// Copyright 2020 Signal Messenger, LLC // SPDX-License-Identifier: AGPL-3.0-only /* eslint-disable max-classes-per-file */ /* * WebSocket-Resources * * Create a request-response interface over websockets using the * WebSocket-Resources sub-protocol[1]. * * var client = new WebSocketResource(socket, function(request) { * request.respond(200, 'OK'); * }); * * const { response, status } = await client.sendRequest({ * verb: 'PUT', * path: '/v1/messages', * headers: ['content-type:application/json'], * body: Buffer.from('{ some: "json" }'), * }); * * 1. https://github.com/signalapp/WebSocket-Resources * */ /* eslint-disable @typescript-eslint/no-namespace */ /* eslint-disable @typescript-eslint/brace-style */ import type { connection as WebSocket, IMessage } from 'websocket'; import Long from 'long'; import pTimeout from 'p-timeout'; import { Response } from 'node-fetch'; import net from 'net'; import { z } from 'zod'; import { clearInterval } from 'timers'; import { random } from 'lodash'; import type { ChatServiceDebugInfo } from '@signalapp/libsignal-client/Native'; import type { LibSignalError, Net } from '@signalapp/libsignal-client'; import { Buffer } from 'node:buffer'; import type { ChatServerMessageAck, ChatServiceListener, ConnectionEventsListener, } from '@signalapp/libsignal-client/dist/net'; import type { EventHandler } from './EventTarget'; import EventTarget from './EventTarget'; import * as durations from '../util/durations'; import { dropNull } from '../util/dropNull'; import { drop } from '../util/drop'; import { isOlderThan } from '../util/timestamp'; import { strictAssert } from '../util/assert'; import * as Errors from '../types/errors'; import { SignalService as Proto } from '../protobuf'; import * as log from '../logging/log'; import * as Timers from '../Timers'; import type { IResource } from './WebSocket'; import { isProduction } from '../util/version'; import { ToastType } from '../types/Toast'; import { AbortableProcess } from '../util/AbortableProcess'; import type { WebAPICredentials } from './Types'; import { NORMAL_DISCONNECT_CODE } from './SocketManager'; import { parseUnknown } from '../util/schemas'; const THIRTY_SECONDS = 30 * durations.SECOND; const STATS_UPDATE_INTERVAL = durations.MINUTE; const MAX_MESSAGE_SIZE = 512 * 1024; const AGGREGATED_STATS_KEY = 'websocketStats'; export enum IpVersion { IPv4 = 'ipv4', IPv6 = 'ipv6', } export namespace IpVersion { export function fromDebugInfoCode(ipType: number): IpVersion | undefined { switch (ipType) { case 1: return IpVersion.IPv4; case 2: return IpVersion.IPv6; default: return undefined; } } } const AggregatedStatsSchema = z.object({ connectionFailures: z.number(), requestsCompared: z.number(), ipVersionMismatches: z.number(), healthcheckFailures: z.number(), healthcheckBadStatus: z.number(), lastToastTimestamp: z.number(), }); export type AggregatedStats = z.infer; // eslint-disable-next-line @typescript-eslint/no-redeclare export namespace AggregatedStats { export function loadOrCreateEmpty(name: string): AggregatedStats { const key = localStorageKey(name); try { const json = localStorage.getItem(key); return json != null ? parseUnknown(AggregatedStatsSchema, JSON.parse(json) as unknown) : createEmpty(); } catch (error) { log.warn( `Could not load [${key}] from local storage. Possibly, attempting to load for the first time`, Errors.toLogFormat(error) ); return createEmpty(); } } export function store(stats: AggregatedStats, name: string): void { const key = localStorageKey(name); try { const json = JSON.stringify(stats); localStorage.setItem(key, json); } catch (error) { log.warn( `Failed to store key [${key}] to the local storage`, Errors.toLogFormat(error) ); } } export function add(a: AggregatedStats, b: AggregatedStats): AggregatedStats { return { requestsCompared: a.requestsCompared + b.requestsCompared, connectionFailures: a.connectionFailures + b.connectionFailures, healthcheckFailures: a.healthcheckFailures + b.healthcheckFailures, ipVersionMismatches: a.ipVersionMismatches + b.ipVersionMismatches, healthcheckBadStatus: a.healthcheckBadStatus + b.healthcheckBadStatus, lastToastTimestamp: Math.max(a.lastToastTimestamp, b.lastToastTimestamp), }; } export function createEmpty(): AggregatedStats { return { requestsCompared: 0, connectionFailures: 0, ipVersionMismatches: 0, healthcheckFailures: 0, healthcheckBadStatus: 0, lastToastTimestamp: 0, }; } export function shouldReportError(stats: AggregatedStats): boolean { const timeSinceLastToast = Date.now() - stats.lastToastTimestamp; if (timeSinceLastToast < durations.DAY || stats.requestsCompared < 1000) { return false; } const totalFailuresSinceLastToast = stats.healthcheckBadStatus + stats.healthcheckFailures + stats.connectionFailures; return totalFailuresSinceLastToast > 20; } export function localStorageKey(name: string): string { return `${AGGREGATED_STATS_KEY}.${name}`; } } export enum ServerRequestType { ApiMessage = '/api/v1/message', ApiEmptyQueue = '/api/v1/queue/empty', ProvisioningMessage = '/v1/message', ProvisioningAddress = '/v1/address', Unknown = 'unknown', } export type IncomingWebSocketRequest = { readonly requestType: ServerRequestType; readonly body: Uint8Array | undefined; readonly timestamp: number | undefined; respond(status: number, message: string): void; }; export class IncomingWebSocketRequestLibsignal implements IncomingWebSocketRequest { constructor( readonly requestType: ServerRequestType, readonly body: Uint8Array | undefined, readonly timestamp: number | undefined, private readonly ack: ChatServerMessageAck | undefined ) {} respond(status: number, _message: string): void { if (this.ack) { drop(this.ack.send(status)); } } } export class IncomingWebSocketRequestLegacy implements IncomingWebSocketRequest { private readonly id: Long; public readonly requestType: ServerRequestType; public readonly body: Uint8Array | undefined; public readonly timestamp: number | undefined; constructor( request: Proto.IWebSocketRequestMessage, private readonly sendBytes: (bytes: Buffer) => void ) { strictAssert(request.id, 'request without id'); strictAssert(request.verb, 'request without verb'); strictAssert(request.path, 'request without path'); this.id = request.id; this.requestType = resolveType(request.path, request.verb); this.body = dropNull(request.body); this.timestamp = resolveTimestamp(request.headers || []); } public respond(status: number, message: string): void { const bytes = Proto.WebSocketMessage.encode({ type: Proto.WebSocketMessage.Type.RESPONSE, response: { id: this.id, message, status }, }).finish(); this.sendBytes(Buffer.from(bytes)); } } function resolveType(path: string, verb: string): ServerRequestType { if (path === ServerRequestType.ApiMessage) { return ServerRequestType.ApiMessage; } if (path === ServerRequestType.ApiEmptyQueue && verb === 'PUT') { return ServerRequestType.ApiEmptyQueue; } if (path === ServerRequestType.ProvisioningAddress && verb === 'PUT') { return ServerRequestType.ProvisioningAddress; } if (path === ServerRequestType.ProvisioningMessage && verb === 'PUT') { return ServerRequestType.ProvisioningMessage; } return ServerRequestType.Unknown; } function resolveTimestamp(headers: ReadonlyArray): number | undefined { // The 'X-Signal-Timestamp' is usually the last item, so start there. let it = headers.length; // eslint-disable-next-line no-plusplus while (--it >= 0) { const match = headers[it].match(/^X-Signal-Timestamp:\s*(\d+)\s*$/i); if (match && match.length === 2) { return Number(match[1]); } } return undefined; } export type SendRequestOptions = Readonly<{ verb: string; path: string; body?: Uint8Array; timeout?: number; headers?: ReadonlyArray<[string, string]>; }>; export type SendRequestResult = Readonly<{ status: number; message: string; response?: Uint8Array; headers: ReadonlyArray; }>; export enum TransportOption { // Only original transport is used Original = 'original', // All requests are going through the original transport, // but for every request that completes sucessfully we're initiating // a healthcheck request via libsignal transport, // collecting comparison statistics, and if we see many inconsistencies, // we're showing a toast asking user to submit a debug log ShadowingHigh = 'shadowingHigh', // Similar to `shadowingHigh`, however, only 10% of requests // will trigger a healthcheck, and toast is never shown. // Statistics data is still added to the debug logs, // so it will be available to us with all the debug log uploads. ShadowingLow = 'shadowingLow', // Only libsignal transport is used Libsignal = 'libsignal', } export type WebSocketResourceOptions = { name: string; handleRequest?: (request: IncomingWebSocketRequest) => void; keepalive?: KeepAliveOptionsType; transportOption?: TransportOption; }; export class CloseEvent extends Event { constructor( public readonly code: number, public readonly reason: string ) { super('close'); } } // eslint-disable-next-line no-restricted-syntax export interface IWebSocketResource extends IResource { sendRequest(options: SendRequestOptions): Promise; addEventListener(name: 'close', handler: (ev: CloseEvent) => void): void; forceKeepAlive(timeout?: number): void; shutdown(): void; close(code?: number, reason?: string): void; localPort(): number | undefined; } type LibsignalWebSocketResourceHolder = { resource: LibsignalWebSocketResource | undefined; }; const UNEXPECTED_DISCONNECT_CODE = 3001; export function connectUnauthenticatedLibsignal({ libsignalNet, name, keepalive, }: { libsignalNet: Net.Net; name: string; keepalive: KeepAliveOptionsType; }): AbortableProcess { const logId = `LibsignalWebSocketResource(${name})`; const listener: LibsignalWebSocketResourceHolder & ConnectionEventsListener = { resource: undefined, onConnectionInterrupted(cause: LibSignalError | null): void { if (!this.resource) { logDisconnectedListenerWarn(logId, 'onConnectionInterrupted'); return; } this.resource.onConnectionInterrupted(cause); this.resource = undefined; }, }; return connectLibsignal( libsignalNet.newUnauthenticatedChatService(listener), listener, logId, keepalive ); } export function connectAuthenticatedLibsignal({ libsignalNet, name, credentials, handler, receiveStories, keepalive, }: { libsignalNet: Net.Net; name: string; credentials: WebAPICredentials; handler: (request: IncomingWebSocketRequest) => void; receiveStories: boolean; keepalive: KeepAliveOptionsType; }): AbortableProcess { const logId = `LibsignalWebSocketResource(${name})`; const listener: LibsignalWebSocketResourceHolder & ChatServiceListener = { resource: undefined, onIncomingMessage( envelope: Buffer, timestamp: number, ack: ChatServerMessageAck ): void { // Handle incoming messages even if we've disconnected. const request = new IncomingWebSocketRequestLibsignal( ServerRequestType.ApiMessage, envelope, timestamp, ack ); handler(request); }, onQueueEmpty(): void { if (!this.resource) { logDisconnectedListenerWarn(logId, 'onQueueEmpty'); return; } const request = new IncomingWebSocketRequestLibsignal( ServerRequestType.ApiEmptyQueue, undefined, undefined, undefined ); handler(request); }, onConnectionInterrupted(cause): void { if (!this.resource) { logDisconnectedListenerWarn(logId, 'onConnectionInterrupted'); return; } this.resource.onConnectionInterrupted(cause); this.resource = undefined; }, }; return connectLibsignal( libsignalNet.newAuthenticatedChatService( credentials.username, credentials.password, receiveStories, listener ), listener, logId, keepalive ); } function logDisconnectedListenerWarn(logId: string, method: string): void { log.warn(`${logId} received ${method}, but listener already disconnected`); } function connectLibsignal( chatService: Net.ChatService, resourceHolder: LibsignalWebSocketResourceHolder, logId: string, keepalive: KeepAliveOptionsType ): AbortableProcess { const connectAsync = async () => { try { const debugInfo = await chatService.connect(); log.info(`${logId} connected`, debugInfo); const resource = new LibsignalWebSocketResource( chatService, IpVersion.fromDebugInfoCode(debugInfo.ipType), logId, keepalive ); // eslint-disable-next-line no-param-reassign resourceHolder.resource = resource; return resource; } catch (error) { // Handle any errors that occur during connection log.error(`${logId} connection failed`, Errors.toLogFormat(error)); throw error; } }; return new AbortableProcess( `${logId}.connect`, { abort() { // if interrupted, trying to disconnect drop(chatService.disconnect()); }, }, connectAsync() ); } export class LibsignalWebSocketResource extends EventTarget implements IWebSocketResource { closed = false; // Unlike WebSocketResource, libsignal will automatically attempt to keep the // socket alive using websocket pings, so we don't need a timer-based // keepalive mechanism. But we still send one-off keepalive requests when // things change (see forceKeepAlive()). private keepalive: KeepAliveSender; constructor( private readonly chatService: Net.ChatService, private readonly socketIpVersion: IpVersion | undefined, private readonly logId: string, keepalive: KeepAliveOptionsType ) { super(); this.keepalive = new KeepAliveSender(this, this.logId, keepalive); } public localPort(): number | undefined { return undefined; } public ipVersion(): IpVersion | undefined { return this.socketIpVersion; } public override addEventListener( name: 'close', handler: (ev: CloseEvent) => void ): void; public override addEventListener(name: string, handler: EventHandler): void { return super.addEventListener(name, handler); } public close(code = NORMAL_DISCONNECT_CODE, reason?: string): void { if (this.closed) { log.info(`${this.logId}.close: Already closed! ${code}/${reason}`); return; } drop(this.chatService.disconnect()); // On linux the socket can wait a long time to emit its close event if we've // lost the internet connection. On the order of minutes. This speeds that // process up. Timers.setTimeout( () => this.onConnectionInterrupted(null), 5 * durations.SECOND ); } public shutdown(): void { this.close(NORMAL_DISCONNECT_CODE, 'Shutdown'); } onConnectionInterrupted(cause: LibSignalError | null): void { if (this.closed) { log.warn( `${this.logId}.onConnectionInterrupted called after resource is closed` ); return; } this.closed = true; log.warn(`${this.logId}: connection closed`); let event; if (cause) { event = new CloseEvent(UNEXPECTED_DISCONNECT_CODE, cause.message); } else { // The cause was an intentional disconnect. Report normal closure. event = new CloseEvent(NORMAL_DISCONNECT_CODE, 'normal'); } this.dispatchEvent(event); } public forceKeepAlive(timeout?: number): void { drop(this.keepalive.send(timeout)); } public async sendRequest(options: SendRequestOptions): Promise { const [response] = await this.sendRequestGetDebugInfo(options); return response; } public async sendRequestGetDebugInfo( options: SendRequestOptions ): Promise<[Response, ChatServiceDebugInfo]> { const { response, debugInfo } = await this.chatService.fetchAndDebug({ verb: options.verb, path: options.path, headers: options.headers ? options.headers : [], body: options.body, timeoutMillis: options.timeout, }); return [ new Response(response.body, { status: response.status, statusText: response.message, headers: [...response.headers], }), debugInfo, ]; } } export class WebSocketResourceWithShadowing implements IWebSocketResource { private shadowing: LibsignalWebSocketResource | undefined; private stats: AggregatedStats; private statsTimer: NodeJS.Timeout; private shadowingWithReporting: boolean; private logId: string; constructor( private readonly main: WebSocketResource, private readonly shadowingConnection: AbortableProcess, options: WebSocketResourceOptions ) { this.stats = AggregatedStats.createEmpty(); this.logId = `WebSocketResourceWithShadowing(${options.name})`; this.statsTimer = setInterval( () => this.updateStats(options.name), STATS_UPDATE_INTERVAL ); this.shadowingWithReporting = options.transportOption === TransportOption.ShadowingHigh; // the idea is that we want to keep the shadowing connection process // "in the background", so that the main connection wouldn't need to wait on it. // then when we're connected, `this.shadowing` socket resource is initialized // or an error reported in case of connection failure const initializeAfterConnected = async () => { try { this.shadowing = await shadowingConnection.resultPromise; // checking IP one time per connection if (this.main.ipVersion() !== this.shadowing.ipVersion()) { this.stats.ipVersionMismatches += 1; const mainIpType = this.main.ipVersion(); const shadowIpType = this.shadowing.ipVersion(); log.warn( `${this.logId}: libsignal websocket IP [${shadowIpType}], Desktop websocket IP [${mainIpType}]` ); } } catch (error) { this.stats.connectionFailures += 1; } }; drop(initializeAfterConnected()); this.addEventListener('close', (_ev): void => { clearInterval(this.statsTimer); this.updateStats(options.name); }); } private updateStats(name: string) { const storedStats = AggregatedStats.loadOrCreateEmpty(name); let updatedStats = AggregatedStats.add(storedStats, this.stats); if ( this.shadowingWithReporting && AggregatedStats.shouldReportError(updatedStats) && !isProduction(window.getVersion()) ) { window.reduxActions.toast.showToast({ toastType: ToastType.TransportError, }); log.warn( `${this.logId}: experimental transport toast displayed, flushing transport statistics before resetting`, updatedStats ); updatedStats = AggregatedStats.createEmpty(); updatedStats.lastToastTimestamp = Date.now(); } AggregatedStats.store(updatedStats, name); this.stats = AggregatedStats.createEmpty(); } public localPort(): number | undefined { return this.main.localPort(); } public addEventListener( name: 'close', handler: (ev: CloseEvent) => void ): void { this.main.addEventListener(name, handler); } public close(code = NORMAL_DISCONNECT_CODE, reason?: string): void { this.main.close(code, reason); if (this.shadowing) { this.shadowing.close(code, reason); this.shadowing = undefined; } else { this.shadowingConnection.abort(); } } public shutdown(): void { this.main.shutdown(); if (this.shadowing) { this.shadowing.shutdown(); this.shadowing = undefined; } else { this.shadowingConnection.abort(); } } public forceKeepAlive(timeout?: number): void { this.main.forceKeepAlive(timeout); } public async sendRequest(options: SendRequestOptions): Promise { const responsePromise = this.main.sendRequest(options); const response = await responsePromise; // if we're received a response from the main channel and the status was successful, // attempting to run a healthcheck on a libsignal transport. if ( isSuccessfulStatusCode(response.status) && this.shouldSendShadowRequest() ) { drop(this.sendShadowRequest()); } return response; } private async sendShadowRequest(): Promise { // In the shadowing mode, it could be that we're either // still connecting libsignal websocket or have already closed it. // In those cases we're not running shadowing check. if (!this.shadowing) { log.info( `${this.logId}: skipping healthcheck - websocket not connected or already closed` ); return; } try { const healthCheckResult = await this.shadowing.sendRequest({ verb: 'GET', path: '/v1/keepalive', timeout: KEEPALIVE_TIMEOUT_MS, }); this.stats.requestsCompared += 1; if (!isSuccessfulStatusCode(healthCheckResult.status)) { this.stats.healthcheckBadStatus += 1; log.warn( `${this.logId}: keepalive via libsignal responded with status [${healthCheckResult.status}]` ); } } catch (error) { this.stats.healthcheckFailures += 1; log.warn( `${this.logId}: failed to send keepalive via libsignal`, Errors.toLogFormat(error) ); } } private shouldSendShadowRequest(): boolean { return this.shadowingWithReporting || random(0, 100) < 10; } } function isSuccessfulStatusCode(status: number): boolean { return status >= 200 && status < 300; } export default class WebSocketResource extends EventTarget implements IWebSocketResource { private outgoingId = Long.fromNumber(1, true); private closed = false; private readonly outgoingMap = new Map< string, (result: SendRequestResult) => void >(); private readonly boundOnMessage: (message: IMessage) => void; private activeRequests = new Set(); private shuttingDown = false; private shutdownTimer?: Timers.Timeout; private readonly logId: string; private readonly localSocketPort: number | undefined; private readonly socketIpVersion: IpVersion | undefined; // Public for tests public readonly keepalive?: KeepAlive; constructor( private readonly socket: WebSocket, private readonly options: WebSocketResourceOptions ) { super(); this.logId = `WebSocketResource(${options.name})`; this.localSocketPort = socket.socket.localPort; if (!socket.socket.localAddress) { this.socketIpVersion = undefined; } if (socket.socket.localAddress == null) { this.socketIpVersion = undefined; } else if (net.isIPv4(socket.socket.localAddress)) { this.socketIpVersion = IpVersion.IPv4; } else if (net.isIPv6(socket.socket.localAddress)) { this.socketIpVersion = IpVersion.IPv6; } else { this.socketIpVersion = undefined; } this.boundOnMessage = this.onMessage.bind(this); socket.on('message', this.boundOnMessage); if (options.keepalive) { const keepalive = new KeepAlive( this, options.name, options.keepalive ?? {} ); this.keepalive = keepalive; keepalive.reset(); socket.on('close', () => this.keepalive?.stop()); socket.on('error', (error: Error) => { log.warn(`${this.logId}: WebSocket error`, Errors.toLogFormat(error)); }); } socket.on('close', (code, reason) => { this.closed = true; log.warn(`${this.logId}: Socket closed`); this.dispatchEvent(new CloseEvent(code, reason || 'normal')); }); this.addEventListener('close', () => this.onClose()); } public ipVersion(): IpVersion | undefined { return this.socketIpVersion; } public localPort(): number | undefined { return this.localSocketPort; } public override addEventListener( name: 'close', handler: (ev: CloseEvent) => void ): void; public override addEventListener(name: string, handler: EventHandler): void { return super.addEventListener(name, handler); } public async sendRequest(options: SendRequestOptions): Promise { const id = this.outgoingId; const idString = id.toString(); strictAssert(!this.outgoingMap.has(idString), 'Duplicate outgoing request'); // Note that this automatically wraps this.outgoingId = this.outgoingId.add(1); const bytes = Proto.WebSocketMessage.encode({ type: Proto.WebSocketMessage.Type.REQUEST, request: { verb: options.verb, path: options.path, body: options.body, headers: options.headers ? options.headers .map(([key, value]) => { return `${key}:${value}`; }) .slice() : undefined, id, }, }).finish(); strictAssert( bytes.length <= MAX_MESSAGE_SIZE, 'WebSocket request byte size exceeded' ); strictAssert(!this.shuttingDown, 'Cannot send request, shutting down'); this.addActive(idString); const promise = new Promise((resolve, reject) => { let timer = options.timeout ? Timers.setTimeout(() => { this.removeActive(idString); this.close(UNEXPECTED_DISCONNECT_CODE, 'Request timed out'); reject(new Error(`Request timed out; id: [${idString}]`)); }, options.timeout) : undefined; this.outgoingMap.set(idString, result => { if (timer !== undefined) { Timers.clearTimeout(timer); timer = undefined; } this.keepalive?.reset(); this.removeActive(idString); resolve(result); }); }); this.socket.sendBytes(Buffer.from(bytes)); const requestResult = await promise; return WebSocketResource.intoResponse(requestResult); } public forceKeepAlive(timeout?: number): void { if (!this.keepalive) { return; } drop(this.keepalive.send(timeout)); } public close(code = NORMAL_DISCONNECT_CODE, reason?: string): void { if (this.closed) { log.info(`${this.logId}.close: Already closed! ${code}/${reason}`); return; } log.info(`${this.logId}.close(${code})`); if (this.keepalive) { this.keepalive.stop(); } this.socket.close(code, reason); this.socket.removeListener('message', this.boundOnMessage); // On linux the socket can wait a long time to emit its close event if we've // lost the internet connection. On the order of minutes. This speeds that // process up. Timers.setTimeout(() => { if (this.closed) { return; } log.warn(`${this.logId}.close: Dispatching our own socket close event`); this.dispatchEvent(new CloseEvent(code, reason || 'normal')); }, 5 * durations.SECOND); } public shutdown(): void { if (this.closed) { return; } if (this.activeRequests.size === 0) { log.info(`${this.logId}.shutdown: no active requests, closing`); this.close(NORMAL_DISCONNECT_CODE, 'Shutdown'); return; } this.shuttingDown = true; log.info(`${this.logId}.shutdown: shutting down`); this.shutdownTimer = Timers.setTimeout(() => { if (this.closed) { return; } log.warn(`${this.logId}.shutdown: Failed to shutdown gracefully`); this.close(NORMAL_DISCONNECT_CODE, 'Shutdown'); }, THIRTY_SECONDS); } private onMessage({ type, binaryData }: IMessage): void { if (type !== 'binary' || !binaryData) { throw new Error(`Unsupported websocket message type: ${type}`); } const message = Proto.WebSocketMessage.decode(binaryData); if ( message.type === Proto.WebSocketMessage.Type.REQUEST && message.request ) { const handleRequest = this.options.handleRequest || (request => request.respond(404, 'Not found')); const incomingRequest = new IncomingWebSocketRequestLegacy( message.request, (bytes: Buffer): void => { this.removeActive(incomingRequest); strictAssert( bytes.length <= MAX_MESSAGE_SIZE, 'WebSocket response byte size exceeded' ); this.socket.sendBytes(bytes); } ); if (this.shuttingDown) { incomingRequest.respond(-1, 'Shutting down'); return; } this.addActive(incomingRequest); handleRequest(incomingRequest); } else if ( message.type === Proto.WebSocketMessage.Type.RESPONSE && message.response ) { const { response } = message; strictAssert(response.id, 'response without id'); const responseIdString = response.id.toString(); const resolve = this.outgoingMap.get(responseIdString); this.outgoingMap.delete(responseIdString); if (!resolve) { throw new Error(`Received response for unknown request ${response.id}`); } resolve({ status: response.status ?? -1, message: response.message ?? '', response: dropNull(response.body), headers: response.headers ?? [], }); } } private onClose(): void { const outgoing = new Map(this.outgoingMap); this.outgoingMap.clear(); for (const resolve of outgoing.values()) { resolve({ status: -1, message: 'Connection closed', response: undefined, headers: [], }); } } private addActive(request: IncomingWebSocketRequest | string): void { this.activeRequests.add(request); } private removeActive(request: IncomingWebSocketRequest | string): void { if (!this.activeRequests.has(request)) { log.warn(`${this.logId}.removeActive: removing unknown request`); return; } this.activeRequests.delete(request); if (this.activeRequests.size !== 0) { return; } if (!this.shuttingDown) { return; } if (this.shutdownTimer) { Timers.clearTimeout(this.shutdownTimer); this.shutdownTimer = undefined; } log.info(`${this.logId}.removeActive: shutdown complete`); this.close(NORMAL_DISCONNECT_CODE, 'Shutdown'); } private static intoResponse(sendRequestResult: SendRequestResult): Response { const { status, message: statusText, response, headers: flatResponseHeaders, } = sendRequestResult; const headers: Array<[string, string]> = flatResponseHeaders.map(header => { const [key, value] = header.split(':', 2); strictAssert(value !== undefined, 'Invalid header!'); return [key, value]; }); return new Response(response, { status, statusText, headers, }); } } export type KeepAliveOptionsType = { path?: string; }; // 30 seconds + 5 seconds for closing the socket above. const KEEPALIVE_INTERVAL_MS = 30 * durations.SECOND; // If the machine was in suspended mode for more than 5 minutes - trigger // immediate disconnect. const STALE_THRESHOLD_MS = 5 * durations.MINUTE; // If we don't receive a response to keepalive request within 30 seconds - // close the socket. const KEEPALIVE_TIMEOUT_MS = 30 * durations.SECOND; const LOG_KEEPALIVE_AFTER_MS = 500; /** * References an {@link IWebSocketResource} and a request path that should * return promptly to determine whether the connection is still alive. * * The response to the request must have a 2xx status code but is otherwise * ignored. A failing response or a timeout results in the socket being closed * with {@link UNEXPECTED_DISCONNECT_CODE}. * * Use the subclass {@link KeepAlive} if you want to send the request at regular * intervals. */ class KeepAliveSender { private path: string; protected wsr: IWebSocketResource; protected logId: string; constructor( websocketResource: IWebSocketResource, name: string, opts: KeepAliveOptionsType = {} ) { this.logId = `WebSocketResources.KeepAlive(${name})`; this.path = opts.path ?? '/'; this.wsr = websocketResource; } public async send(timeout = KEEPALIVE_TIMEOUT_MS): Promise { log.info(`${this.logId}.send: Sending a keepalive message`); const sentAt = Date.now(); try { const { status } = await pTimeout( this.wsr.sendRequest({ verb: 'GET', path: this.path, }), timeout ); if (status < 200 || status >= 300) { log.warn(`${this.logId}.send: keepalive response status ${status}`); this.wsr.close( UNEXPECTED_DISCONNECT_CODE, `keepalive response with ${status} code` ); return false; } } catch (error) { this.wsr.close( UNEXPECTED_DISCONNECT_CODE, 'No response to keepalive request' ); return false; } const responseTime = Date.now() - sentAt; if (responseTime > LOG_KEEPALIVE_AFTER_MS) { log.warn( `${this.logId}.send: delayed response to keepalive request, ` + `response time: ${responseTime}ms` ); } return true; } } /** * Manages a timer that checks if a particular {@link WebSocketResource} is * still alive. * * The resource must specifically be a {@link WebSocketResource}. Other kinds of * resource are expected to manage their own liveness checks. If you want to * manually send keepalive requests to such resources, use the base class * {@link KeepAliveSender}. */ class KeepAlive extends KeepAliveSender { private keepAliveTimer: Timers.Timeout | undefined; private lastAliveAt: number = Date.now(); constructor( websocketResource: WebSocketResource, name: string, opts: KeepAliveOptionsType = {} ) { super(websocketResource, name, opts); } public stop(): void { this.clearTimers(); } public override async send(timeout = KEEPALIVE_TIMEOUT_MS): Promise { this.clearTimers(); const isStale = isOlderThan(this.lastAliveAt, STALE_THRESHOLD_MS); if (isStale) { log.info(`${this.logId}.send: disconnecting due to stale state`); this.wsr.close( UNEXPECTED_DISCONNECT_CODE, `Last keepalive request was too far in the past: ${this.lastAliveAt}` ); return false; } const isAlive = await super.send(timeout); if (!isAlive) { return false; } // Successful response on time this.reset(); return true; } public reset(): void { this.lastAliveAt = Date.now(); this.clearTimers(); this.keepAliveTimer = Timers.setTimeout( () => this.send(), KEEPALIVE_INTERVAL_MS ); } private clearTimers(): void { if (this.keepAliveTimer) { Timers.clearTimeout(this.keepAliveTimer); this.keepAliveTimer = undefined; } } }