Use TUS utilities for backup upload

This commit is contained in:
Fedor Indutny 2024-05-14 10:04:50 -07:00 committed by GitHub
parent 4eb5458ace
commit 4fed756661
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 356 additions and 228 deletions

View file

@ -211,7 +211,7 @@
"@formatjs/intl": "2.6.7", "@formatjs/intl": "2.6.7",
"@indutny/rezip-electron": "1.3.1", "@indutny/rezip-electron": "1.3.1",
"@mixer/parallel-prettier": "2.0.3", "@mixer/parallel-prettier": "2.0.3",
"@signalapp/mock-server": "6.4.1", "@signalapp/mock-server": "6.4.2",
"@storybook/addon-a11y": "7.4.5", "@storybook/addon-a11y": "7.4.5",
"@storybook/addon-actions": "7.4.5", "@storybook/addon-actions": "7.4.5",
"@storybook/addon-controls": "7.4.5", "@storybook/addon-controls": "7.4.5",

View file

@ -1,13 +1,13 @@
// Copyright 2024 Signal Messenger, LLC // Copyright 2024 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only // SPDX-License-Identifier: AGPL-3.0-only
import type { Readable } from 'stream';
import { strictAssert } from '../../util/assert'; import { strictAssert } from '../../util/assert';
import { tusUpload } from '../../util/uploads/tusProtocol';
import { defaultFileReader } from '../../util/uploads/uploads';
import type { import type {
WebAPIType, WebAPIType,
AttachmentV3ResponseType,
GetBackupInfoResponseType, GetBackupInfoResponseType,
GetBackupUploadFormResponseType,
BackupMediaItemType, BackupMediaItemType,
BackupMediaBatchResponseType, BackupMediaBatchResponseType,
BackupListMediaResponseType, BackupListMediaResponseType,
@ -55,14 +55,25 @@ export class BackupAPI {
return (await this.getInfo()).backupName; return (await this.getInfo()).backupName;
} }
public async upload(stream: Readable): Promise<string> { public async upload(filePath: string, fileSize: number): Promise<void> {
return this.server.uploadBackup({ const form = await this.server.getBackupUploadForm(
headers: await this.credentials.getHeadersForToday(), await this.credentials.getHeadersForToday()
stream, );
const fetchFn = this.server.createFetchForAttachmentUpload(form);
await tusUpload({
endpoint: form.signedUploadLocation,
headers: {},
fileName: form.key,
filePath,
fileSize,
reader: defaultFileReader,
fetchFn,
}); });
} }
public async getMediaUploadForm(): Promise<GetBackupUploadFormResponseType> { public async getMediaUploadForm(): Promise<AttachmentV3ResponseType> {
return this.server.getBackupMediaUploadForm( return this.server.getBackupMediaUploadForm(
await this.credentials.getHeadersForToday() await this.credentials.getHeadersForToday()
); );

View file

@ -622,6 +622,7 @@ export class BackupExportStream extends Readable {
}; };
if (!isNormalBubble(message)) { if (!isNormalBubble(message)) {
result.directionless = {};
return this.toChatItemFromNonBubble(result, message, options); return this.toChatItemFromNonBubble(result, message, options);
} }

View file

@ -638,12 +638,6 @@ export class BackupImportStream extends Writable {
? this.recipientIdToConvo.get(item.authorId.toNumber()) ? this.recipientIdToConvo.get(item.authorId.toNumber())
: undefined; : undefined;
const isOutgoing =
authorConvo && this.ourConversation.id === authorConvo?.id;
const isIncoming =
authorConvo && this.ourConversation.id !== authorConvo?.id;
const isDirectionLess = !isOutgoing && !isIncoming;
let attributes: MessageAttributesType = { let attributes: MessageAttributesType = {
id: generateUuid(), id: generateUuid(),
canReplyToStory: false, canReplyToStory: false,
@ -653,7 +647,7 @@ export class BackupImportStream extends Writable {
source: authorConvo?.e164, source: authorConvo?.e164,
sourceServiceId: authorConvo?.serviceId, sourceServiceId: authorConvo?.serviceId,
timestamp, timestamp,
type: isOutgoing ? 'outgoing' : 'incoming', type: item.outgoing != null ? 'outgoing' : 'incoming',
unidentifiedDeliveryReceived: false, unidentifiedDeliveryReceived: false,
expirationStartTimestamp: item.expireStartDate expirationStartTimestamp: item.expireStartDate
? getTimestampFromLong(item.expireStartDate) ? getTimestampFromLong(item.expireStartDate)
@ -667,7 +661,7 @@ export class BackupImportStream extends Writable {
const { outgoing, incoming, directionless } = item; const { outgoing, incoming, directionless } = item;
if (outgoing) { if (outgoing) {
strictAssert( strictAssert(
isOutgoing, authorConvo && this.ourConversation.id === authorConvo?.id,
`${logId}: outgoing message must have outgoing field` `${logId}: outgoing message must have outgoing field`
); );
@ -722,10 +716,9 @@ export class BackupImportStream extends Writable {
attributes.sendStateByConversationId = sendStateByConversationId; attributes.sendStateByConversationId = sendStateByConversationId;
chatConvo.active_at = attributes.sent_at; chatConvo.active_at = attributes.sent_at;
} } else if (incoming) {
if (incoming) {
strictAssert( strictAssert(
isIncoming, authorConvo && this.ourConversation.id !== authorConvo?.id,
`${logId}: message with incoming field must be incoming` `${logId}: message with incoming field must be incoming`
); );
attributes.received_at_ms = attributes.received_at_ms =
@ -741,12 +734,8 @@ export class BackupImportStream extends Writable {
} }
chatConvo.active_at = attributes.received_at_ms; chatConvo.active_at = attributes.received_at_ms;
} } else if (directionless) {
if (directionless) { // Nothing to do
strictAssert(
isDirectionLess,
`${logId}: directionless message must not be incoming/outgoing`
);
} }
if (item.standardMessage) { if (item.standardMessage) {
@ -793,7 +782,7 @@ export class BackupImportStream extends Writable {
// TODO (DESKTOP-6964): We'll want to increment for more types here - stickers, etc. // TODO (DESKTOP-6964): We'll want to increment for more types here - stickers, etc.
if (item.standardMessage) { if (item.standardMessage) {
if (isOutgoing) { if (item.outgoing != null) {
chatConvo.sentMessageCount = (chatConvo.sentMessageCount ?? 0) + 1; chatConvo.sentMessageCount = (chatConvo.sentMessageCount ?? 0) + 1;
} else { } else {
chatConvo.messageCount = (chatConvo.messageCount ?? 0) + 1; chatConvo.messageCount = (chatConvo.messageCount ?? 0) + 1;

View file

@ -5,6 +5,8 @@ import { pipeline } from 'stream/promises';
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
import type { Readable, Writable } from 'stream'; import type { Readable, Writable } from 'stream';
import { createReadStream, createWriteStream } from 'fs'; import { createReadStream, createWriteStream } from 'fs';
import { unlink } from 'fs/promises';
import { join } from 'path';
import { createGzip, createGunzip } from 'zlib'; import { createGzip, createGunzip } from 'zlib';
import { createCipheriv, createHmac, randomBytes } from 'crypto'; import { createCipheriv, createHmac, randomBytes } from 'crypto';
import { noop } from 'lodash'; import { noop } from 'lodash';
@ -27,6 +29,7 @@ import { BackupImportStream } from './import';
import { getKeyMaterial } from './crypto'; import { getKeyMaterial } from './crypto';
import { BackupCredentials } from './credentials'; import { BackupCredentials } from './credentials';
import { BackupAPI } from './api'; import { BackupAPI } from './api';
import { validateBackup } from './validator';
const IV_LENGTH = 16; const IV_LENGTH = 16;
@ -61,39 +64,21 @@ export class BackupsService {
}); });
} }
public async exportBackup(sink: Writable): Promise<void> { public async upload(): Promise<void> {
strictAssert(!this.isRunning, 'BackupService is already running'); const fileName = `backup-${randomBytes(32).toString('hex')}`;
const filePath = join(window.BasePaths.temp, fileName);
log.info('exportBackup: starting...');
this.isRunning = true;
try { try {
const { aesKey, macKey } = getKeyMaterial(); const fileSize = await this.exportToDisk(filePath);
const recordStream = new BackupExportStream(); await this.api.upload(filePath, fileSize);
recordStream.run();
const iv = randomBytes(IV_LENGTH);
await pipeline(
recordStream,
createGzip(),
appendPaddingStream(),
createCipheriv(CipherType.AES256CBC, aesKey, iv),
prependStream(iv),
appendMacStream(macKey),
sink
);
} finally { } finally {
log.info('exportBackup: finished...'); try {
this.isRunning = false; await unlink(filePath);
} catch {
// Ignore
} }
} }
public async upload(): Promise<void> {
const pipe = new PassThrough();
await Promise.all([this.api.upload(pipe), this.exportBackup(pipe)]);
} }
// Test harness // Test harness
@ -108,8 +93,12 @@ export class BackupsService {
} }
// Test harness // Test harness
public async exportToDisk(path: string): Promise<void> { public async exportToDisk(path: string): Promise<number> {
await this.exportBackup(createWriteStream(path)); const size = await this.exportBackup(createWriteStream(path));
await validateBackup(path, size);
return size;
} }
// Test harness // Test harness
@ -185,6 +174,49 @@ export class BackupsService {
} }
} }
private async exportBackup(sink: Writable): Promise<number> {
strictAssert(!this.isRunning, 'BackupService is already running');
log.info('exportBackup: starting...');
this.isRunning = true;
try {
const { aesKey, macKey } = getKeyMaterial();
const recordStream = new BackupExportStream();
recordStream.run();
const iv = randomBytes(IV_LENGTH);
const pass = new PassThrough();
let totalBytes = 0;
// Pause the flow first so that the we respect backpressure. The
// `pipeline` call below will control the flow anyway.
pass.pause();
pass.on('data', chunk => {
totalBytes += chunk.length;
});
await pipeline(
recordStream,
createGzip(),
appendPaddingStream(),
createCipheriv(CipherType.AES256CBC, aesKey, iv),
prependStream(iv),
appendMacStream(macKey),
pass,
sink
);
return totalBytes;
} finally {
log.info('exportBackup: finished...');
this.isRunning = false;
}
}
private async runPeriodicRefresh(): Promise<void> { private async runPeriodicRefresh(): Promise<void> {
try { try {
await this.api.refresh(); await this.api.refresh();

View file

@ -0,0 +1,92 @@
// Copyright 2024 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
import { type FileHandle, open } from 'node:fs/promises';
import * as libsignal from '@signalapp/libsignal-client/dist/MessageBackup';
import { InputStream } from '@signalapp/libsignal-client/dist/io';
import { strictAssert } from '../../util/assert';
import { toAciObject } from '../../util/ServiceId';
import { isTestOrMockEnvironment } from '../../environment';
class FileStream extends InputStream {
private file: FileHandle | undefined;
private position = 0;
private buffer = Buffer.alloc(16 * 1024);
private initPromise: Promise<unknown> | undefined;
constructor(private readonly filePath: string) {
super();
}
public async close(): Promise<void> {
await this.initPromise;
await this.file?.close();
}
async read(amount: number): Promise<Buffer> {
await this.initPromise;
if (!this.file) {
const filePromise = open(this.filePath);
this.initPromise = filePromise;
this.file = await filePromise;
}
if (this.buffer.length < amount) {
this.buffer = Buffer.alloc(amount);
}
const { bytesRead } = await this.file.read(
this.buffer,
0,
amount,
this.position
);
this.position += bytesRead;
return this.buffer.slice(0, bytesRead);
}
async skip(amount: number): Promise<void> {
this.position += amount;
}
}
export async function validateBackup(
filePath: string,
fileSize: number
): Promise<void> {
const masterKeyBase64 = window.storage.get('masterKey');
strictAssert(masterKeyBase64, 'Master key not available');
const masterKey = Buffer.from(masterKeyBase64, 'base64');
const aci = toAciObject(window.storage.user.getCheckedAci());
const backupKey = new libsignal.MessageBackupKey(masterKey, aci);
const streams = new Array<FileStream>();
let outcome: libsignal.ValidationOutcome;
try {
outcome = await libsignal.validate(
backupKey,
libsignal.Purpose.RemoteBackup,
() => {
const stream = new FileStream(filePath);
streams.push(stream);
return stream;
},
BigInt(fileSize)
);
} finally {
await Promise.all(streams.map(stream => stream.close()));
}
if (isTestOrMockEnvironment()) {
strictAssert(
outcome.ok,
`Backup validation failed: ${outcome.errorMessage}`
);
} else {
strictAssert(outcome.ok, 'Backup validation failed');
}
}

View file

@ -3,7 +3,8 @@
import path from 'path'; import path from 'path';
import { tmpdir } from 'os'; import { tmpdir } from 'os';
import { rmSync, mkdtempSync, createReadStream } from 'fs'; import { createReadStream } from 'fs';
import { mkdtemp, rm } from 'fs/promises';
import { v4 as generateGuid } from 'uuid'; import { v4 as generateGuid } from 'uuid';
import { assert } from 'chai'; import { assert } from 'chai';
@ -71,7 +72,7 @@ async function asymmetricRoundtripHarness(
before: Array<MessageAttributesType>, before: Array<MessageAttributesType>,
after: Array<MessageAttributesType> after: Array<MessageAttributesType>
) { ) {
const outDir = mkdtempSync(path.join(tmpdir(), 'signal-temp-')); const outDir = await mkdtemp(path.join(tmpdir(), 'signal-temp-'));
try { try {
const targetOutputFile = path.join(outDir, 'backup.bin'); const targetOutputFile = path.join(outDir, 'backup.bin');
@ -89,7 +90,7 @@ async function asymmetricRoundtripHarness(
const actual = sortAndNormalize(messagesFromDatabase); const actual = sortAndNormalize(messagesFromDatabase);
assert.deepEqual(expected, actual); assert.deepEqual(expected, actual);
} finally { } finally {
rmSync(outDir, { recursive: true }); await rm(outDir, { recursive: true });
} }
} }

View file

@ -65,7 +65,7 @@ export class TestServer extends EventEmitter {
typeof address === 'object' && address != null, typeof address === 'object' && address != null,
'address must be an object' 'address must be an object'
); );
return `http://localhost:${address.port}/}`; return `http://localhost:${address.port}/`;
} }
respondWith(status: number, headers: OutgoingHttpHeaders = {}): void { respondWith(status: number, headers: OutgoingHttpHeaders = {}): void {

View file

@ -146,7 +146,7 @@ describe('tusProtocol', () => {
}), }),
}); });
assert.strictEqual(result, false); assert.strictEqual(result, false);
assert.strictEqual(caughtError?.message, 'fetch failed'); assert.strictEqual(caughtError?.message, 'test');
}); });
}); });
@ -317,7 +317,7 @@ describe('tusProtocol', () => {
}), }),
}); });
assert.strictEqual(result, false); assert.strictEqual(result, false);
assert.strictEqual(caughtError?.message, 'fetch failed'); assert.strictEqual(caughtError?.message, 'test');
}); });
}); });
@ -327,7 +327,6 @@ describe('tusProtocol', () => {
function assertSocketCloseError(error: unknown) { function assertSocketCloseError(error: unknown) {
// There isn't an equivalent to this chain in assert() // There isn't an equivalent to this chain in assert()
expect(error, toLogFormat(error)) expect(error, toLogFormat(error))
.property('cause')
.property('code') .property('code')
.oneOf(['ECONNRESET', 'UND_ERR_SOCKET']); .oneOf(['ECONNRESET', 'UND_ERR_SOCKET']);
} }

View file

@ -3,6 +3,7 @@
/* eslint-disable max-classes-per-file */ /* eslint-disable max-classes-per-file */
import type { Response } from 'node-fetch';
import type { LibSignalErrorBase } from '@signalapp/libsignal-client'; import type { LibSignalErrorBase } from '@signalapp/libsignal-client';
import { parseRetryAfter } from '../util/parseRetryAfter'; import { parseRetryAfter } from '../util/parseRetryAfter';

View file

@ -6,7 +6,7 @@
/* eslint-disable no-restricted-syntax */ /* eslint-disable no-restricted-syntax */
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
import type { Response } from 'node-fetch'; import type { RequestInit, Response } from 'node-fetch';
import fetch from 'node-fetch'; import fetch from 'node-fetch';
import type { Agent } from 'https'; import type { Agent } from 'https';
import { escapeRegExp, isNumber, isString, isObject } from 'lodash'; import { escapeRegExp, isNumber, isString, isObject } from 'lodash';
@ -30,6 +30,7 @@ import { getBasicAuth } from '../util/getBasicAuth';
import { createHTTPSAgent } from '../util/createHTTPSAgent'; import { createHTTPSAgent } from '../util/createHTTPSAgent';
import { createProxyAgent } from '../util/createProxyAgent'; import { createProxyAgent } from '../util/createProxyAgent';
import type { ProxyAgent } from '../util/createProxyAgent'; import type { ProxyAgent } from '../util/createProxyAgent';
import type { FetchFunctionType } from '../util/uploads/tusProtocol';
import type { SocketStatus } from '../types/SocketStatus'; import type { SocketStatus } from '../types/SocketStatus';
import { VerificationTransport } from '../types/VerificationTransport'; import { VerificationTransport } from '../types/VerificationTransport';
import { toLogFormat } from '../types/errors'; import { toLogFormat } from '../types/errors';
@ -143,7 +144,6 @@ function _validateResponse(response: any, schema: any) {
const FIVE_MINUTES = 5 * durations.MINUTE; const FIVE_MINUTES = 5 * durations.MINUTE;
const GET_ATTACHMENT_CHUNK_TIMEOUT = 10 * durations.SECOND; const GET_ATTACHMENT_CHUNK_TIMEOUT = 10 * durations.SECOND;
const BACKUP_CDN_VERSION = 3;
type AgentCacheType = { type AgentCacheType = {
[name: string]: { [name: string]: {
@ -260,19 +260,21 @@ function getHostname(url: string): string {
return urlObject.hostname; return urlObject.hostname;
} }
async function _promiseAjax( type FetchOptionsType = {
providedUrl: string | null, method: string;
body?: Uint8Array | Readable | string;
headers: FetchHeaderListType;
redirect?: 'error' | 'follow' | 'manual';
agent?: Agent;
ca?: string;
timeout?: number;
abortSignal?: AbortSignal;
};
async function getFetchOptions(
options: PromiseAjaxOptionsType options: PromiseAjaxOptionsType
): Promise<unknown> { ): Promise<FetchOptionsType> {
const { proxyUrl, socketManager } = options; const { proxyUrl } = options;
const url = providedUrl || `${options.host}/${options.path}`;
const logType = socketManager ? '(WS)' : '(REST)';
const redactedURL = options.redactUrl ? options.redactUrl(url) : url;
const unauthLabel = options.unauthenticated ? ' (unauth)' : '';
const logId = `${options.type} ${logType} ${redactedURL}${unauthLabel}`;
log.info(logId);
const timeout = const timeout =
typeof options.timeout === 'number' ? options.timeout : DEFAULT_TIMEOUT; typeof options.timeout === 'number' ? options.timeout : DEFAULT_TIMEOUT;
@ -313,6 +315,28 @@ async function _promiseAjax(
abortSignal: options.abortSignal, abortSignal: options.abortSignal,
}; };
if (options.contentType) {
fetchOptions.headers['Content-Type'] = options.contentType;
}
return fetchOptions;
}
async function _promiseAjax(
providedUrl: string | null,
options: PromiseAjaxOptionsType
): Promise<unknown> {
const fetchOptions = await getFetchOptions(options);
const { socketManager } = options;
const url = providedUrl || `${options.host}/${options.path}`;
const logType = socketManager ? '(WS)' : '(REST)';
const redactedURL = options.redactUrl ? options.redactUrl(url) : url;
const unauthLabel = options.unauthenticated ? ' (unauth)' : '';
const logId = `${options.type} ${logType} ${redactedURL}${unauthLabel}`;
log.info(logId);
if (fetchOptions.body instanceof Uint8Array) { if (fetchOptions.body instanceof Uint8Array) {
// node-fetch doesn't support Uint8Array, only node Buffer // node-fetch doesn't support Uint8Array, only node Buffer
const contentLength = fetchOptions.body.byteLength; const contentLength = fetchOptions.body.byteLength;
@ -337,10 +361,6 @@ async function _promiseAjax(
}); });
} }
if (options.contentType) {
fetchOptions.headers['Content-Type'] = options.contentType;
}
let response: Response; let response: Response;
let result: string | Uint8Array | Readable | unknown; let result: string | Uint8Array | Readable | unknown;
try { try {
@ -963,7 +983,7 @@ const artAuthZod = z.object({
export type ArtAuthType = z.infer<typeof artAuthZod>; export type ArtAuthType = z.infer<typeof artAuthZod>;
const attachmentV3Response = z.object({ const attachmentV3Response = z.object({
cdn: z.literal(2), cdn: z.literal(2).or(z.literal(3)),
key: z.string(), key: z.string(),
headers: z.record(z.string()), headers: z.record(z.string()),
signedUploadLocation: z.string(), signedUploadLocation: z.string(),
@ -1129,17 +1149,6 @@ export type GetBackupInfoResponseType = z.infer<
typeof getBackupInfoResponseSchema typeof getBackupInfoResponseSchema
>; >;
export const getBackupUploadFormResponseSchema = z.object({
cdn: z.number(),
key: z.string(),
headers: z.record(z.string(), z.string()),
signedUploadLocation: z.string(),
});
export type GetBackupUploadFormResponseType = z.infer<
typeof getBackupUploadFormResponseSchema
>;
export type WebAPIType = { export type WebAPIType = {
startRegistration(): unknown; startRegistration(): unknown;
finishRegistration(baton: unknown): void; finishRegistration(baton: unknown): void;
@ -1327,15 +1336,18 @@ export type WebAPIType = {
urgent?: boolean; urgent?: boolean;
} }
) => Promise<MultiRecipient200ResponseType>; ) => Promise<MultiRecipient200ResponseType>;
createFetchForAttachmentUpload(
attachment: AttachmentV3ResponseType
): FetchFunctionType;
getBackupInfo: ( getBackupInfo: (
headers: BackupPresentationHeadersType headers: BackupPresentationHeadersType
) => Promise<GetBackupInfoResponseType>; ) => Promise<GetBackupInfoResponseType>;
getBackupUploadForm: ( getBackupUploadForm: (
headers: BackupPresentationHeadersType headers: BackupPresentationHeadersType
) => Promise<GetBackupUploadFormResponseType>; ) => Promise<AttachmentV3ResponseType>;
getBackupMediaUploadForm: ( getBackupMediaUploadForm: (
headers: BackupPresentationHeadersType headers: BackupPresentationHeadersType
) => Promise<GetBackupUploadFormResponseType>; ) => Promise<AttachmentV3ResponseType>;
refreshBackup: (headers: BackupPresentationHeadersType) => Promise<void>; refreshBackup: (headers: BackupPresentationHeadersType) => Promise<void>;
getBackupCredentials: ( getBackupCredentials: (
options: GetBackupCredentialsOptionsType options: GetBackupCredentialsOptionsType
@ -1360,7 +1372,6 @@ export type WebAPIType = {
uploadAvatarRequestHeaders: UploadAvatarHeadersType, uploadAvatarRequestHeaders: UploadAvatarHeadersType,
avatarData: Uint8Array avatarData: Uint8Array
) => Promise<string>; ) => Promise<string>;
uploadBackup: (options: UploadBackupOptionsType) => Promise<string>;
uploadGroupAvatar: ( uploadGroupAvatar: (
avatarData: Uint8Array, avatarData: Uint8Array,
options: GroupCredentialsType options: GroupCredentialsType
@ -1646,6 +1657,7 @@ export function initialize({
checkAccountExistence, checkAccountExistence,
checkSockets, checkSockets,
createAccount, createAccount,
createFetchForAttachmentUpload,
confirmUsername, confirmUsername,
createGroup, createGroup,
deleteUsername, deleteUsername,
@ -1727,7 +1739,6 @@ export function initialize({
unregisterRequestHandler, unregisterRequestHandler,
updateDeviceName, updateDeviceName,
uploadAvatar, uploadAvatar,
uploadBackup,
uploadGroupAvatar, uploadGroupAvatar,
whoami, whoami,
}; };
@ -2725,7 +2736,45 @@ export function initialize({
responseType: 'json', responseType: 'json',
}); });
return getBackupUploadFormResponseSchema.parse(res); return attachmentV3Response.parse(res);
}
function createFetchForAttachmentUpload({
signedUploadLocation,
headers: uploadHeaders,
cdn,
}: AttachmentV3ResponseType): FetchFunctionType {
strictAssert(cdn === 3, 'Fetch can only be created for CDN 3');
const { origin: expectedOrigin } = new URL(signedUploadLocation);
return async (
endpoint: string | URL,
init: RequestInit
): Promise<Response> => {
const { origin } = new URL(endpoint);
strictAssert(origin === expectedOrigin, `Unexpected origin: ${origin}`);
const fetchOptions = await getFetchOptions({
// Will be overriden
type: 'GET',
certificateAuthority,
proxyUrl,
timeout: 0,
version,
headers: uploadHeaders,
});
return fetch(endpoint, {
...fetchOptions,
...init,
headers: {
...fetchOptions.headers,
...init.headers,
},
});
};
} }
async function getBackupUploadForm(headers: BackupPresentationHeadersType) { async function getBackupUploadForm(headers: BackupPresentationHeadersType) {
@ -2738,94 +2787,7 @@ export function initialize({
responseType: 'json', responseType: 'json',
}); });
return getBackupUploadFormResponseSchema.parse(res); return attachmentV3Response.parse(res);
}
async function uploadBackup({ headers, stream }: UploadBackupOptionsType) {
const {
signedUploadLocation,
headers: uploadHeaders,
cdn,
key,
} = await getBackupUploadForm(headers);
strictAssert(
cdn === BACKUP_CDN_VERSION,
'uploadBackup: unexpected cdn version'
);
let size = 0n;
stream.pause();
stream.on('data', chunk => {
size += BigInt(chunk.length);
});
const uploadOptions = {
certificateAuthority,
proxyUrl,
timeout: 0,
type: 'POST' as const,
version,
headers: {
...uploadHeaders,
'Tus-Resumable': '1.0.0',
'Content-Type': 'application/offset+octet-stream',
'Upload-Defer-Length': '1',
},
redactUrl: () => {
const tmp = new URL(signedUploadLocation);
tmp.search = '';
tmp.pathname = '';
return `${tmp}[REDACTED]`;
},
data: stream,
responseType: 'byteswithdetails' as const,
};
let response: Response;
try {
({ response } = await _outerAjax(signedUploadLocation, uploadOptions));
} catch (e) {
// Another upload in progress, getting 409 should have aborted it.
if (e instanceof HTTPError && e.code === 409) {
log.warn('uploadBackup: aborting previous unfinished upload');
({ response } = await _outerAjax(
signedUploadLocation,
uploadOptions
));
} else {
throw e;
}
}
const uploadLocation = response.headers.get('location');
strictAssert(uploadLocation, 'backup response header has no location');
// Finish the upload by sending a PATCH with the stream length
// This is going to the CDN, not the service, so we use _outerAjax
await _outerAjax(uploadLocation, {
certificateAuthority,
proxyUrl,
timeout: 0,
type: 'PATCH',
version,
headers: {
...uploadHeaders,
'Tus-Resumable': '1.0.0',
'Content-Type': 'application/offset+octet-stream',
'Upload-Offset': String(size),
'Upload-Length': String(size),
},
redactUrl: () => {
const tmp = new URL(uploadLocation);
tmp.search = '';
tmp.pathname = '';
return `${tmp}[REDACTED]`;
},
});
return key;
} }
async function refreshBackup(headers: BackupPresentationHeadersType) { async function refreshBackup(headers: BackupPresentationHeadersType) {

View file

@ -1,6 +1,7 @@
// Copyright 2024 Signal Messenger, LLC // Copyright 2024 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only // SPDX-License-Identifier: AGPL-3.0-only
import { type Readable } from 'node:stream'; import { type Readable } from 'node:stream';
import fetch, { type RequestInit, type Response } from 'node-fetch';
import { HTTPError } from '../../textsecure/Errors'; import { HTTPError } from '../../textsecure/Errors';
import * as log from '../../logging/log'; import * as log from '../../logging/log';
@ -8,6 +9,11 @@ import * as Errors from '../../types/errors';
import { sleep } from '../sleep'; import { sleep } from '../sleep';
import { FIBONACCI_TIMEOUTS, BackOff } from '../BackOff'; import { FIBONACCI_TIMEOUTS, BackOff } from '../BackOff';
export type FetchFunctionType = (
url: string | URL,
init: RequestInit
) => Promise<Response>;
const DEFAULT_MAX_RETRIES = 3; const DEFAULT_MAX_RETRIES = 3;
function toLogId(input: string) { function toLogId(input: string) {
@ -49,6 +55,17 @@ function addProgressHandler(
}); });
} }
function wrapFetchWithBody(
responsePromise: Promise<Response>,
body: Readable
): Promise<Response> {
const errorPromise = new Promise<Response>((_resolve, reject) => {
body.on('error', reject);
});
return Promise.race([responsePromise, errorPromise]);
}
/** /**
* @private * @private
* Generic TUS POST implementation with creation-with-upload. * Generic TUS POST implementation with creation-with-upload.
@ -65,6 +82,7 @@ export async function _tusCreateWithUploadRequest({
onProgress, onProgress,
onCaughtError, onCaughtError,
signal, signal,
fetchFn = fetch,
}: { }: {
endpoint: string; endpoint: string;
headers: Record<string, string>; headers: Record<string, string>;
@ -74,6 +92,7 @@ export async function _tusCreateWithUploadRequest({
onProgress?: (bytesUploaded: number) => void; onProgress?: (bytesUploaded: number) => void;
onCaughtError?: (error: Error) => void; onCaughtError?: (error: Error) => void;
signal?: AbortSignal; signal?: AbortSignal;
fetchFn?: FetchFunctionType;
}): Promise<boolean> { }): Promise<boolean> {
const logId = `tusProtocol: CreateWithUpload(${toLogId(fileName)})`; const logId = `tusProtocol: CreateWithUpload(${toLogId(fileName)})`;
if (onProgress != null) { if (onProgress != null) {
@ -83,7 +102,8 @@ export async function _tusCreateWithUploadRequest({
let response: Response; let response: Response;
try { try {
log.info(`${logId} init`); log.info(`${logId} init`);
response = await fetch(endpoint, { response = await wrapFetchWithBody(
fetchFn(endpoint, {
method: 'POST', method: 'POST',
signal, signal,
// @ts-expect-error: `duplex` is missing from TypeScript's `RequestInit`. // @ts-expect-error: `duplex` is missing from TypeScript's `RequestInit`.
@ -99,7 +119,9 @@ export async function _tusCreateWithUploadRequest({
}, },
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
body: readable as any, body: readable as any,
}); }),
readable
);
} catch (error) { } catch (error) {
log.error(`${logId} closed without response`, Errors.toLogFormat(error)); log.error(`${logId} closed without response`, Errors.toLogFormat(error));
onCaughtError?.(error); onCaughtError?.(error);
@ -130,16 +152,18 @@ export async function _tusGetCurrentOffsetRequest({
headers, headers,
fileName, fileName,
signal, signal,
fetchFn = fetch,
}: { }: {
endpoint: string; endpoint: string;
headers: Record<string, string>; headers: Record<string, string>;
fileName: string; fileName: string;
signal?: AbortSignal; signal?: AbortSignal;
fetchFn?: FetchFunctionType;
}): Promise<number> { }): Promise<number> {
const logId = `tusProtocol: GetCurrentOffsetRequest(${toLogId(fileName)})`; const logId = `tusProtocol: GetCurrentOffsetRequest(${toLogId(fileName)})`;
log.info(`${logId} init`); log.info(`${logId} init`);
const response = await fetch(`${endpoint}/${fileName}`, { const response = await fetchFn(`${endpoint}/${fileName}`, {
method: 'HEAD', method: 'HEAD',
signal, signal,
headers: { headers: {
@ -183,6 +207,7 @@ export async function _tusResumeUploadRequest({
onProgress, onProgress,
onCaughtError, onCaughtError,
signal, signal,
fetchFn = fetch,
}: { }: {
endpoint: string; endpoint: string;
headers: Record<string, string>; headers: Record<string, string>;
@ -192,6 +217,7 @@ export async function _tusResumeUploadRequest({
onProgress?: (bytesUploaded: number) => void; onProgress?: (bytesUploaded: number) => void;
onCaughtError?: (error: Error) => void; onCaughtError?: (error: Error) => void;
signal?: AbortSignal; signal?: AbortSignal;
fetchFn?: FetchFunctionType;
}): Promise<boolean> { }): Promise<boolean> {
const logId = `tusProtocol: ResumeUploadRequest(${toLogId(fileName)})`; const logId = `tusProtocol: ResumeUploadRequest(${toLogId(fileName)})`;
if (onProgress != null) { if (onProgress != null) {
@ -201,7 +227,8 @@ export async function _tusResumeUploadRequest({
let response: Response; let response: Response;
try { try {
log.info(`${logId} init`); log.info(`${logId} init`);
response = await fetch(`${endpoint}/${fileName}`, { response = await wrapFetchWithBody(
fetchFn(`${endpoint}/${fileName}`, {
method: 'PATCH', method: 'PATCH',
signal, signal,
// @ts-expect-error: `duplex` is missing from TypeScript's `RequestInit`. // @ts-expect-error: `duplex` is missing from TypeScript's `RequestInit`.
@ -214,7 +241,9 @@ export async function _tusResumeUploadRequest({
}, },
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
body: readable as any, body: readable as any,
}); }),
readable
);
} catch (error) { } catch (error) {
log.error(`${logId} closed without response`, Errors.toLogFormat(error)); log.error(`${logId} closed without response`, Errors.toLogFormat(error));
onCaughtError?.(error); onCaughtError?.(error);
@ -244,6 +273,7 @@ export async function tusUpload({
onCaughtError, onCaughtError,
maxRetries = DEFAULT_MAX_RETRIES, maxRetries = DEFAULT_MAX_RETRIES,
signal, signal,
fetchFn = fetch,
}: { }: {
endpoint: string; endpoint: string;
headers: Record<string, string>; headers: Record<string, string>;
@ -255,6 +285,7 @@ export async function tusUpload({
onCaughtError?: (error: Error) => void; onCaughtError?: (error: Error) => void;
maxRetries?: number; maxRetries?: number;
signal?: AbortSignal; signal?: AbortSignal;
fetchFn?: FetchFunctionType;
}): Promise<void> { }): Promise<void> {
const readable = reader(filePath); const readable = reader(filePath);
const done = await _tusCreateWithUploadRequest({ const done = await _tusCreateWithUploadRequest({
@ -267,6 +298,7 @@ export async function tusUpload({
onProgress, onProgress,
onCaughtError, onCaughtError,
signal, signal,
fetchFn,
}); });
if (!done) { if (!done) {
await tusResumeUpload({ await tusResumeUpload({
@ -280,6 +312,7 @@ export async function tusUpload({
onCaughtError, onCaughtError,
maxRetries, maxRetries,
signal, signal,
fetchFn,
}); });
} }
} }
@ -302,6 +335,7 @@ export async function tusResumeUpload({
onCaughtError, onCaughtError,
maxRetries = DEFAULT_MAX_RETRIES, maxRetries = DEFAULT_MAX_RETRIES,
signal, signal,
fetchFn = fetch,
}: { }: {
endpoint: string; endpoint: string;
headers: Record<string, string>; headers: Record<string, string>;
@ -313,6 +347,7 @@ export async function tusResumeUpload({
onCaughtError?: (error: Error) => void; onCaughtError?: (error: Error) => void;
maxRetries?: number; maxRetries?: number;
signal?: AbortSignal; signal?: AbortSignal;
fetchFn?: FetchFunctionType;
}): Promise<void> { }): Promise<void> {
const backoff = new BackOff(FIBONACCI_TIMEOUTS, { const backoff = new BackOff(FIBONACCI_TIMEOUTS, {
jitter: BACKOFF_JITTER_MS, jitter: BACKOFF_JITTER_MS,
@ -330,6 +365,7 @@ export async function tusResumeUpload({
headers, headers,
fileName, fileName,
signal, signal,
fetchFn,
}); });
if (uploadOffset === fileSize) { if (uploadOffset === fileSize) {
@ -348,6 +384,7 @@ export async function tusResumeUpload({
onProgress, onProgress,
onCaughtError, onCaughtError,
signal, signal,
fetchFn,
}); });
if (done) { if (done) {

View file

@ -1,12 +1,13 @@
// Copyright 2024 Signal Messenger, LLC // Copyright 2024 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only // SPDX-License-Identifier: AGPL-3.0-only
import fetch from 'node-fetch';
import { createReadStream, createWriteStream } from 'node:fs'; import { createReadStream, createWriteStream } from 'node:fs';
import { Writable } from 'node:stream'; import { pipeline } from 'node:stream/promises';
import type { TusFileReader } from './tusProtocol'; import type { TusFileReader, FetchFunctionType } from './tusProtocol';
import { tusResumeUpload, tusUpload } from './tusProtocol'; import { tusResumeUpload, tusUpload } from './tusProtocol';
import { HTTPError } from '../../textsecure/Errors'; import { HTTPError } from '../../textsecure/Errors';
const defaultFileReader: TusFileReader = (filePath, offset) => { export const defaultFileReader: TusFileReader = (filePath, offset) => {
return createReadStream(filePath, { start: offset }); return createReadStream(filePath, { start: offset });
}; };
@ -87,13 +88,15 @@ export async function _doDownload({
headers = {}, headers = {},
filePath, filePath,
signal, signal,
fetchFn = fetch,
}: { }: {
endpoint: string; endpoint: string;
filePath: string; filePath: string;
headers?: Record<string, string>; headers?: Record<string, string>;
signal?: AbortSignal; signal?: AbortSignal;
fetchFn?: FetchFunctionType;
}): Promise<void> { }): Promise<void> {
const response = await fetch(endpoint, { const response = await fetchFn(endpoint, {
method: 'GET', method: 'GET',
signal, signal,
redirect: 'error', redirect: 'error',
@ -106,7 +109,7 @@ export async function _doDownload({
throw new Error('Response has no body'); throw new Error('Response has no body');
} }
const writable = createWriteStream(filePath); const writable = createWriteStream(filePath);
await response.body.pipeTo(Writable.toWeb(writable)); await pipeline(response.body, writable);
} }
/** /**

View file

@ -4001,10 +4001,10 @@
type-fest "^3.5.0" type-fest "^3.5.0"
uuid "^8.3.0" uuid "^8.3.0"
"@signalapp/mock-server@6.4.1": "@signalapp/mock-server@6.4.2":
version "6.4.1" version "6.4.2"
resolved "https://registry.yarnpkg.com/@signalapp/mock-server/-/mock-server-6.4.1.tgz#b49700f8d43b0c76d3f02820dd3b3da82a910f12" resolved "https://registry.yarnpkg.com/@signalapp/mock-server/-/mock-server-6.4.2.tgz#9c0ccabaf7d9a8728503245d2fa2b4d7da6a5ccd"
integrity sha512-is75JwGL2CjLJ3NakMxw6rkgx379aKc3n328lSaiwLKVgBpuG/ms8wF3fNALxFstKoMl41lPzooOMWeqm+ubVQ== integrity sha512-qL5wUGkbquZA6mKieuSOwlX51UyUFlLeQq+Z/F+gX910l8aYVV0niwtR1hYNPgvgxakPPXJ3VhIWE4qMgQRkrw==
dependencies: dependencies:
"@signalapp/libsignal-client" "^0.42.0" "@signalapp/libsignal-client" "^0.42.0"
debug "^4.3.2" debug "^4.3.2"