Refactor provisioning flow

Co-authored-by: Fedor Indutny <79877362+indutny-signal@users.noreply.github.com>
This commit is contained in:
automated-signal 2024-08-30 15:23:04 -05:00 committed by GitHub
parent f81c1b9331
commit d77d112a49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 373 additions and 225 deletions

View file

@ -12,7 +12,7 @@ import { hasExpired as hasExpiredSelector } from '../selectors/expiration';
import * as log from '../../logging/log'; import * as log from '../../logging/log';
import type { Loadable } from '../../util/loadable'; import type { Loadable } from '../../util/loadable';
import { LoadingState } from '../../util/loadable'; import { LoadingState } from '../../util/loadable';
import { assertDev } from '../../util/assert'; import { assertDev, strictAssert } from '../../util/assert';
import { explodePromise } from '../../util/explodePromise'; import { explodePromise } from '../../util/explodePromise';
import { missingCaseError } from '../../util/missingCaseError'; import { missingCaseError } from '../../util/missingCaseError';
import * as Registration from '../../util/registration'; import * as Registration from '../../util/registration';
@ -26,7 +26,7 @@ import { MAX_DEVICE_NAME_LENGTH } from '../../components/installScreen/InstallSc
import { WidthBreakpoint } from '../../components/_util'; import { WidthBreakpoint } from '../../components/_util';
import { HTTPError } from '../../textsecure/Errors'; import { HTTPError } from '../../textsecure/Errors';
import { isRecord } from '../../util/isRecord'; import { isRecord } from '../../util/isRecord';
import type { ConfirmNumberResultType } from '../../textsecure/AccountManager'; import { Provisioner } from '../../textsecure/Provisioner';
import * as Errors from '../../types/errors'; import * as Errors from '../../types/errors';
import { normalizeDeviceName } from '../../util/normalizeDeviceName'; import { normalizeDeviceName } from '../../util/normalizeDeviceName';
import OS from '../../util/os/osMain'; import OS from '../../util/os/osMain';
@ -89,8 +89,8 @@ function classifyError(
return { loadError: LoadError.Unknown }; return { loadError: LoadError.Unknown };
} }
} }
// AccountManager.registerSecondDevice uses this specific "websocket closed" error // AccountManager.registerSecondDevice uses this specific "websocket closed"
// message. // error message.
if (isRecord(err) && err.message === 'websocket closed') { if (isRecord(err) && err.message === 'websocket closed') {
return { installError: InstallError.ConnectionFailed }; return { installError: InstallError.ConnectionFailed };
} }
@ -197,23 +197,30 @@ export const SmartInstallScreen = memo(function SmartInstallScreen() {
useEffect(() => { useEffect(() => {
let hasCleanedUp = false; let hasCleanedUp = false;
const qrCodeResolution = explodePromise<void>();
const { server } = window.textsecure;
strictAssert(server, 'Expected a server');
let provisioner = new Provisioner(server);
const accountManager = window.getAccountManager(); const accountManager = window.getAccountManager();
assertDev(accountManager, 'Expected an account manager'); strictAssert(accountManager, 'Expected an account manager');
const updateProvisioningUrl = (value: string): void => { async function getQRCode(): Promise<void> {
const sleepError = new TimeoutError();
try {
const qrCodePromise = provisioner.getURL();
const sleepMs = qrCodeBackOff.getAndIncrement();
log.info(`InstallScreen/getQRCode: race to ${sleepMs}ms`);
const url = await pTimeout(qrCodePromise, sleepMs, sleepError);
if (hasCleanedUp) { if (hasCleanedUp) {
return; return;
} }
qrCodeResolution.resolve();
setProvisioningUrl(value);
};
const confirmNumber = async (): Promise<ConfirmNumberResultType> => { window.IPC.removeSetupMenuItems();
if (hasCleanedUp) { setProvisioningUrl(url);
throw new Error('Cannot confirm number; the component was unmounted');
} await provisioner.waitForEnvelope();
onQrCodeScanned(); onQrCodeScanned();
let deviceName: string; let deviceName: string;
@ -225,16 +232,19 @@ export const SmartInstallScreen = memo(function SmartInstallScreen() {
const backupFile = const backupFile =
await chooseBackupFilePromiseWrapperRef.current.promise; await chooseBackupFilePromiseWrapperRef.current.promise;
backupFileData = backupFile ? await fileToBytes(backupFile) : undefined; backupFileData = backupFile
? await fileToBytes(backupFile)
: undefined;
} }
if (hasCleanedUp) { if (hasCleanedUp) {
throw new Error('Cannot confirm number; the component was unmounted'); throw new Error('Cannot confirm number; the component was unmounted');
} }
// Delete all data from the database unless we're in the middle of a re-link. // Delete all data from the database unless we're in the middle of a
// Without this, the app restarts at certain times and can cause weird things to // re-link. Without this, the app restarts at certain times and can
// happen, like data from a previous light import showing up after a new install. // cause weird things to happen, like data from a previous light
// import showing up after a new install.
const shouldRetainData = Registration.everDone(); const shouldRetainData = Registration.everDone();
if (!shouldRetainData) { if (!shouldRetainData) {
try { try {
@ -251,28 +261,16 @@ export const SmartInstallScreen = memo(function SmartInstallScreen() {
throw new Error('Cannot confirm number; the component was unmounted'); throw new Error('Cannot confirm number; the component was unmounted');
} }
return { deviceName, backupFile: backupFileData }; const data = provisioner.prepareLinkData({
}; deviceName,
backupFile: backupFileData,
async function getQRCode(): Promise<void> { });
const sleepError = new TimeoutError(); await accountManager.registerSecondDevice(data);
try {
const qrCodePromise = accountManager.registerSecondDevice(
updateProvisioningUrl,
confirmNumber
);
const sleepMs = qrCodeBackOff.getAndIncrement();
log.info(`InstallScreen/getQRCode: race to ${sleepMs}ms`);
await Promise.all([
pTimeout(qrCodeResolution.promise, sleepMs, sleepError),
// Note that `registerSecondDevice` resolves once the registration
// is fully complete and thus should not be subjected to a timeout.
qrCodePromise,
]);
window.IPC.removeSetupMenuItems();
} catch (error) { } catch (error) {
provisioner.close();
strictAssert(server, 'Expected a server');
provisioner = new Provisioner(server);
log.error( log.error(
'account.registerSecondDevice: got an error', 'account.registerSecondDevice: got an error',
Errors.toLogFormat(error) Errors.toLogFormat(error)

View file

@ -21,9 +21,6 @@ import type {
KyberPreKeyType, KyberPreKeyType,
PniKeyMaterialType, PniKeyMaterialType,
} from './Types.d'; } from './Types.d';
import ProvisioningCipher from './ProvisioningCipher';
import type { IncomingWebSocketRequest } from './WebsocketResources';
import { ServerRequestType } from './WebsocketResources';
import createTaskWithTimeout from './TaskWithTimeout'; import createTaskWithTimeout from './TaskWithTimeout';
import * as Bytes from '../Bytes'; import * as Bytes from '../Bytes';
import * as Errors from '../types/errors'; import * as Errors from '../types/errors';
@ -46,7 +43,6 @@ import {
import type { AciString, PniString, ServiceIdString } from '../types/ServiceId'; import type { AciString, PniString, ServiceIdString } from '../types/ServiceId';
import { import {
isUntaggedPniString, isUntaggedPniString,
normalizePni,
ServiceIdKind, ServiceIdKind,
toTaggedPni, toTaggedPni,
} from '../types/ServiceId'; } from '../types/ServiceId';
@ -61,7 +57,6 @@ import { missingCaseError } from '../util/missingCaseError';
import { SignalService as Proto } from '../protobuf'; import { SignalService as Proto } from '../protobuf';
import * as log from '../logging/log'; import * as log from '../logging/log';
import type { StorageAccessType } from '../types/Storage'; import type { StorageAccessType } from '../types/Storage';
import { linkDeviceRoute } from '../util/signalRoutes';
import { getRelativePath, createName } from '../util/attachmentPath'; import { getRelativePath, createName } from '../util/attachmentPath';
import { isBackupEnabled } from '../util/isBackupEnabled'; import { isBackupEnabled } from '../util/isBackupEnabled';
@ -116,7 +111,7 @@ const SIGNED_PRE_KEY_UPDATE_TIME_KEY: StorageKeyByServiceIdKind = {
[ServiceIdKind.PNI]: 'signedKeyUpdateTimePNI', [ServiceIdKind.PNI]: 'signedKeyUpdateTimePNI',
}; };
enum AccountType { export enum AccountType {
Primary = 'Primary', Primary = 'Primary',
Linked = 'Linked', Linked = 'Linked',
} }
@ -146,7 +141,7 @@ type CreatePrimaryDeviceOptionsType = Readonly<{
}> & }> &
CreateAccountSharedOptionsType; CreateAccountSharedOptionsType;
type CreateLinkedDeviceOptionsType = Readonly<{ export type CreateLinkedDeviceOptionsType = Readonly<{
type: AccountType.Linked; type: AccountType.Linked;
deviceName: string; deviceName: string;
@ -327,8 +322,6 @@ export default class AccountManager extends EventTarget {
const accessKey = deriveAccessKey(profileKey); const accessKey = deriveAccessKey(profileKey);
const masterKey = getRandomBytes(MASTER_KEY_LENGTH); const masterKey = getRandomBytes(MASTER_KEY_LENGTH);
const registrationBaton = this.server.startRegistration();
try {
await this.createAccount({ await this.createAccount({
type: AccountType.Primary, type: AccountType.Primary,
number, number,
@ -341,132 +334,14 @@ export default class AccountManager extends EventTarget {
masterKey, masterKey,
readReceipts: true, readReceipts: true,
}); });
} finally {
this.server.finishRegistration(registrationBaton);
}
await this.registrationDone();
}); });
} }
async registerSecondDevice( async registerSecondDevice(
setProvisioningUrl: (url: string) => void, options: CreateLinkedDeviceOptionsType
confirmNumber: (number?: string) => Promise<ConfirmNumberResultType>
): Promise<void> { ): Promise<void> {
const provisioningCipher = new ProvisioningCipher();
const pubKey = await provisioningCipher.getPublicKey();
let envelopeCallbacks:
| {
resolve(data: Proto.ProvisionEnvelope): void;
reject(error: Error): void;
}
| undefined;
const envelopePromise = new Promise<Proto.ProvisionEnvelope>(
(resolve, reject) => {
envelopeCallbacks = { resolve, reject };
}
);
const wsr = await this.server.getProvisioningResource({
handleRequest(request: IncomingWebSocketRequest) {
if (
request.requestType === ServerRequestType.ProvisioningAddress &&
request.body
) {
const proto = Proto.ProvisioningUuid.decode(request.body);
const { uuid } = proto;
if (!uuid) {
throw new Error('registerSecondDevice: expected a UUID');
}
const url = linkDeviceRoute
.toAppUrl({
uuid,
pubKey: Bytes.toBase64(pubKey),
})
.toString();
window.SignalCI?.setProvisioningURL(url);
setProvisioningUrl(url);
request.respond(200, 'OK');
} else if (
request.requestType === ServerRequestType.ProvisioningMessage &&
request.body
) {
const envelope = Proto.ProvisionEnvelope.decode(request.body);
request.respond(200, 'OK');
wsr.close();
envelopeCallbacks?.resolve(envelope);
} else {
log.error('Unknown websocket message', request.requestType);
}
},
});
log.info('provisioning socket open');
wsr.addEventListener('close', ({ code, reason }) => {
log.info(`provisioning socket closed. Code: ${code} Reason: ${reason}`);
// Note: if we have resolved the envelope already - this has no effect
envelopeCallbacks?.reject(new Error('websocket closed'));
});
const envelope = await envelopePromise;
const provisionMessage = await provisioningCipher.decrypt(envelope);
await this.queueTask(async () => { await this.queueTask(async () => {
const { deviceName, backupFile } = await confirmNumber( await this.createAccount(options);
provisionMessage.number
);
if (typeof deviceName !== 'string' || deviceName.length === 0) {
throw new Error(
'AccountManager.registerSecondDevice: Invalid device name'
);
}
if (
!provisionMessage.number ||
!provisionMessage.provisioningCode ||
!provisionMessage.aciKeyPair ||
!provisionMessage.pniKeyPair ||
!provisionMessage.aci ||
!Bytes.isNotEmpty(provisionMessage.profileKey) ||
!Bytes.isNotEmpty(provisionMessage.masterKey) ||
!isUntaggedPniString(provisionMessage.untaggedPni)
) {
throw new Error(
'AccountManager.registerSecondDevice: Provision message was missing key data'
);
}
const ourAci = normalizeAci(provisionMessage.aci, 'provisionMessage.aci');
const ourPni = normalizePni(
toTaggedPni(provisionMessage.untaggedPni),
'provisionMessage.pni'
);
const registrationBaton = this.server.startRegistration();
try {
await this.createAccount({
type: AccountType.Linked,
number: provisionMessage.number,
verificationCode: provisionMessage.provisioningCode,
aciKeyPair: provisionMessage.aciKeyPair,
pniKeyPair: provisionMessage.pniKeyPair,
profileKey: provisionMessage.profileKey,
deviceName,
backupFile,
userAgent: provisionMessage.userAgent,
ourAci,
ourPni,
readReceipts: Boolean(provisionMessage.readReceipts),
masterKey: provisionMessage.masterKey,
});
} finally {
this.server.finishRegistration(registrationBaton);
}
await this.registrationDone();
}); });
} }
@ -1021,6 +896,18 @@ export default class AccountManager extends EventTarget {
private async createAccount( private async createAccount(
options: CreateAccountOptionsType options: CreateAccountOptionsType
): Promise<void> {
const registrationBaton = this.server.startRegistration();
try {
await this.doCreateAccount(options);
} finally {
this.server.finishRegistration(registrationBaton);
}
await this.registrationDone();
}
private async doCreateAccount(
options: CreateAccountOptionsType
): Promise<void> { ): Promise<void> {
const { const {
number, number,

View file

@ -0,0 +1,266 @@
// Copyright 2024 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
import {
type ExplodePromiseResultType,
explodePromise,
} from '../util/explodePromise';
import { linkDeviceRoute } from '../util/signalRoutes';
import { strictAssert } from '../util/assert';
import { normalizeAci } from '../util/normalizeAci';
import {
isUntaggedPniString,
normalizePni,
toTaggedPni,
} from '../types/ServiceId';
import * as Errors from '../types/errors';
import { SignalService as Proto } from '../protobuf';
import * as Bytes from '../Bytes';
import * as log from '../logging/log';
import { type WebAPIType } from './WebAPI';
import ProvisioningCipher, {
type ProvisionDecryptResult,
} from './ProvisioningCipher';
import {
type CreateLinkedDeviceOptionsType,
AccountType,
} from './AccountManager';
import {
type IWebSocketResource,
type IncomingWebSocketRequest,
ServerRequestType,
} from './WebsocketResources';
enum Step {
Idle = 'Idle',
Connecting = 'Connecting',
WaitingForURL = 'WaitingForURL',
WaitingForEnvelope = 'WaitingForEnvelope',
ReadyToLink = 'ReadyToLink',
Done = 'Done',
}
type StateType = Readonly<
| {
step: Step.Idle;
}
| {
step: Step.Connecting;
}
| {
step: Step.WaitingForURL;
url: ExplodePromiseResultType<string>;
}
| {
step: Step.WaitingForEnvelope;
done: ExplodePromiseResultType<void>;
}
| {
step: Step.ReadyToLink;
envelope: ProvisionDecryptResult;
}
| {
step: Step.Done;
}
>;
export type PrepareLinkDataOptionsType = Readonly<{
deviceName: string;
backupFile?: Uint8Array;
}>;
export class Provisioner {
private readonly cipher = new ProvisioningCipher();
private state: StateType = { step: Step.Idle };
private wsr: IWebSocketResource | undefined;
constructor(private readonly server: WebAPIType) {}
public close(error = new Error('Provisioner closed')): void {
try {
this.wsr?.close();
} catch {
// Best effort
}
const prevState = this.state;
this.state = { step: Step.Done };
if (prevState.step === Step.WaitingForURL) {
prevState.url.reject(error);
} else if (prevState.step === Step.WaitingForEnvelope) {
prevState.done.reject(error);
}
}
public async getURL(): Promise<string> {
strictAssert(
this.state.step === Step.Idle,
`Invalid state for getURL: ${this.state.step}`
);
this.state = { step: Step.Connecting };
const wsr = await this.server.getProvisioningResource({
handleRequest: (request: IncomingWebSocketRequest) => {
try {
this.handleRequest(request);
} catch (error) {
log.error(
'Provisioner.handleRequest: failure',
Errors.toLogFormat(error)
);
this.close();
}
},
});
this.wsr = wsr;
if (this.state.step !== Step.Connecting) {
this.close();
throw new Error('Provisioner closed early');
}
this.state = {
step: Step.WaitingForURL,
url: explodePromise(),
};
wsr.addEventListener('close', ({ code, reason }) => {
if (this.state.step === Step.ReadyToLink) {
// WebSocket close is not an issue since we no longer need it
return;
}
log.info(`provisioning socket closed. Code: ${code} Reason: ${reason}`);
this.close(new Error('websocket closed'));
});
return this.state.url.promise;
}
public async waitForEnvelope(): Promise<void> {
strictAssert(
this.state.step === Step.WaitingForEnvelope,
`Invalid state for waitForEnvelope: ${this.state.step}`
);
await this.state.done.promise;
}
public prepareLinkData({
deviceName,
backupFile,
}: PrepareLinkDataOptionsType): CreateLinkedDeviceOptionsType {
strictAssert(
this.state.step === Step.ReadyToLink,
`Invalid state for prepareLinkData: ${this.state.step}`
);
const { envelope } = this.state;
this.state = { step: Step.Done };
const {
number,
provisioningCode,
aciKeyPair,
pniKeyPair,
aci,
profileKey,
masterKey,
untaggedPni,
userAgent,
readReceipts,
} = envelope;
strictAssert(number, 'prepareLinkData: missing number');
strictAssert(provisioningCode, 'prepareLinkData: missing provisioningCode');
strictAssert(aciKeyPair, 'prepareLinkData: missing aciKeyPair');
strictAssert(pniKeyPair, 'prepareLinkData: missing pniKeyPair');
strictAssert(aci, 'prepareLinkData: missing aci');
strictAssert(
Bytes.isNotEmpty(profileKey),
'prepareLinkData: missing profileKey'
);
strictAssert(
Bytes.isNotEmpty(masterKey),
'prepareLinkData: missing masterKey'
);
strictAssert(
isUntaggedPniString(untaggedPni),
'prepareLinkData: invalid untaggedPni'
);
const ourAci = normalizeAci(aci, 'provisionMessage.aci');
const ourPni = normalizePni(
toTaggedPni(untaggedPni),
'provisionMessage.pni'
);
return {
type: AccountType.Linked,
number,
verificationCode: provisioningCode,
aciKeyPair,
pniKeyPair,
profileKey,
deviceName,
backupFile,
userAgent,
ourAci,
ourPni,
readReceipts: Boolean(readReceipts),
masterKey,
};
}
private handleRequest(request: IncomingWebSocketRequest): void {
const pubKey = this.cipher.getPublicKey();
if (
request.requestType === ServerRequestType.ProvisioningAddress &&
request.body
) {
strictAssert(
this.state.step === Step.WaitingForURL,
`Unexpected provisioning address, state: ${this.state}`
);
const prevState = this.state;
this.state = { step: Step.WaitingForEnvelope, done: explodePromise() };
const proto = Proto.ProvisioningUuid.decode(request.body);
const { uuid } = proto;
strictAssert(uuid, 'Provisioner.getURL: expected a UUID');
const url = linkDeviceRoute
.toAppUrl({
uuid,
pubKey: Bytes.toBase64(pubKey),
})
.toString();
window.SignalCI?.setProvisioningURL(url);
prevState.url.resolve(url);
request.respond(200, 'OK');
} else if (
request.requestType === ServerRequestType.ProvisioningMessage &&
request.body
) {
strictAssert(
this.state.step === Step.WaitingForEnvelope,
`Unexpected provisioning address, state: ${this.state}`
);
const prevState = this.state;
const ciphertext = Proto.ProvisionEnvelope.decode(request.body);
const message = this.cipher.decrypt(ciphertext);
this.state = { step: Step.ReadyToLink, envelope: message };
request.respond(200, 'OK');
this.wsr?.close();
prevState.done.resolve();
} else {
log.error('Unknown websocket message', request.requestType);
}
}
}

View file

@ -15,7 +15,7 @@ import { SignalService as Proto } from '../protobuf';
import { strictAssert } from '../util/assert'; import { strictAssert } from '../util/assert';
import { dropNull } from '../util/dropNull'; import { dropNull } from '../util/dropNull';
type ProvisionDecryptResult = { export type ProvisionDecryptResult = Readonly<{
aciKeyPair: KeyPairType; aciKeyPair: KeyPairType;
pniKeyPair?: KeyPairType; pniKeyPair?: KeyPairType;
number?: string; number?: string;
@ -26,14 +26,12 @@ type ProvisionDecryptResult = {
readReceipts?: boolean; readReceipts?: boolean;
profileKey?: Uint8Array; profileKey?: Uint8Array;
masterKey?: Uint8Array; masterKey?: Uint8Array;
}; }>;
class ProvisioningCipherInner { class ProvisioningCipherInner {
keyPair?: KeyPairType; keyPair?: KeyPairType;
async decrypt( decrypt(provisionEnvelope: Proto.ProvisionEnvelope): ProvisionDecryptResult {
provisionEnvelope: Proto.ProvisionEnvelope
): Promise<ProvisionDecryptResult> {
strictAssert( strictAssert(
provisionEnvelope.publicKey, provisionEnvelope.publicKey,
'Missing publicKey in ProvisionEnvelope' 'Missing publicKey in ProvisionEnvelope'
@ -77,7 +75,7 @@ class ProvisioningCipherInner {
strictAssert(aci, 'Missing aci in provisioning message'); strictAssert(aci, 'Missing aci in provisioning message');
strictAssert(pni, 'Missing pni in provisioning message'); strictAssert(pni, 'Missing pni in provisioning message');
const ret: ProvisionDecryptResult = { return {
aciKeyPair, aciKeyPair,
pniKeyPair, pniKeyPair,
number: dropNull(provisionMessage.number), number: dropNull(provisionMessage.number),
@ -86,17 +84,16 @@ class ProvisioningCipherInner {
provisioningCode: dropNull(provisionMessage.provisioningCode), provisioningCode: dropNull(provisionMessage.provisioningCode),
userAgent: dropNull(provisionMessage.userAgent), userAgent: dropNull(provisionMessage.userAgent),
readReceipts: provisionMessage.readReceipts ?? false, readReceipts: provisionMessage.readReceipts ?? false,
profileKey: Bytes.isNotEmpty(provisionMessage.profileKey)
? provisionMessage.profileKey
: undefined,
masterKey: Bytes.isNotEmpty(provisionMessage.masterKey)
? provisionMessage.masterKey
: undefined,
}; };
if (Bytes.isNotEmpty(provisionMessage.profileKey)) {
ret.profileKey = provisionMessage.profileKey;
}
if (Bytes.isNotEmpty(provisionMessage.masterKey)) {
ret.masterKey = provisionMessage.masterKey;
}
return ret;
} }
async getPublicKey(): Promise<Uint8Array> { getPublicKey(): Uint8Array {
if (!this.keyPair) { if (!this.keyPair) {
this.keyPair = generateKeyPair(); this.keyPair = generateKeyPair();
} }
@ -119,7 +116,7 @@ export default class ProvisioningCipher {
decrypt: ( decrypt: (
provisionEnvelope: Proto.ProvisionEnvelope provisionEnvelope: Proto.ProvisionEnvelope
) => Promise<ProvisionDecryptResult>; ) => ProvisionDecryptResult;
getPublicKey: () => Promise<Uint8Array>; getPublicKey: () => Uint8Array;
} }