Refactor AttachmentCrypto

This commit is contained in:
Jamie Kyle 2024-02-05 15:17:28 -08:00 committed by GitHub
parent 96131112da
commit 395c67f6c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 443 additions and 792 deletions

File diff suppressed because it is too large Load diff

View file

@ -7,6 +7,7 @@ import { assert } from 'chai';
import { readFileSync, unlinkSync, writeFileSync } from 'fs';
import { join } from 'path';
import { randomBytes } from 'crypto';
import * as log from '../logging/log';
import * as Bytes from '../Bytes';
import * as Curve from '../Curve';
@ -36,7 +37,12 @@ import {
decryptAttachmentV1,
padAndEncryptAttachment,
} from '../Crypto';
import { decryptAttachmentV2, encryptAttachmentV2 } from '../AttachmentCrypto';
import {
KEY_SET_LENGTH,
_generateAttachmentIv,
decryptAttachmentV2,
encryptAttachmentV2,
} from '../AttachmentCrypto';
import { createTempDir, deleteTempDir } from '../updater/common';
import { uuidToBytes, bytesToUuid } from '../util/uuidToBytes';
@ -605,6 +611,10 @@ describe('Crypto', () => {
const FILE_CONTENTS = readFileSync(FILE_PATH);
let tempDir: string | undefined;
function generateAttachmentKeys(): Uint8Array {
return randomBytes(KEY_SET_LENGTH);
}
beforeEach(async () => {
tempDir = await createTempDir();
});
@ -615,7 +625,7 @@ describe('Crypto', () => {
});
it('v1 roundtrips (memory only)', () => {
const keys = getRandomBytes(64);
const keys = generateAttachmentKeys();
// Note: support for padding is not in decryptAttachmentV1, so we don't pad here
const encryptedAttachment = encryptAttachment({
@ -632,7 +642,7 @@ describe('Crypto', () => {
});
it('v1 -> v2 (memory -> disk)', async () => {
const keys = getRandomBytes(64);
const keys = generateAttachmentKeys();
const ciphertextPath = join(tempDir!, 'file');
let plaintextPath;
@ -670,7 +680,7 @@ describe('Crypto', () => {
});
it('v2 roundtrips (all on disk)', async () => {
const keys = getRandomBytes(64);
const keys = generateAttachmentKeys();
let plaintextPath;
let ciphertextPath;
@ -680,7 +690,6 @@ describe('Crypto', () => {
plaintextAbsolutePath: FILE_PATH,
size: FILE_CONTENTS.byteLength,
});
ciphertextPath = window.Signal.Migrations.getAbsoluteAttachmentPath(
encryptedAttachment.path
);
@ -695,9 +704,7 @@ describe('Crypto', () => {
decryptedAttachment.path
);
const plaintext = readFileSync(plaintextPath);
assert.isTrue(constantTimeEqual(FILE_CONTENTS, plaintext));
assert.strictEqual(encryptedAttachment.plaintextHash, GHOST_KITTY_HASH);
assert.strictEqual(
decryptedAttachment.plaintextHash,
@ -714,7 +721,7 @@ describe('Crypto', () => {
});
it('v2 -> v1 (disk -> memory)', async () => {
const keys = getRandomBytes(64);
const keys = generateAttachmentKeys();
let ciphertextPath;
try {
@ -760,11 +767,10 @@ describe('Crypto', () => {
});
it('v1 and v2 produce the same ciphertext, given same iv', async () => {
const keys = getRandomBytes(64);
const keys = generateAttachmentKeys();
const dangerousTestOnlyIv = _generateAttachmentIv();
let ciphertextPath;
const dangerousTestOnlyIv = getRandomBytes(16);
try {
const encryptedAttachmentV1 = padAndEncryptAttachment({
plaintext: FILE_CONTENTS,

View file

@ -1,13 +1,12 @@
// Copyright 2020 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
import { createWriteStream, existsSync, unlinkSync } from 'fs';
import { createWriteStream } from 'fs';
import { isNumber, omit } from 'lodash';
import type { Readable } from 'stream';
import { Transform } from 'stream';
import { pipeline } from 'stream/promises';
import { ensureFile } from 'fs-extra';
import * as log from '../logging/log';
import * as Errors from '../types/errors';
import { strictAssert } from '../util/assert';
@ -19,21 +18,23 @@ import {
} from '../types/Attachment';
import * as MIME from '../types/MIME';
import * as Bytes from '../Bytes';
import {
getFirstBytes,
decryptAttachmentV1,
getAttachmentSizeBucket,
} from '../Crypto';
import { getFirstBytes, decryptAttachmentV1 } from '../Crypto';
import {
decryptAttachmentV2,
IV_LENGTH,
ATTACHMENT_MAC_LENGTH,
getAttachmentDownloadSize,
safeUnlinkSync,
} from '../AttachmentCrypto';
import type { ProcessedAttachment } from './Types.d';
import type { WebAPIType } from './WebAPI';
import { createName, getRelativePath } from '../windows/attachments';
function getCdn(attachment: ProcessedAttachment) {
const { cdnId, cdnKey } = attachment;
const cdn = cdnId || cdnKey;
strictAssert(cdn, 'Attachment was missing cdnId or cdnKey');
return cdn;
}
export async function downloadAttachmentV1(
server: WebAPIType,
attachment: ProcessedAttachment,
@ -42,24 +43,16 @@ export async function downloadAttachmentV1(
timeout?: number;
}
): Promise<DownloadedAttachmentType> {
const cdnId = attachment.cdnId || attachment.cdnKey;
const { cdnNumber } = attachment;
if (!cdnId) {
throw new Error('downloadAttachment: Attachment was missing cdnId!');
}
const { cdnNumber, key, digest, size, contentType } = attachment;
const cdn = getCdn(attachment);
const encrypted = await server.getAttachment(
cdnId,
cdn,
dropNull(cdnNumber),
options
);
const { key, digest, size, contentType } = attachment;
if (!digest) {
throw new Error('Failure: Ask sender to update Signal and resend.');
}
strictAssert(digest, 'Failure: Ask sender to update Signal and resend.');
strictAssert(key, 'attachment has no key');
const paddedData = decryptAttachmentV1(
@ -78,7 +71,6 @@ export async function downloadAttachmentV1(
return {
...attachment,
size,
contentType: contentType
? MIME.stringToMIMEType(contentType)
@ -95,13 +87,10 @@ export async function downloadAttachmentV2(
timeout?: number;
}
): Promise<AttachmentType> {
const { cdnId, cdnKey, cdnNumber, contentType, digest, key, size } =
attachment;
const cdn = cdnId || cdnKey;
const { cdnNumber, contentType, digest, key, size } = attachment;
const cdn = getCdn(attachment);
const logId = `downloadAttachmentV2(${cdn}):`;
strictAssert(cdn, `${logId}: missing cdnId or cdnKey`);
strictAssert(digest, `${logId}: missing digest`);
strictAssert(key, `${logId}: missing key`);
strictAssert(isNumber(size), `${logId}: missing size`);
@ -124,9 +113,7 @@ export async function downloadAttachmentV2(
theirDigest: Bytes.fromBase64(digest),
});
if (existsSync(cipherTextAbsolutePath)) {
unlinkSync(cipherTextAbsolutePath);
}
safeUnlinkSync(cipherTextAbsolutePath);
return {
...omit(attachment, 'key'),
@ -151,19 +138,13 @@ async function downloadToDisk({
window.Signal.Migrations.getAbsoluteAttachmentPath(relativeTargetPath);
await ensureFile(absoluteTargetPath);
const writeStream = createWriteStream(absoluteTargetPath);
const targetSize =
getAttachmentSizeBucket(size) * 1.05 + IV_LENGTH + ATTACHMENT_MAC_LENGTH;
const checkSizeTransform = new CheckSizeTransform(targetSize);
const targetSize = getAttachmentDownloadSize(size);
try {
await pipeline(downloadStream, checkSizeTransform, writeStream);
await pipeline(downloadStream, checkSize(targetSize), writeStream);
} catch (error) {
try {
writeStream.close();
if (absoluteTargetPath && existsSync(absoluteTargetPath)) {
unlinkSync(absoluteTargetPath);
}
safeUnlinkSync(absoluteTargetPath);
} catch (cleanupError) {
log.error(
'downloadToDisk: Error while cleaning up',
@ -178,41 +159,21 @@ async function downloadToDisk({
}
// A simple transform that throws if it sees more than maxBytes on the stream.
class CheckSizeTransform extends Transform {
private bytesSeen = 0;
constructor(private maxBytes: number) {
super();
}
override _transform(
chunk: Buffer | undefined,
_encoding: string,
done: (error?: Error) => void
) {
if (!chunk || chunk.byteLength === 0) {
done();
return;
}
try {
this.bytesSeen += chunk.byteLength;
if (this.bytesSeen > this.maxBytes) {
done(
function checkSize(expectedBytes: number) {
let totalBytes = 0;
return new Transform({
transform(chunk, encoding, callback) {
totalBytes += chunk.byteLength;
if (totalBytes > expectedBytes) {
callback(
new AttachmentSizeError(
`CheckSizeTransform: Saw ${this.bytesSeen} bytes, max is ${this.maxBytes} bytes`
`checkSize: Received ${totalBytes} bytes, max is ${expectedBytes}, `
)
);
return;
}
this.push(chunk);
} catch (error) {
done(error);
return;
}
done();
}
this.push(chunk, encoding);
callback();
},
});
}