diff --git a/package.json b/package.json index 8ff414d7501b..598d13b59298 100644 --- a/package.json +++ b/package.json @@ -211,7 +211,7 @@ "@formatjs/intl": "2.6.7", "@indutny/rezip-electron": "1.3.1", "@mixer/parallel-prettier": "2.0.3", - "@signalapp/mock-server": "6.4.1", + "@signalapp/mock-server": "6.4.2", "@storybook/addon-a11y": "7.4.5", "@storybook/addon-actions": "7.4.5", "@storybook/addon-controls": "7.4.5", diff --git a/ts/services/backups/api.ts b/ts/services/backups/api.ts index d0f2a635f954..68720f7f6181 100644 --- a/ts/services/backups/api.ts +++ b/ts/services/backups/api.ts @@ -1,13 +1,13 @@ // Copyright 2024 Signal Messenger, LLC // SPDX-License-Identifier: AGPL-3.0-only -import type { Readable } from 'stream'; - import { strictAssert } from '../../util/assert'; +import { tusUpload } from '../../util/uploads/tusProtocol'; +import { defaultFileReader } from '../../util/uploads/uploads'; import type { WebAPIType, + AttachmentV3ResponseType, GetBackupInfoResponseType, - GetBackupUploadFormResponseType, BackupMediaItemType, BackupMediaBatchResponseType, BackupListMediaResponseType, @@ -55,14 +55,25 @@ export class BackupAPI { return (await this.getInfo()).backupName; } - public async upload(stream: Readable): Promise { - return this.server.uploadBackup({ - headers: await this.credentials.getHeadersForToday(), - stream, + public async upload(filePath: string, fileSize: number): Promise { + const form = await this.server.getBackupUploadForm( + await this.credentials.getHeadersForToday() + ); + + const fetchFn = this.server.createFetchForAttachmentUpload(form); + + await tusUpload({ + endpoint: form.signedUploadLocation, + headers: {}, + fileName: form.key, + filePath, + fileSize, + reader: defaultFileReader, + fetchFn, }); } - public async getMediaUploadForm(): Promise { + public async getMediaUploadForm(): Promise { return this.server.getBackupMediaUploadForm( await this.credentials.getHeadersForToday() ); diff --git a/ts/services/backups/export.ts b/ts/services/backups/export.ts index 216892ba625b..202c9d5ff5c2 100644 --- a/ts/services/backups/export.ts +++ b/ts/services/backups/export.ts @@ -622,6 +622,7 @@ export class BackupExportStream extends Readable { }; if (!isNormalBubble(message)) { + result.directionless = {}; return this.toChatItemFromNonBubble(result, message, options); } diff --git a/ts/services/backups/import.ts b/ts/services/backups/import.ts index 364dcb99fe70..be8c1c516957 100644 --- a/ts/services/backups/import.ts +++ b/ts/services/backups/import.ts @@ -638,12 +638,6 @@ export class BackupImportStream extends Writable { ? this.recipientIdToConvo.get(item.authorId.toNumber()) : undefined; - const isOutgoing = - authorConvo && this.ourConversation.id === authorConvo?.id; - const isIncoming = - authorConvo && this.ourConversation.id !== authorConvo?.id; - const isDirectionLess = !isOutgoing && !isIncoming; - let attributes: MessageAttributesType = { id: generateUuid(), canReplyToStory: false, @@ -653,7 +647,7 @@ export class BackupImportStream extends Writable { source: authorConvo?.e164, sourceServiceId: authorConvo?.serviceId, timestamp, - type: isOutgoing ? 'outgoing' : 'incoming', + type: item.outgoing != null ? 'outgoing' : 'incoming', unidentifiedDeliveryReceived: false, expirationStartTimestamp: item.expireStartDate ? getTimestampFromLong(item.expireStartDate) @@ -667,7 +661,7 @@ export class BackupImportStream extends Writable { const { outgoing, incoming, directionless } = item; if (outgoing) { strictAssert( - isOutgoing, + authorConvo && this.ourConversation.id === authorConvo?.id, `${logId}: outgoing message must have outgoing field` ); @@ -722,10 +716,9 @@ export class BackupImportStream extends Writable { attributes.sendStateByConversationId = sendStateByConversationId; chatConvo.active_at = attributes.sent_at; - } - if (incoming) { + } else if (incoming) { strictAssert( - isIncoming, + authorConvo && this.ourConversation.id !== authorConvo?.id, `${logId}: message with incoming field must be incoming` ); attributes.received_at_ms = @@ -741,12 +734,8 @@ export class BackupImportStream extends Writable { } chatConvo.active_at = attributes.received_at_ms; - } - if (directionless) { - strictAssert( - isDirectionLess, - `${logId}: directionless message must not be incoming/outgoing` - ); + } else if (directionless) { + // Nothing to do } if (item.standardMessage) { @@ -793,7 +782,7 @@ export class BackupImportStream extends Writable { // TODO (DESKTOP-6964): We'll want to increment for more types here - stickers, etc. if (item.standardMessage) { - if (isOutgoing) { + if (item.outgoing != null) { chatConvo.sentMessageCount = (chatConvo.sentMessageCount ?? 0) + 1; } else { chatConvo.messageCount = (chatConvo.messageCount ?? 0) + 1; diff --git a/ts/services/backups/index.ts b/ts/services/backups/index.ts index 9d1dcd641307..1e08a33e3fce 100644 --- a/ts/services/backups/index.ts +++ b/ts/services/backups/index.ts @@ -5,6 +5,8 @@ import { pipeline } from 'stream/promises'; import { PassThrough } from 'stream'; import type { Readable, Writable } from 'stream'; import { createReadStream, createWriteStream } from 'fs'; +import { unlink } from 'fs/promises'; +import { join } from 'path'; import { createGzip, createGunzip } from 'zlib'; import { createCipheriv, createHmac, randomBytes } from 'crypto'; import { noop } from 'lodash'; @@ -27,6 +29,7 @@ import { BackupImportStream } from './import'; import { getKeyMaterial } from './crypto'; import { BackupCredentials } from './credentials'; import { BackupAPI } from './api'; +import { validateBackup } from './validator'; const IV_LENGTH = 16; @@ -61,41 +64,23 @@ export class BackupsService { }); } - public async exportBackup(sink: Writable): Promise { - strictAssert(!this.isRunning, 'BackupService is already running'); - - log.info('exportBackup: starting...'); - this.isRunning = true; + public async upload(): Promise { + const fileName = `backup-${randomBytes(32).toString('hex')}`; + const filePath = join(window.BasePaths.temp, fileName); try { - const { aesKey, macKey } = getKeyMaterial(); + const fileSize = await this.exportToDisk(filePath); - const recordStream = new BackupExportStream(); - recordStream.run(); - - const iv = randomBytes(IV_LENGTH); - - await pipeline( - recordStream, - createGzip(), - appendPaddingStream(), - createCipheriv(CipherType.AES256CBC, aesKey, iv), - prependStream(iv), - appendMacStream(macKey), - sink - ); + await this.api.upload(filePath, fileSize); } finally { - log.info('exportBackup: finished...'); - this.isRunning = false; + try { + await unlink(filePath); + } catch { + // Ignore + } } } - public async upload(): Promise { - const pipe = new PassThrough(); - - await Promise.all([this.api.upload(pipe), this.exportBackup(pipe)]); - } - // Test harness public async exportBackupData(): Promise { const sink = new PassThrough(); @@ -108,8 +93,12 @@ export class BackupsService { } // Test harness - public async exportToDisk(path: string): Promise { - await this.exportBackup(createWriteStream(path)); + public async exportToDisk(path: string): Promise { + const size = await this.exportBackup(createWriteStream(path)); + + await validateBackup(path, size); + + return size; } // Test harness @@ -185,6 +174,49 @@ export class BackupsService { } } + private async exportBackup(sink: Writable): Promise { + strictAssert(!this.isRunning, 'BackupService is already running'); + + log.info('exportBackup: starting...'); + this.isRunning = true; + + try { + const { aesKey, macKey } = getKeyMaterial(); + + const recordStream = new BackupExportStream(); + recordStream.run(); + + const iv = randomBytes(IV_LENGTH); + + const pass = new PassThrough(); + + let totalBytes = 0; + + // Pause the flow first so that the we respect backpressure. The + // `pipeline` call below will control the flow anyway. + pass.pause(); + pass.on('data', chunk => { + totalBytes += chunk.length; + }); + + await pipeline( + recordStream, + createGzip(), + appendPaddingStream(), + createCipheriv(CipherType.AES256CBC, aesKey, iv), + prependStream(iv), + appendMacStream(macKey), + pass, + sink + ); + + return totalBytes; + } finally { + log.info('exportBackup: finished...'); + this.isRunning = false; + } + } + private async runPeriodicRefresh(): Promise { try { await this.api.refresh(); diff --git a/ts/services/backups/validator.ts b/ts/services/backups/validator.ts new file mode 100644 index 000000000000..57b71537ab74 --- /dev/null +++ b/ts/services/backups/validator.ts @@ -0,0 +1,92 @@ +// Copyright 2024 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +import { type FileHandle, open } from 'node:fs/promises'; +import * as libsignal from '@signalapp/libsignal-client/dist/MessageBackup'; +import { InputStream } from '@signalapp/libsignal-client/dist/io'; + +import { strictAssert } from '../../util/assert'; +import { toAciObject } from '../../util/ServiceId'; +import { isTestOrMockEnvironment } from '../../environment'; + +class FileStream extends InputStream { + private file: FileHandle | undefined; + private position = 0; + private buffer = Buffer.alloc(16 * 1024); + private initPromise: Promise | undefined; + + constructor(private readonly filePath: string) { + super(); + } + + public async close(): Promise { + await this.initPromise; + await this.file?.close(); + } + + async read(amount: number): Promise { + await this.initPromise; + + if (!this.file) { + const filePromise = open(this.filePath); + this.initPromise = filePromise; + this.file = await filePromise; + } + + if (this.buffer.length < amount) { + this.buffer = Buffer.alloc(amount); + } + const { bytesRead } = await this.file.read( + this.buffer, + 0, + amount, + this.position + ); + this.position += bytesRead; + return this.buffer.slice(0, bytesRead); + } + + async skip(amount: number): Promise { + this.position += amount; + } +} + +export async function validateBackup( + filePath: string, + fileSize: number +): Promise { + const masterKeyBase64 = window.storage.get('masterKey'); + strictAssert(masterKeyBase64, 'Master key not available'); + + const masterKey = Buffer.from(masterKeyBase64, 'base64'); + + const aci = toAciObject(window.storage.user.getCheckedAci()); + const backupKey = new libsignal.MessageBackupKey(masterKey, aci); + + const streams = new Array(); + + let outcome: libsignal.ValidationOutcome; + try { + outcome = await libsignal.validate( + backupKey, + libsignal.Purpose.RemoteBackup, + () => { + const stream = new FileStream(filePath); + streams.push(stream); + return stream; + }, + BigInt(fileSize) + ); + } finally { + await Promise.all(streams.map(stream => stream.close())); + } + + if (isTestOrMockEnvironment()) { + strictAssert( + outcome.ok, + `Backup validation failed: ${outcome.errorMessage}` + ); + } else { + strictAssert(outcome.ok, 'Backup validation failed'); + } +} diff --git a/ts/test-electron/backup/backup_groupv2_notifications_test.ts b/ts/test-electron/backup/backup_groupv2_notifications_test.ts index 6a9d21792921..439a8b8e3fa3 100644 --- a/ts/test-electron/backup/backup_groupv2_notifications_test.ts +++ b/ts/test-electron/backup/backup_groupv2_notifications_test.ts @@ -3,7 +3,8 @@ import path from 'path'; import { tmpdir } from 'os'; -import { rmSync, mkdtempSync, createReadStream } from 'fs'; +import { createReadStream } from 'fs'; +import { mkdtemp, rm } from 'fs/promises'; import { v4 as generateGuid } from 'uuid'; import { assert } from 'chai'; @@ -71,7 +72,7 @@ async function asymmetricRoundtripHarness( before: Array, after: Array ) { - const outDir = mkdtempSync(path.join(tmpdir(), 'signal-temp-')); + const outDir = await mkdtemp(path.join(tmpdir(), 'signal-temp-')); try { const targetOutputFile = path.join(outDir, 'backup.bin'); @@ -89,7 +90,7 @@ async function asymmetricRoundtripHarness( const actual = sortAndNormalize(messagesFromDatabase); assert.deepEqual(expected, actual); } finally { - rmSync(outDir, { recursive: true }); + await rm(outDir, { recursive: true }); } } diff --git a/ts/test-node/util/uploads/helpers.ts b/ts/test-node/util/uploads/helpers.ts index 39b82736459a..6a9e3757b3ba 100644 --- a/ts/test-node/util/uploads/helpers.ts +++ b/ts/test-node/util/uploads/helpers.ts @@ -65,7 +65,7 @@ export class TestServer extends EventEmitter { typeof address === 'object' && address != null, 'address must be an object' ); - return `http://localhost:${address.port}/}`; + return `http://localhost:${address.port}/`; } respondWith(status: number, headers: OutgoingHttpHeaders = {}): void { diff --git a/ts/test-node/util/uploads/tusProtocol_test.ts b/ts/test-node/util/uploads/tusProtocol_test.ts index 2ef59f4b0e4f..057e16d21554 100644 --- a/ts/test-node/util/uploads/tusProtocol_test.ts +++ b/ts/test-node/util/uploads/tusProtocol_test.ts @@ -146,7 +146,7 @@ describe('tusProtocol', () => { }), }); assert.strictEqual(result, false); - assert.strictEqual(caughtError?.message, 'fetch failed'); + assert.strictEqual(caughtError?.message, 'test'); }); }); @@ -317,7 +317,7 @@ describe('tusProtocol', () => { }), }); assert.strictEqual(result, false); - assert.strictEqual(caughtError?.message, 'fetch failed'); + assert.strictEqual(caughtError?.message, 'test'); }); }); @@ -327,7 +327,6 @@ describe('tusProtocol', () => { function assertSocketCloseError(error: unknown) { // There isn't an equivalent to this chain in assert() expect(error, toLogFormat(error)) - .property('cause') .property('code') .oneOf(['ECONNRESET', 'UND_ERR_SOCKET']); } diff --git a/ts/textsecure/Errors.ts b/ts/textsecure/Errors.ts index 88bd7a1bf149..eb7d1b7f8478 100644 --- a/ts/textsecure/Errors.ts +++ b/ts/textsecure/Errors.ts @@ -3,6 +3,7 @@ /* eslint-disable max-classes-per-file */ +import type { Response } from 'node-fetch'; import type { LibSignalErrorBase } from '@signalapp/libsignal-client'; import { parseRetryAfter } from '../util/parseRetryAfter'; diff --git a/ts/textsecure/WebAPI.ts b/ts/textsecure/WebAPI.ts index bd49dedc1677..268e03906626 100644 --- a/ts/textsecure/WebAPI.ts +++ b/ts/textsecure/WebAPI.ts @@ -6,7 +6,7 @@ /* eslint-disable no-restricted-syntax */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import type { Response } from 'node-fetch'; +import type { RequestInit, Response } from 'node-fetch'; import fetch from 'node-fetch'; import type { Agent } from 'https'; import { escapeRegExp, isNumber, isString, isObject } from 'lodash'; @@ -30,6 +30,7 @@ import { getBasicAuth } from '../util/getBasicAuth'; import { createHTTPSAgent } from '../util/createHTTPSAgent'; import { createProxyAgent } from '../util/createProxyAgent'; import type { ProxyAgent } from '../util/createProxyAgent'; +import type { FetchFunctionType } from '../util/uploads/tusProtocol'; import type { SocketStatus } from '../types/SocketStatus'; import { VerificationTransport } from '../types/VerificationTransport'; import { toLogFormat } from '../types/errors'; @@ -143,7 +144,6 @@ function _validateResponse(response: any, schema: any) { const FIVE_MINUTES = 5 * durations.MINUTE; const GET_ATTACHMENT_CHUNK_TIMEOUT = 10 * durations.SECOND; -const BACKUP_CDN_VERSION = 3; type AgentCacheType = { [name: string]: { @@ -260,19 +260,21 @@ function getHostname(url: string): string { return urlObject.hostname; } -async function _promiseAjax( - providedUrl: string | null, +type FetchOptionsType = { + method: string; + body?: Uint8Array | Readable | string; + headers: FetchHeaderListType; + redirect?: 'error' | 'follow' | 'manual'; + agent?: Agent; + ca?: string; + timeout?: number; + abortSignal?: AbortSignal; +}; + +async function getFetchOptions( options: PromiseAjaxOptionsType -): Promise { - const { proxyUrl, socketManager } = options; - - const url = providedUrl || `${options.host}/${options.path}`; - const logType = socketManager ? '(WS)' : '(REST)'; - const redactedURL = options.redactUrl ? options.redactUrl(url) : url; - - const unauthLabel = options.unauthenticated ? ' (unauth)' : ''; - const logId = `${options.type} ${logType} ${redactedURL}${unauthLabel}`; - log.info(logId); +): Promise { + const { proxyUrl } = options; const timeout = typeof options.timeout === 'number' ? options.timeout : DEFAULT_TIMEOUT; @@ -313,6 +315,28 @@ async function _promiseAjax( abortSignal: options.abortSignal, }; + if (options.contentType) { + fetchOptions.headers['Content-Type'] = options.contentType; + } + + return fetchOptions; +} + +async function _promiseAjax( + providedUrl: string | null, + options: PromiseAjaxOptionsType +): Promise { + const fetchOptions = await getFetchOptions(options); + const { socketManager } = options; + + const url = providedUrl || `${options.host}/${options.path}`; + const logType = socketManager ? '(WS)' : '(REST)'; + const redactedURL = options.redactUrl ? options.redactUrl(url) : url; + + const unauthLabel = options.unauthenticated ? ' (unauth)' : ''; + const logId = `${options.type} ${logType} ${redactedURL}${unauthLabel}`; + log.info(logId); + if (fetchOptions.body instanceof Uint8Array) { // node-fetch doesn't support Uint8Array, only node Buffer const contentLength = fetchOptions.body.byteLength; @@ -337,10 +361,6 @@ async function _promiseAjax( }); } - if (options.contentType) { - fetchOptions.headers['Content-Type'] = options.contentType; - } - let response: Response; let result: string | Uint8Array | Readable | unknown; try { @@ -963,7 +983,7 @@ const artAuthZod = z.object({ export type ArtAuthType = z.infer; const attachmentV3Response = z.object({ - cdn: z.literal(2), + cdn: z.literal(2).or(z.literal(3)), key: z.string(), headers: z.record(z.string()), signedUploadLocation: z.string(), @@ -1129,17 +1149,6 @@ export type GetBackupInfoResponseType = z.infer< typeof getBackupInfoResponseSchema >; -export const getBackupUploadFormResponseSchema = z.object({ - cdn: z.number(), - key: z.string(), - headers: z.record(z.string(), z.string()), - signedUploadLocation: z.string(), -}); - -export type GetBackupUploadFormResponseType = z.infer< - typeof getBackupUploadFormResponseSchema ->; - export type WebAPIType = { startRegistration(): unknown; finishRegistration(baton: unknown): void; @@ -1327,15 +1336,18 @@ export type WebAPIType = { urgent?: boolean; } ) => Promise; + createFetchForAttachmentUpload( + attachment: AttachmentV3ResponseType + ): FetchFunctionType; getBackupInfo: ( headers: BackupPresentationHeadersType ) => Promise; getBackupUploadForm: ( headers: BackupPresentationHeadersType - ) => Promise; + ) => Promise; getBackupMediaUploadForm: ( headers: BackupPresentationHeadersType - ) => Promise; + ) => Promise; refreshBackup: (headers: BackupPresentationHeadersType) => Promise; getBackupCredentials: ( options: GetBackupCredentialsOptionsType @@ -1360,7 +1372,6 @@ export type WebAPIType = { uploadAvatarRequestHeaders: UploadAvatarHeadersType, avatarData: Uint8Array ) => Promise; - uploadBackup: (options: UploadBackupOptionsType) => Promise; uploadGroupAvatar: ( avatarData: Uint8Array, options: GroupCredentialsType @@ -1646,6 +1657,7 @@ export function initialize({ checkAccountExistence, checkSockets, createAccount, + createFetchForAttachmentUpload, confirmUsername, createGroup, deleteUsername, @@ -1727,7 +1739,6 @@ export function initialize({ unregisterRequestHandler, updateDeviceName, uploadAvatar, - uploadBackup, uploadGroupAvatar, whoami, }; @@ -2725,7 +2736,45 @@ export function initialize({ responseType: 'json', }); - return getBackupUploadFormResponseSchema.parse(res); + return attachmentV3Response.parse(res); + } + + function createFetchForAttachmentUpload({ + signedUploadLocation, + headers: uploadHeaders, + cdn, + }: AttachmentV3ResponseType): FetchFunctionType { + strictAssert(cdn === 3, 'Fetch can only be created for CDN 3'); + const { origin: expectedOrigin } = new URL(signedUploadLocation); + + return async ( + endpoint: string | URL, + init: RequestInit + ): Promise => { + const { origin } = new URL(endpoint); + strictAssert(origin === expectedOrigin, `Unexpected origin: ${origin}`); + + const fetchOptions = await getFetchOptions({ + // Will be overriden + type: 'GET', + + certificateAuthority, + proxyUrl, + timeout: 0, + version, + + headers: uploadHeaders, + }); + + return fetch(endpoint, { + ...fetchOptions, + ...init, + headers: { + ...fetchOptions.headers, + ...init.headers, + }, + }); + }; } async function getBackupUploadForm(headers: BackupPresentationHeadersType) { @@ -2738,94 +2787,7 @@ export function initialize({ responseType: 'json', }); - return getBackupUploadFormResponseSchema.parse(res); - } - - async function uploadBackup({ headers, stream }: UploadBackupOptionsType) { - const { - signedUploadLocation, - headers: uploadHeaders, - cdn, - key, - } = await getBackupUploadForm(headers); - - strictAssert( - cdn === BACKUP_CDN_VERSION, - 'uploadBackup: unexpected cdn version' - ); - - let size = 0n; - stream.pause(); - stream.on('data', chunk => { - size += BigInt(chunk.length); - }); - - const uploadOptions = { - certificateAuthority, - proxyUrl, - timeout: 0, - type: 'POST' as const, - version, - headers: { - ...uploadHeaders, - 'Tus-Resumable': '1.0.0', - 'Content-Type': 'application/offset+octet-stream', - 'Upload-Defer-Length': '1', - }, - redactUrl: () => { - const tmp = new URL(signedUploadLocation); - tmp.search = ''; - tmp.pathname = ''; - return `${tmp}[REDACTED]`; - }, - data: stream, - responseType: 'byteswithdetails' as const, - }; - - let response: Response; - try { - ({ response } = await _outerAjax(signedUploadLocation, uploadOptions)); - } catch (e) { - // Another upload in progress, getting 409 should have aborted it. - if (e instanceof HTTPError && e.code === 409) { - log.warn('uploadBackup: aborting previous unfinished upload'); - ({ response } = await _outerAjax( - signedUploadLocation, - uploadOptions - )); - } else { - throw e; - } - } - - const uploadLocation = response.headers.get('location'); - strictAssert(uploadLocation, 'backup response header has no location'); - - // Finish the upload by sending a PATCH with the stream length - - // This is going to the CDN, not the service, so we use _outerAjax - await _outerAjax(uploadLocation, { - certificateAuthority, - proxyUrl, - timeout: 0, - type: 'PATCH', - version, - headers: { - ...uploadHeaders, - 'Tus-Resumable': '1.0.0', - 'Content-Type': 'application/offset+octet-stream', - 'Upload-Offset': String(size), - 'Upload-Length': String(size), - }, - redactUrl: () => { - const tmp = new URL(uploadLocation); - tmp.search = ''; - tmp.pathname = ''; - return `${tmp}[REDACTED]`; - }, - }); - - return key; + return attachmentV3Response.parse(res); } async function refreshBackup(headers: BackupPresentationHeadersType) { diff --git a/ts/util/uploads/tusProtocol.ts b/ts/util/uploads/tusProtocol.ts index 119acfe086cd..a061aa41b410 100644 --- a/ts/util/uploads/tusProtocol.ts +++ b/ts/util/uploads/tusProtocol.ts @@ -1,6 +1,7 @@ // Copyright 2024 Signal Messenger, LLC // SPDX-License-Identifier: AGPL-3.0-only import { type Readable } from 'node:stream'; +import fetch, { type RequestInit, type Response } from 'node-fetch'; import { HTTPError } from '../../textsecure/Errors'; import * as log from '../../logging/log'; @@ -8,6 +9,11 @@ import * as Errors from '../../types/errors'; import { sleep } from '../sleep'; import { FIBONACCI_TIMEOUTS, BackOff } from '../BackOff'; +export type FetchFunctionType = ( + url: string | URL, + init: RequestInit +) => Promise; + const DEFAULT_MAX_RETRIES = 3; function toLogId(input: string) { @@ -49,6 +55,17 @@ function addProgressHandler( }); } +function wrapFetchWithBody( + responsePromise: Promise, + body: Readable +): Promise { + const errorPromise = new Promise((_resolve, reject) => { + body.on('error', reject); + }); + + return Promise.race([responsePromise, errorPromise]); +} + /** * @private * Generic TUS POST implementation with creation-with-upload. @@ -65,6 +82,7 @@ export async function _tusCreateWithUploadRequest({ onProgress, onCaughtError, signal, + fetchFn = fetch, }: { endpoint: string; headers: Record; @@ -74,6 +92,7 @@ export async function _tusCreateWithUploadRequest({ onProgress?: (bytesUploaded: number) => void; onCaughtError?: (error: Error) => void; signal?: AbortSignal; + fetchFn?: FetchFunctionType; }): Promise { const logId = `tusProtocol: CreateWithUpload(${toLogId(fileName)})`; if (onProgress != null) { @@ -83,23 +102,26 @@ export async function _tusCreateWithUploadRequest({ let response: Response; try { log.info(`${logId} init`); - response = await fetch(endpoint, { - method: 'POST', - signal, - // @ts-expect-error: `duplex` is missing from TypeScript's `RequestInit`. - duplex: 'half', - headers: { - ...headers, - 'Tus-Resumable': '1.0.0', - 'Upload-Length': String(fileSize), - 'Upload-Metadata': _getUploadMetadataHeader({ - filename: fileName, - }), - 'Content-Type': 'application/offset+octet-stream', - }, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - body: readable as any, - }); + response = await wrapFetchWithBody( + fetchFn(endpoint, { + method: 'POST', + signal, + // @ts-expect-error: `duplex` is missing from TypeScript's `RequestInit`. + duplex: 'half', + headers: { + ...headers, + 'Tus-Resumable': '1.0.0', + 'Upload-Length': String(fileSize), + 'Upload-Metadata': _getUploadMetadataHeader({ + filename: fileName, + }), + 'Content-Type': 'application/offset+octet-stream', + }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + body: readable as any, + }), + readable + ); } catch (error) { log.error(`${logId} closed without response`, Errors.toLogFormat(error)); onCaughtError?.(error); @@ -130,16 +152,18 @@ export async function _tusGetCurrentOffsetRequest({ headers, fileName, signal, + fetchFn = fetch, }: { endpoint: string; headers: Record; fileName: string; signal?: AbortSignal; + fetchFn?: FetchFunctionType; }): Promise { const logId = `tusProtocol: GetCurrentOffsetRequest(${toLogId(fileName)})`; log.info(`${logId} init`); - const response = await fetch(`${endpoint}/${fileName}`, { + const response = await fetchFn(`${endpoint}/${fileName}`, { method: 'HEAD', signal, headers: { @@ -183,6 +207,7 @@ export async function _tusResumeUploadRequest({ onProgress, onCaughtError, signal, + fetchFn = fetch, }: { endpoint: string; headers: Record; @@ -192,6 +217,7 @@ export async function _tusResumeUploadRequest({ onProgress?: (bytesUploaded: number) => void; onCaughtError?: (error: Error) => void; signal?: AbortSignal; + fetchFn?: FetchFunctionType; }): Promise { const logId = `tusProtocol: ResumeUploadRequest(${toLogId(fileName)})`; if (onProgress != null) { @@ -201,20 +227,23 @@ export async function _tusResumeUploadRequest({ let response: Response; try { log.info(`${logId} init`); - response = await fetch(`${endpoint}/${fileName}`, { - method: 'PATCH', - signal, - // @ts-expect-error: `duplex` is missing from TypeScript's `RequestInit`. - duplex: 'half', - headers: { - ...headers, - 'Tus-Resumable': '1.0.0', - 'Upload-Offset': String(uploadOffset), - 'Content-Type': 'application/offset+octet-stream', - }, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - body: readable as any, - }); + response = await wrapFetchWithBody( + fetchFn(`${endpoint}/${fileName}`, { + method: 'PATCH', + signal, + // @ts-expect-error: `duplex` is missing from TypeScript's `RequestInit`. + duplex: 'half', + headers: { + ...headers, + 'Tus-Resumable': '1.0.0', + 'Upload-Offset': String(uploadOffset), + 'Content-Type': 'application/offset+octet-stream', + }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + body: readable as any, + }), + readable + ); } catch (error) { log.error(`${logId} closed without response`, Errors.toLogFormat(error)); onCaughtError?.(error); @@ -244,6 +273,7 @@ export async function tusUpload({ onCaughtError, maxRetries = DEFAULT_MAX_RETRIES, signal, + fetchFn = fetch, }: { endpoint: string; headers: Record; @@ -255,6 +285,7 @@ export async function tusUpload({ onCaughtError?: (error: Error) => void; maxRetries?: number; signal?: AbortSignal; + fetchFn?: FetchFunctionType; }): Promise { const readable = reader(filePath); const done = await _tusCreateWithUploadRequest({ @@ -267,6 +298,7 @@ export async function tusUpload({ onProgress, onCaughtError, signal, + fetchFn, }); if (!done) { await tusResumeUpload({ @@ -280,6 +312,7 @@ export async function tusUpload({ onCaughtError, maxRetries, signal, + fetchFn, }); } } @@ -302,6 +335,7 @@ export async function tusResumeUpload({ onCaughtError, maxRetries = DEFAULT_MAX_RETRIES, signal, + fetchFn = fetch, }: { endpoint: string; headers: Record; @@ -313,6 +347,7 @@ export async function tusResumeUpload({ onCaughtError?: (error: Error) => void; maxRetries?: number; signal?: AbortSignal; + fetchFn?: FetchFunctionType; }): Promise { const backoff = new BackOff(FIBONACCI_TIMEOUTS, { jitter: BACKOFF_JITTER_MS, @@ -330,6 +365,7 @@ export async function tusResumeUpload({ headers, fileName, signal, + fetchFn, }); if (uploadOffset === fileSize) { @@ -348,6 +384,7 @@ export async function tusResumeUpload({ onProgress, onCaughtError, signal, + fetchFn, }); if (done) { diff --git a/ts/util/uploads/uploads.ts b/ts/util/uploads/uploads.ts index e9c280563db1..76f6c0113546 100644 --- a/ts/util/uploads/uploads.ts +++ b/ts/util/uploads/uploads.ts @@ -1,12 +1,13 @@ // Copyright 2024 Signal Messenger, LLC // SPDX-License-Identifier: AGPL-3.0-only +import fetch from 'node-fetch'; import { createReadStream, createWriteStream } from 'node:fs'; -import { Writable } from 'node:stream'; -import type { TusFileReader } from './tusProtocol'; +import { pipeline } from 'node:stream/promises'; +import type { TusFileReader, FetchFunctionType } from './tusProtocol'; import { tusResumeUpload, tusUpload } from './tusProtocol'; import { HTTPError } from '../../textsecure/Errors'; -const defaultFileReader: TusFileReader = (filePath, offset) => { +export const defaultFileReader: TusFileReader = (filePath, offset) => { return createReadStream(filePath, { start: offset }); }; @@ -87,13 +88,15 @@ export async function _doDownload({ headers = {}, filePath, signal, + fetchFn = fetch, }: { endpoint: string; filePath: string; headers?: Record; signal?: AbortSignal; + fetchFn?: FetchFunctionType; }): Promise { - const response = await fetch(endpoint, { + const response = await fetchFn(endpoint, { method: 'GET', signal, redirect: 'error', @@ -106,7 +109,7 @@ export async function _doDownload({ throw new Error('Response has no body'); } const writable = createWriteStream(filePath); - await response.body.pipeTo(Writable.toWeb(writable)); + await pipeline(response.body, writable); } /** diff --git a/yarn.lock b/yarn.lock index 9fe5551882b4..54773d0a671e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4001,10 +4001,10 @@ type-fest "^3.5.0" uuid "^8.3.0" -"@signalapp/mock-server@6.4.1": - version "6.4.1" - resolved "https://registry.yarnpkg.com/@signalapp/mock-server/-/mock-server-6.4.1.tgz#b49700f8d43b0c76d3f02820dd3b3da82a910f12" - integrity sha512-is75JwGL2CjLJ3NakMxw6rkgx379aKc3n328lSaiwLKVgBpuG/ms8wF3fNALxFstKoMl41lPzooOMWeqm+ubVQ== +"@signalapp/mock-server@6.4.2": + version "6.4.2" + resolved "https://registry.yarnpkg.com/@signalapp/mock-server/-/mock-server-6.4.2.tgz#9c0ccabaf7d9a8728503245d2fa2b4d7da6a5ccd" + integrity sha512-qL5wUGkbquZA6mKieuSOwlX51UyUFlLeQq+Z/F+gX910l8aYVV0niwtR1hYNPgvgxakPPXJ3VhIWE4qMgQRkrw== dependencies: "@signalapp/libsignal-client" "^0.42.0" debug "^4.3.2"