Make updates atomic again

This commit is contained in:
Fedor Indutny 2022-03-03 14:34:51 -08:00 committed by GitHub
parent c87cb59676
commit 26100ea562
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 162 additions and 94 deletions

View file

@ -68,12 +68,21 @@ describe('updater/signatures', () => {
const { publicKey, privateKey } = keyPair(); const { publicKey, privateKey } = keyPair();
await writeHexToPath(privateKeyPath, privateKey); await writeHexToPath(privateKeyPath, privateKey);
await writeSignature(updatePath, version, privateKeyPath); const signature = await writeSignature(
updatePath,
version,
privateKeyPath
);
const signaturePath = getSignaturePath(updatePath); const signaturePath = getSignaturePath(updatePath);
assert.strictEqual(existsSync(signaturePath), true); assert.strictEqual(existsSync(signaturePath), true);
const verified = await verifySignature(updatePath, version, publicKey); const verified = await verifySignature(
updatePath,
version,
signature,
publicKey
);
assert.strictEqual(verified, true); assert.strictEqual(verified, true);
} finally { } finally {
if (tempDir) { if (tempDir) {
@ -99,11 +108,16 @@ describe('updater/signatures', () => {
const { publicKey, privateKey } = keyPair(); const { publicKey, privateKey } = keyPair();
await writeHexToPath(privateKeyPath, privateKey); await writeHexToPath(privateKeyPath, privateKey);
await writeSignature(updatePath, version, privateKeyPath); const signature = await writeSignature(
updatePath,
version,
privateKeyPath
);
const verified = await verifySignature( const verified = await verifySignature(
updatePath, updatePath,
brokenVersion, brokenVersion,
signature,
publicKey publicKey
); );
assert.strictEqual(verified, false); assert.strictEqual(verified, false);
@ -130,14 +144,19 @@ describe('updater/signatures', () => {
const { publicKey, privateKey } = keyPair(); const { publicKey, privateKey } = keyPair();
await writeHexToPath(privateKeyPath, privateKey); await writeHexToPath(privateKeyPath, privateKey);
await writeSignature(updatePath, version, privateKeyPath); const signature = await writeSignature(
updatePath,
const signaturePath = getSignaturePath(updatePath); version,
const signature = Buffer.from(await loadHexFromPath(signaturePath)); privateKeyPath
);
signature[4] += 3; signature[4] += 3;
await writeHexToPath(signaturePath, signature);
const verified = await verifySignature(updatePath, version, publicKey); const verified = await verifySignature(
updatePath,
version,
signature,
publicKey
);
assert.strictEqual(verified, false); assert.strictEqual(verified, false);
} finally { } finally {
if (tempDir) { if (tempDir) {
@ -162,7 +181,11 @@ describe('updater/signatures', () => {
const { publicKey, privateKey } = keyPair(); const { publicKey, privateKey } = keyPair();
await writeHexToPath(privateKeyPath, privateKey); await writeHexToPath(privateKeyPath, privateKey);
await writeSignature(updatePath, version, privateKeyPath); const signature = await writeSignature(
updatePath,
version,
privateKeyPath
);
const brokenSourcePath = join( const brokenSourcePath = join(
__dirname, __dirname,
@ -170,7 +193,12 @@ describe('updater/signatures', () => {
); );
await copy(brokenSourcePath, updatePath); await copy(brokenSourcePath, updatePath);
const verified = await verifySignature(updatePath, version, publicKey); const verified = await verifySignature(
updatePath,
version,
signature,
publicKey
);
assert.strictEqual(verified, false); assert.strictEqual(verified, false);
} finally { } finally {
if (tempDir) { if (tempDir) {
@ -196,9 +224,18 @@ describe('updater/signatures', () => {
const { privateKey } = keyPair(); const { privateKey } = keyPair();
await writeHexToPath(privateKeyPath, privateKey); await writeHexToPath(privateKeyPath, privateKey);
await writeSignature(updatePath, version, privateKeyPath); const signature = await writeSignature(
updatePath,
version,
privateKeyPath
);
const verified = await verifySignature(updatePath, version, publicKey); const verified = await verifySignature(
updatePath,
version,
signature,
publicKey
);
assert.strictEqual(verified, false); assert.strictEqual(verified, false);
} finally { } finally {
if (tempDir) { if (tempDir) {

View file

@ -2,9 +2,9 @@
// SPDX-License-Identifier: AGPL-3.0-only // SPDX-License-Identifier: AGPL-3.0-only
/* eslint-disable no-console */ /* eslint-disable no-console */
import { createWriteStream, statSync } from 'fs'; import { createWriteStream } from 'fs';
import { pathExists } from 'fs-extra'; import { pathExists } from 'fs-extra';
import { readdir, writeFile } from 'fs/promises'; import { readdir, rename, stat, writeFile } from 'fs/promises';
import { promisify } from 'util'; import { promisify } from 'util';
import { execFile } from 'child_process'; import { execFile } from 'child_process';
import { join, normalize, extname } from 'path'; import { join, normalize, extname } from 'path';
@ -83,6 +83,11 @@ enum DownloadMode {
Automatic = 'Automatic', Automatic = 'Automatic',
} }
type DownloadUpdateResultType = Readonly<{
updateFilePath: string;
signature: Buffer;
}>;
export abstract class Updater { export abstract class Updater {
protected fileName: string | undefined; protected fileName: string | undefined;
@ -92,6 +97,8 @@ export abstract class Updater {
private throttledSendDownloadingUpdate: (downloadedSize: number) => void; private throttledSendDownloadingUpdate: (downloadedSize: number) => void;
private activeDownload: Promise<boolean> | undefined;
constructor( constructor(
protected readonly logger: LoggerType, protected readonly logger: LoggerType,
private readonly settingsChannel: SettingsChannel, private readonly settingsChannel: SettingsChannel,
@ -156,6 +163,23 @@ export abstract class Updater {
private async downloadAndInstall( private async downloadAndInstall(
updateInfo: UpdateInformationType, updateInfo: UpdateInformationType,
mode: DownloadMode mode: DownloadMode
): Promise<boolean> {
if (this.activeDownload) {
return this.activeDownload;
}
try {
this.activeDownload = this.doDownloadAndInstall(updateInfo, mode);
return await this.activeDownload;
} finally {
this.activeDownload = undefined;
}
}
private async doDownloadAndInstall(
updateInfo: UpdateInformationType,
mode: DownloadMode
): Promise<boolean> { ): Promise<boolean> {
const { logger } = this; const { logger } = this;
@ -163,19 +187,20 @@ export abstract class Updater {
try { try {
const oldVersion = this.version; const oldVersion = this.version;
this.version = newVersion; this.version = newVersion;
let updateFilePath: string | undefined; let downloadResult: DownloadUpdateResultType | undefined;
try { try {
updateFilePath = await this.downloadUpdate(updateInfo, mode); downloadResult = await this.downloadUpdate(updateInfo, mode);
} catch (error) { } catch (error) {
// Restore state in case of download error // Restore state in case of download error
this.version = oldVersion; this.version = oldVersion;
throw error; throw error;
} }
if (!updateFilePath) { if (!downloadResult) {
logger.warn('downloadAndInstall: no update was downloaded'); logger.warn('downloadAndInstall: no update was downloaded');
strictAssert( strictAssert(
mode !== DownloadMode.Automatic && mode !== DownloadMode.FullOnly, mode !== DownloadMode.Automatic && mode !== DownloadMode.FullOnly,
@ -184,10 +209,13 @@ export abstract class Updater {
return false; return false;
} }
const { updateFilePath, signature } = downloadResult;
const publicKey = hexToBinary(config.get('updatesPublicKey')); const publicKey = hexToBinary(config.get('updatesPublicKey'));
const verified = await verifySignature( const verified = await verifySignature(
updateFilePath, updateFilePath,
this.version, this.version,
signature,
publicKey publicKey
); );
if (!verified) { if (!verified) {
@ -403,7 +431,7 @@ export abstract class Updater {
private async downloadUpdate( private async downloadUpdate(
{ fileName, sha512, differentialData }: UpdateInformationType, { fileName, sha512, differentialData }: UpdateInformationType,
mode: DownloadMode mode: DownloadMode
): Promise<string | undefined> { ): Promise<DownloadUpdateResultType | undefined> {
const baseUrl = getUpdatesBase(); const baseUrl = getUpdatesBase();
const updateFileUrl = `${baseUrl}/${fileName}`; const updateFileUrl = `${baseUrl}/${fileName}`;
@ -414,43 +442,37 @@ export abstract class Updater {
const signatureUrl = `${baseUrl}/${signatureFileName}`; const signatureUrl = `${baseUrl}/${signatureFileName}`;
const blockMapUrl = `${baseUrl}/${blockMapFileName}`; const blockMapUrl = `${baseUrl}/${blockMapFileName}`;
const cacheDir = await createUpdateCacheDirIfNeeded(); let cacheDir = await createUpdateCacheDirIfNeeded();
const targetUpdatePath = join(cacheDir, fileName); const targetUpdatePath = join(cacheDir, fileName);
const targetSignaturePath = join(cacheDir, signatureFileName);
const targetBlockMapPath = join(cacheDir, blockMapFileName);
const targetPaths = [ const tempDir = await createTempDir();
targetUpdatePath, const restoreDir = await createTempDir();
targetSignaturePath,
targetBlockMapPath,
];
// List of files to be deleted on success const tempUpdatePath = join(tempDir, fileName);
const oldFiles = (await readdir(cacheDir)) const tempBlockMapPath = join(tempDir, blockMapFileName);
.map(oldFileName => {
return join(cacheDir, oldFileName);
})
.filter(path => !targetPaths.includes(path));
try { try {
validatePath(cacheDir, targetUpdatePath); validatePath(cacheDir, targetUpdatePath);
validatePath(cacheDir, targetSignaturePath);
validatePath(cacheDir, targetBlockMapPath); validatePath(tempDir, tempUpdatePath);
validatePath(tempDir, tempBlockMapPath);
this.logger.info(`downloadUpdate: Downloading signature ${signatureUrl}`); this.logger.info(`downloadUpdate: Downloading signature ${signatureUrl}`);
const signature = await got(signatureUrl, getGotOptions()).buffer(); const signature = Buffer.from(
await writeFile(targetSignaturePath, signature); await got(signatureUrl, getGotOptions()).text(),
'hex'
);
if (differentialData) { if (differentialData) {
this.logger.info(`downloadUpdate: Saving blockmap ${blockMapUrl}`); this.logger.info(`downloadUpdate: Saving blockmap ${blockMapUrl}`);
await writeFile(targetBlockMapPath, differentialData.newBlockMap); await writeFile(tempBlockMapPath, differentialData.newBlockMap);
} else { } else {
try { try {
this.logger.info( this.logger.info(
`downloadUpdate: Downloading blockmap ${blockMapUrl}` `downloadUpdate: Downloading blockmap ${blockMapUrl}`
); );
const blockMap = await got(blockMapUrl, getGotOptions()).buffer(); const blockMap = await got(blockMapUrl, getGotOptions()).buffer();
await writeFile(targetBlockMapPath, blockMap); await writeFile(tempBlockMapPath, blockMap);
} catch (error) { } catch (error) {
this.logger.warn( this.logger.warn(
'downloadUpdate: Failed to download blockmap, continuing', 'downloadUpdate: Failed to download blockmap, continuing',
@ -467,7 +489,17 @@ export abstract class Updater {
`downloadUpdate: Not downloading update ${updateFileUrl}, ` + `downloadUpdate: Not downloading update ${updateFileUrl}, ` +
'local file has the same hash' 'local file has the same hash'
); );
gotUpdate = true;
// Move file into downloads directory
try {
await rename(targetUpdatePath, tempUpdatePath);
gotUpdate = true;
} catch (error) {
this.logger.error(
'downloadUpdate: failed to move already downloaded file',
Errors.toLogFormat(error)
);
}
} else { } else {
this.logger.error( this.logger.error(
'downloadUpdate: integrity check failure', 'downloadUpdate: integrity check failure',
@ -484,7 +516,7 @@ export abstract class Updater {
); );
try { try {
await downloadDifferentialData(targetUpdatePath, differentialData, { await downloadDifferentialData(tempUpdatePath, differentialData, {
statusCallback: updateOnProgress statusCallback: updateOnProgress
? this.throttledSendDownloadingUpdate ? this.throttledSendDownloadingUpdate
: undefined, : undefined,
@ -505,9 +537,16 @@ export abstract class Updater {
this.logger.info( this.logger.info(
`downloadUpdate: Downloading full update ${updateFileUrl}` `downloadUpdate: Downloading full update ${updateFileUrl}`
); );
// We could have failed to update differentially due to low free disk
// space. Remove all cached updates since we are doing a full download
// anyway.
await rimrafPromise(cacheDir);
cacheDir = await createUpdateCacheDirIfNeeded();
await this.downloadAndReport( await this.downloadAndReport(
updateFileUrl, updateFileUrl,
targetUpdatePath, tempUpdatePath,
updateOnProgress updateOnProgress
); );
gotUpdate = true; gotUpdate = true;
@ -517,17 +556,22 @@ export abstract class Updater {
return undefined; return undefined;
} }
// Now that we successfully downloaded an update - remove old files // Backup old files
await Promise.all(oldFiles.map(path => rimrafPromise(path))); await rename(cacheDir, restoreDir);
return targetUpdatePath; // Move the files into the final position
} catch (error) {
try { try {
await Promise.all([targetPaths.map(path => rimrafPromise(path))]); await rename(tempDir, cacheDir);
} catch (_) { } catch (error) {
// Ignore error, this is a cleanup // Attempt to restore old files
await rename(restoreDir, cacheDir);
throw error;
} }
throw error;
return { updateFilePath: targetUpdatePath, signature };
} finally {
await Promise.all([deleteTempDir(tempDir), deleteTempDir(restoreDir)]);
} }
} }
@ -758,11 +802,13 @@ export async function createUpdateCacheDirIfNeeded(): Promise<string> {
} }
export async function deleteTempDir(targetDir: string): Promise<void> { export async function deleteTempDir(targetDir: string): Promise<void> {
const pathInfo = statSync(targetDir); if (await pathExists(targetDir)) {
if (!pathInfo.isDirectory()) { const pathInfo = await stat(targetDir);
throw new Error( if (!pathInfo.isDirectory()) {
`deleteTempDir: Cannot delete path '${targetDir}' because it is not a directory` throw new Error(
); `deleteTempDir: Cannot delete path '${targetDir}' because it is not a directory`
);
}
} }
const baseTempDir = getBaseTempDir(); const baseTempDir = getBaseTempDir();

View file

@ -2,11 +2,9 @@
// SPDX-License-Identifier: AGPL-3.0-only // SPDX-License-Identifier: AGPL-3.0-only
import type { FileHandle } from 'fs/promises'; import type { FileHandle } from 'fs/promises';
import { readFile, open, mkdtemp, mkdir, rename, unlink } from 'fs/promises'; import { readFile, open } from 'fs/promises';
import { promisify } from 'util'; import { promisify } from 'util';
import { gunzip as nativeGunzip } from 'zlib'; import { gunzip as nativeGunzip } from 'zlib';
import { tmpdir } from 'os';
import path from 'path';
import got from 'got'; import got from 'got';
import { chunk as lodashChunk } from 'lodash'; import { chunk as lodashChunk } from 'lodash';
import pMap from 'p-map'; import pMap from 'p-map';
@ -254,12 +252,10 @@ export async function download(
{ statusCallback, logger, gotOptions }: DownloadOptionsType = {} { statusCallback, logger, gotOptions }: DownloadOptionsType = {}
): Promise<void> { ): Promise<void> {
const input = await open(oldFile, 'r'); const input = await open(oldFile, 'r');
const output = await open(newFile, 'w');
const tempDir = await mkdtemp(path.join(tmpdir(), 'signal-temp-')); const abortController = new AbortController();
await mkdir(tempDir, { recursive: true }); const { signal: abortSignal } = abortController;
const tempFile = path.join(tempDir, path.basename(newFile));
const output = await open(tempFile, 'w');
const copyActions = diff.filter(({ action }) => action === 'copy'); const copyActions = diff.filter(({ action }) => action === 'copy');
@ -278,15 +274,16 @@ export async function download(
`Not enough data to read from offset=${readOffset} size=${size}` `Not enough data to read from offset=${readOffset} size=${size}`
); );
if (abortSignal?.aborted) {
return;
}
await output.write(chunk, 0, chunk.length, writeOffset); await output.write(chunk, 0, chunk.length, writeOffset);
}) })
); );
const downloadActions = diff.filter(({ action }) => action === 'download'); const downloadActions = diff.filter(({ action }) => action === 'download');
const abortController = new AbortController();
const { signal: abortSignal } = abortController;
try { try {
let downloadedSize = 0; let downloadedSize = 0;
@ -314,16 +311,8 @@ export async function download(
await Promise.all([input.close(), output.close()]); await Promise.all([input.close(), output.close()]);
} }
const checkResult = await checkIntegrity(tempFile, sha512); const checkResult = await checkIntegrity(newFile, sha512);
strictAssert(checkResult.ok, checkResult.error ?? ''); strictAssert(checkResult.ok, checkResult.error ?? '');
// Finally move the file into its final location
try {
await unlink(newFile);
} catch (_) {
// ignore errors
}
await rename(tempFile, newFile);
} }
export async function downloadRanges( export async function downloadRanges(
@ -389,6 +378,11 @@ export async function downloadRanges(
`newChunk=${chunk.length} ` + `newChunk=${chunk.length} ` +
`maxSize=${diff.size}` `maxSize=${diff.size}`
); );
if (abortSignal?.aborted) {
return;
}
await output.write(chunk, 0, chunk.length, offset + diff.writeOffset); await output.write(chunk, 0, chunk.length, offset + diff.writeOffset);
offset += chunk.length; offset += chunk.length;

View file

@ -7,6 +7,7 @@ import {
readFile as readFileCallback, readFile as readFileCallback,
writeFile as writeFileCallback, writeFile as writeFileCallback,
} from 'fs'; } from 'fs';
import { pipeline } from 'stream/promises';
import { basename, dirname, join, resolve as resolvePath } from 'path'; import { basename, dirname, join, resolve as resolvePath } from 'path';
import pify from 'pify'; import pify from 'pify';
@ -30,10 +31,9 @@ export async function generateSignature(
export async function verifySignature( export async function verifySignature(
updatePackagePath: string, updatePackagePath: string,
version: string, version: string,
signature: Buffer,
publicKey: Buffer publicKey: Buffer
): Promise<boolean> { ): Promise<boolean> {
const signaturePath = getSignaturePath(updatePackagePath);
const signature = await loadHexFromPath(signaturePath);
const message = await generateMessage(updatePackagePath, version); const message = await generateMessage(updatePackagePath, version);
return verify(publicKey, message, signature); return verify(publicKey, message, signature);
@ -55,7 +55,7 @@ export async function writeSignature(
updatePackagePath: string, updatePackagePath: string,
version: string, version: string,
privateKeyPath: string privateKeyPath: string
): Promise<void> { ): Promise<Buffer> {
const signaturePath = getSignaturePath(updatePackagePath); const signaturePath = getSignaturePath(updatePackagePath);
const signature = await generateSignature( const signature = await generateSignature(
updatePackagePath, updatePackagePath,
@ -63,23 +63,15 @@ export async function writeSignature(
privateKeyPath privateKeyPath
); );
await writeHexToPath(signaturePath, signature); await writeHexToPath(signaturePath, signature);
return signature;
} }
export async function _getFileHash(updatePackagePath: string): Promise<Buffer> { export async function _getFileHash(updatePackagePath: string): Promise<Buffer> {
const hash = createHash('sha256'); const hash = createHash('sha256');
const stream = createReadStream(updatePackagePath); await pipeline(createReadStream(updatePackagePath), hash);
return new Promise((resolve, reject) => { return hash.digest();
stream.on('data', data => {
hash.update(data);
});
stream.on('close', () => {
resolve(hash.digest());
});
stream.on('error', error => {
reject(error);
});
});
} }
export function getSignatureFileName(fileName: string): string { export function getSignatureFileName(fileName: string): string {

View file

@ -2,6 +2,7 @@
// SPDX-License-Identifier: AGPL-3.0-only // SPDX-License-Identifier: AGPL-3.0-only
import { createReadStream } from 'fs'; import { createReadStream } from 'fs';
import { pipeline } from 'stream/promises';
import { createHash } from 'crypto'; import { createHash } from 'crypto';
import * as Errors from '../types/errors'; import * as Errors from '../types/errors';
@ -23,9 +24,7 @@ export async function checkIntegrity(
): Promise<CheckIntegrityResultType> { ): Promise<CheckIntegrityResultType> {
try { try {
const hash = createHash('sha512'); const hash = createHash('sha512');
for await (const chunk of createReadStream(fileName)) { await pipeline(createReadStream(fileName), hash);
hash.update(chunk);
}
const actualSHA512 = hash.digest('base64'); const actualSHA512 = hash.digest('base64');
if (sha512 === actualSHA512) { if (sha512 === actualSHA512) {