Introduce in-memory transactions for sessions

This commit is contained in:
Fedor Indutny 2021-05-17 11:03:42 -07:00 committed by Scott Nonnenberg
parent 403b3c5fc6
commit 94d2c56ab9
12 changed files with 874 additions and 391 deletions

View file

@ -2,6 +2,7 @@
// SPDX-License-Identifier: AGPL-3.0-only
/* eslint-disable class-methods-use-this */
/* eslint-disable no-restricted-syntax */
import PQueue from 'p-queue';
import { isNumber } from 'lodash';
@ -22,7 +23,9 @@ import {
fromEncodedBinaryToArrayBuffer,
typedArrayToArrayBuffer,
} from './Crypto';
import { assert } from './util/assert';
import { isNotNil } from './util/isNotNil';
import { Lock } from './util/Lock';
import { isMoreRecentThan } from './util/timestamp';
import {
sessionRecordToProtobuf,
@ -102,9 +105,22 @@ type CacheEntryType<DBType, HydratedType> =
}
| { hydrated: true; fromDB: DBType; item: HydratedType };
type MapFields =
| 'identityKeys'
| 'preKeys'
| 'senderKeys'
| 'sessions'
| 'signedPreKeys';
export type SessionTransactionOptions = {
readonly lock?: Lock;
};
const GLOBAL_LOCK = new Lock();
async function _fillCaches<ID, T extends HasIdType<ID>, HydratedType>(
object: SignalProtocolStore,
field: keyof SignalProtocolStore,
field: MapFields,
itemsPromise: Promise<Array<T>>
): Promise<void> {
const items = await itemsPromise;
@ -182,6 +198,8 @@ const EventsMixin = (function EventsMixin(this: unknown) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any) as typeof window.Backbone.EventsMixin;
type SessionCacheEntry = CacheEntryType<SessionType, SessionRecord>;
export class SignalProtocolStore extends EventsMixin {
// Enums used across the app
@ -197,7 +215,15 @@ export class SignalProtocolStore extends EventsMixin {
senderKeys?: Map<string, CacheEntryType<SenderKeyType, SenderKeyRecord>>;
sessions?: Map<string, CacheEntryType<SessionType, SessionRecord>>;
sessions?: Map<string, SessionCacheEntry>;
sessionLock?: Lock;
sessionLockQueue: Array<() => void> = [];
pendingSessions = new Map<string, SessionCacheEntry>();
pendingUnprocessed = new Map<string, UnprocessedType>();
preKeys?: Map<number, CacheEntryType<PreKeyType, PreKeyRecord>>;
@ -562,43 +588,154 @@ export class SignalProtocolStore extends EventsMixin {
// Sessions
async loadSession(
encodedAddress: string
): Promise<SessionRecord | undefined> {
if (!this.sessions) {
throw new Error('loadSession: this.sessions not yet cached!');
// Re-entrant session transaction routine. Only one session transaction could
// be running at the same time.
//
// While in transaction:
//
// - `storeSession()` adds the updated session to the `pendingSessions`
// - `loadSession()` looks up the session first in `pendingSessions` and only
// then in the main `sessions` store
//
// When transaction ends:
//
// - successfully: pending session stores are batched into the database
// - with an error: pending session stores are reverted
async sessionTransaction<T>(
name: string,
body: () => Promise<T>,
lock: Lock = GLOBAL_LOCK
): Promise<T> {
// Allow re-entering from LibSignalStores
const isNested = this.sessionLock === lock;
if (this.sessionLock && !isNested) {
window.log.info(`sessionTransaction(${name}): sessions locked, waiting`);
await new Promise<void>(resolve => this.sessionLockQueue.push(resolve));
}
if (encodedAddress === null || encodedAddress === undefined) {
throw new Error('loadSession: encodedAddress was undefined/null');
if (!isNested) {
if (lock !== GLOBAL_LOCK) {
window.log.info(`sessionTransaction(${name}): enter`);
}
this.sessionLock = lock;
}
let result: T;
try {
const id = await normalizeEncodedAddress(encodedAddress);
const entry = this.sessions.get(id);
if (!entry) {
return undefined;
}
if (entry.hydrated) {
return entry.item;
}
const item = await this._maybeMigrateSession(entry.fromDB);
this.sessions.set(id, {
hydrated: true,
item,
fromDB: entry.fromDB,
});
return item;
result = await body();
} catch (error) {
const errorString = error && error.stack ? error.stack : error;
window.log.error(
`loadSession: failed to load session ${encodedAddress}: ${errorString}`
);
return undefined;
if (!isNested) {
await this.revertSessions(name, error);
this.releaseSessionLock();
}
throw error;
}
if (!isNested) {
await this.commitSessions(name);
this.releaseSessionLock();
}
return result;
}
private async commitSessions(name: string): Promise<void> {
const { pendingSessions, pendingUnprocessed } = this;
if (pendingSessions.size === 0 && pendingUnprocessed.size === 0) {
return;
}
window.log.info(
`commitSessions(${name}): pending sessions ${pendingSessions.size} ` +
`pending unprocessed ${pendingUnprocessed.size}`
);
this.pendingSessions = new Map();
this.pendingUnprocessed = new Map();
// Commit both unprocessed and sessions in the same database transaction
// to unroll both on error.
await window.Signal.Data.commitSessionsAndUnprocessed({
sessions: Array.from(pendingSessions.values()).map(
({ fromDB }) => fromDB
),
unprocessed: Array.from(pendingUnprocessed.values()),
});
const { sessions } = this;
assert(sessions !== undefined, "Can't commit unhydrated storage");
// Apply changes to in-memory storage after successful DB write.
pendingSessions.forEach((value, key) => {
sessions.set(key, value);
});
}
private async revertSessions(name: string, error: Error): Promise<void> {
window.log.info(
`revertSessions(${name}): pending size ${this.pendingSessions.size}`,
error && error.stack
);
this.pendingSessions.clear();
this.pendingUnprocessed.clear();
}
private releaseSessionLock(): void {
this.sessionLock = undefined;
const next = this.sessionLockQueue.shift();
if (next) {
next();
}
}
async loadSession(
encodedAddress: string,
{ lock }: SessionTransactionOptions = {}
): Promise<SessionRecord | undefined> {
return this.sessionTransaction(
'loadSession',
async () => {
if (!this.sessions) {
throw new Error('loadSession: this.sessions not yet cached!');
}
if (encodedAddress === null || encodedAddress === undefined) {
throw new Error('loadSession: encodedAddress was undefined/null');
}
try {
const id = await normalizeEncodedAddress(encodedAddress);
const map = this.pendingSessions.has(id)
? this.pendingSessions
: this.sessions;
const entry = map.get(id);
if (!entry) {
return undefined;
}
if (entry.hydrated) {
return entry.item;
}
const item = await this._maybeMigrateSession(entry.fromDB);
map.set(id, {
hydrated: true,
item,
fromDB: entry.fromDB,
});
return item;
} catch (error) {
const errorString = error && error.stack ? error.stack : error;
window.log.error(
`loadSession: failed to load session ${encodedAddress}: ${errorString}`
);
return undefined;
}
},
lock
);
}
private async _maybeMigrateSession(
@ -643,139 +780,155 @@ export class SignalProtocolStore extends EventsMixin {
async storeSession(
encodedAddress: string,
record: SessionRecord
record: SessionRecord,
{ lock }: SessionTransactionOptions = {}
): Promise<void> {
if (!this.sessions) {
throw new Error('storeSession: this.sessions not yet cached!');
}
await this.sessionTransaction(
'storeSession',
async () => {
if (!this.sessions) {
throw new Error('storeSession: this.sessions not yet cached!');
}
if (encodedAddress === null || encodedAddress === undefined) {
throw new Error('storeSession: encodedAddress was undefined/null');
}
const unencoded = window.textsecure.utils.unencodeNumber(encodedAddress);
const deviceId = parseInt(unencoded[1], 10);
if (encodedAddress === null || encodedAddress === undefined) {
throw new Error('storeSession: encodedAddress was undefined/null');
}
const unencoded = window.textsecure.utils.unencodeNumber(
encodedAddress
);
const deviceId = parseInt(unencoded[1], 10);
try {
const id = await normalizeEncodedAddress(encodedAddress);
const fromDB = {
id,
version: 2,
conversationId: window.textsecure.utils.unencodeNumber(id)[0],
deviceId,
record: record.serialize().toString('base64'),
};
try {
const id = await normalizeEncodedAddress(encodedAddress);
const fromDB = {
id,
version: 2,
conversationId: window.textsecure.utils.unencodeNumber(id)[0],
deviceId,
record: record.serialize().toString('base64'),
};
await window.Signal.Data.createOrUpdateSession(fromDB);
this.sessions.set(id, {
hydrated: true,
fromDB,
item: record,
});
} catch (error) {
const errorString = error && error.stack ? error.stack : error;
window.log.error(
`storeSession: Save failed fo ${encodedAddress}: ${errorString}`
);
throw error;
}
const newSession = {
hydrated: true,
fromDB,
item: record,
};
this.pendingSessions.set(id, newSession);
} catch (error) {
const errorString = error && error.stack ? error.stack : error;
window.log.error(
`storeSession: Save failed fo ${encodedAddress}: ${errorString}`
);
throw error;
}
},
lock
);
}
async getDeviceIds(identifier: string): Promise<Array<number>> {
if (!this.sessions) {
throw new Error('getDeviceIds: this.sessions not yet cached!');
}
if (identifier === null || identifier === undefined) {
throw new Error('getDeviceIds: identifier was undefined/null');
}
try {
const id = window.ConversationController.getConversationId(identifier);
if (!id) {
throw new Error(
`getDeviceIds: No conversationId found for identifier ${identifier}`
);
return this.sessionTransaction('getDeviceIds', async () => {
if (!this.sessions) {
throw new Error('getDeviceIds: this.sessions not yet cached!');
}
if (identifier === null || identifier === undefined) {
throw new Error('getDeviceIds: identifier was undefined/null');
}
const allSessions = Array.from(this.sessions.values());
const entries = allSessions.filter(
session => session.fromDB.conversationId === id
);
const openIds = await Promise.all(
entries.map(async entry => {
if (entry.hydrated) {
const record = entry.item;
try {
const id = window.ConversationController.getConversationId(identifier);
if (!id) {
throw new Error(
`getDeviceIds: No conversationId found for identifier ${identifier}`
);
}
const allSessions = this._getAllSessions();
const entries = allSessions.filter(
session => session.fromDB.conversationId === id
);
const openIds = await Promise.all(
entries.map(async entry => {
if (entry.hydrated) {
const record = entry.item;
if (record.hasCurrentState()) {
return entry.fromDB.deviceId;
}
return undefined;
}
const record = await this._maybeMigrateSession(entry.fromDB);
if (record.hasCurrentState()) {
return entry.fromDB.deviceId;
}
return undefined;
}
})
);
const record = await this._maybeMigrateSession(entry.fromDB);
if (record.hasCurrentState()) {
return entry.fromDB.deviceId;
}
return openIds.filter(isNotNil);
} catch (error) {
window.log.error(
`getDeviceIds: Failed to get device ids for identifier ${identifier}`,
error && error.stack ? error.stack : error
);
}
return undefined;
})
);
return openIds.filter(isNotNil);
} catch (error) {
window.log.error(
`getDeviceIds: Failed to get device ids for identifier ${identifier}`,
error && error.stack ? error.stack : error
);
}
return [];
return [];
});
}
async removeSession(encodedAddress: string): Promise<void> {
if (!this.sessions) {
throw new Error('removeSession: this.sessions not yet cached!');
}
return this.sessionTransaction('removeSession', async () => {
if (!this.sessions) {
throw new Error('removeSession: this.sessions not yet cached!');
}
window.log.info('removeSession: deleting session for', encodedAddress);
try {
const id = await normalizeEncodedAddress(encodedAddress);
await window.Signal.Data.removeSessionById(id);
this.sessions.delete(id);
} catch (e) {
window.log.error(
`removeSession: Failed to delete session for ${encodedAddress}`
);
}
window.log.info('removeSession: deleting session for', encodedAddress);
try {
const id = await normalizeEncodedAddress(encodedAddress);
await window.Signal.Data.removeSessionById(id);
this.sessions.delete(id);
this.pendingSessions.delete(id);
} catch (e) {
window.log.error(
`removeSession: Failed to delete session for ${encodedAddress}`
);
}
});
}
async removeAllSessions(identifier: string): Promise<void> {
if (!this.sessions) {
throw new Error('removeAllSessions: this.sessions not yet cached!');
}
if (identifier === null || identifier === undefined) {
throw new Error('removeAllSessions: identifier was undefined/null');
}
window.log.info('removeAllSessions: deleting sessions for', identifier);
const id = window.ConversationController.getConversationId(identifier);
const entries = Array.from(this.sessions.values());
for (let i = 0, max = entries.length; i < max; i += 1) {
const entry = entries[i];
if (entry.fromDB.conversationId === id) {
this.sessions.delete(entry.fromDB.id);
return this.sessionTransaction('removeAllSessions', async () => {
if (!this.sessions) {
throw new Error('removeAllSessions: this.sessions not yet cached!');
}
}
await window.Signal.Data.removeSessionsByConversation(identifier);
if (identifier === null || identifier === undefined) {
throw new Error('removeAllSessions: identifier was undefined/null');
}
window.log.info('removeAllSessions: deleting sessions for', identifier);
const id = window.ConversationController.getConversationId(identifier);
const entries = Array.from(this.sessions.values());
for (let i = 0, max = entries.length; i < max; i += 1) {
const entry = entries[i];
if (entry.fromDB.conversationId === id) {
this.sessions.delete(entry.fromDB.id);
this.pendingSessions.delete(entry.fromDB.id);
}
}
await window.Signal.Data.removeSessionsByConversation(identifier);
});
}
private async _archiveSession(
entry?: CacheEntryType<SessionType, SessionRecord>
) {
private async _archiveSession(entry?: SessionCacheEntry) {
if (!entry) {
return;
}
@ -796,74 +949,87 @@ export class SignalProtocolStore extends EventsMixin {
}
async archiveSession(encodedAddress: string): Promise<void> {
if (!this.sessions) {
throw new Error('archiveSession: this.sessions not yet cached!');
}
return this.sessionTransaction('archiveSession', async () => {
if (!this.sessions) {
throw new Error('archiveSession: this.sessions not yet cached!');
}
window.log.info(`archiveSession: session for ${encodedAddress}`);
window.log.info(`archiveSession: session for ${encodedAddress}`);
const id = await normalizeEncodedAddress(encodedAddress);
const entry = this.sessions.get(id);
const id = await normalizeEncodedAddress(encodedAddress);
await this._archiveSession(entry);
const entry = this.pendingSessions.get(id) || this.sessions.get(id);
await this._archiveSession(entry);
});
}
async archiveSiblingSessions(encodedAddress: string): Promise<void> {
if (!this.sessions) {
throw new Error('archiveSiblingSessions: this.sessions not yet cached!');
}
return this.sessionTransaction('archiveSiblingSessions', async () => {
if (!this.sessions) {
throw new Error(
'archiveSiblingSessions: this.sessions not yet cached!'
);
}
window.log.info(
'archiveSiblingSessions: archiving sibling sessions for',
encodedAddress
);
window.log.info(
'archiveSiblingSessions: archiving sibling sessions for',
encodedAddress
);
const id = await normalizeEncodedAddress(encodedAddress);
const [identifier, deviceId] = window.textsecure.utils.unencodeNumber(id);
const deviceIdNumber = parseInt(deviceId, 10);
const id = await normalizeEncodedAddress(encodedAddress);
const [identifier, deviceId] = window.textsecure.utils.unencodeNumber(id);
const deviceIdNumber = parseInt(deviceId, 10);
const allEntries = Array.from(this.sessions.values());
const entries = allEntries.filter(
entry =>
entry.fromDB.conversationId === identifier &&
entry.fromDB.deviceId !== deviceIdNumber
);
const allEntries = this._getAllSessions();
const entries = allEntries.filter(
entry =>
entry.fromDB.conversationId === identifier &&
entry.fromDB.deviceId !== deviceIdNumber
);
await Promise.all(
entries.map(async entry => {
await this._archiveSession(entry);
})
);
await Promise.all(
entries.map(async entry => {
await this._archiveSession(entry);
})
);
});
}
async archiveAllSessions(identifier: string): Promise<void> {
if (!this.sessions) {
throw new Error('archiveAllSessions: this.sessions not yet cached!');
}
return this.sessionTransaction('archiveAllSessions', async () => {
if (!this.sessions) {
throw new Error('archiveAllSessions: this.sessions not yet cached!');
}
window.log.info(
'archiveAllSessions: archiving all sessions for',
identifier
);
window.log.info(
'archiveAllSessions: archiving all sessions for',
identifier
);
const id = window.ConversationController.getConversationId(identifier);
const allEntries = Array.from(this.sessions.values());
const entries = allEntries.filter(
entry => entry.fromDB.conversationId === id
);
const id = window.ConversationController.getConversationId(identifier);
await Promise.all(
entries.map(async entry => {
await this._archiveSession(entry);
})
);
const allEntries = this._getAllSessions();
const entries = allEntries.filter(
entry => entry.fromDB.conversationId === id
);
await Promise.all(
entries.map(async entry => {
await this._archiveSession(entry);
})
);
});
}
async clearSessionStore(): Promise<void> {
if (this.sessions) {
this.sessions.clear();
}
window.Signal.Data.removeAllSessions();
return this.sessionTransaction('clearSessionStore', async () => {
if (this.sessions) {
this.sessions.clear();
}
this.pendingSessions.clear();
await window.Signal.Data.removeAllSessions();
});
}
// Identity Keys
@ -1403,56 +1569,92 @@ export class SignalProtocolStore extends EventsMixin {
// Not yet processed messages - for resiliency
getUnprocessedCount(): Promise<number> {
return window.Signal.Data.getUnprocessedCount();
return this.sessionTransaction('getUnprocessedCount', async () => {
this._checkNoPendingUnprocessed();
return window.Signal.Data.getUnprocessedCount();
});
}
getAllUnprocessed(): Promise<Array<UnprocessedType>> {
return window.Signal.Data.getAllUnprocessed();
return this.sessionTransaction('getAllUnprocessed', async () => {
this._checkNoPendingUnprocessed();
return window.Signal.Data.getAllUnprocessed();
});
}
getUnprocessedById(id: string): Promise<UnprocessedType | undefined> {
return window.Signal.Data.getUnprocessedById(id);
}
addUnprocessed(data: UnprocessedType): Promise<string> {
// We need to pass forceSave because the data has an id already, which will cause
// an update instead of an insert.
return window.Signal.Data.saveUnprocessed(data, {
forceSave: true,
return this.sessionTransaction('getUnprocessedById', async () => {
this._checkNoPendingUnprocessed();
return window.Signal.Data.getUnprocessedById(id);
});
}
addMultipleUnprocessed(array: Array<UnprocessedType>): Promise<void> {
// We need to pass forceSave because the data has an id already, which will cause
// an update instead of an insert.
return window.Signal.Data.saveUnprocesseds(array, {
forceSave: true,
});
addUnprocessed(
data: UnprocessedType,
{ lock }: SessionTransactionOptions = {}
): Promise<void> {
return this.sessionTransaction(
'addUnprocessed',
async () => {
this.pendingUnprocessed.set(data.id, data);
},
lock
);
}
addMultipleUnprocessed(
array: Array<UnprocessedType>,
{ lock }: SessionTransactionOptions = {}
): Promise<void> {
return this.sessionTransaction(
'addMultipleUnprocessed',
async () => {
for (const elem of array) {
this.pendingUnprocessed.set(elem.id, elem);
}
},
lock
);
}
updateUnprocessedAttempts(id: string, attempts: number): Promise<void> {
return window.Signal.Data.updateUnprocessedAttempts(id, attempts);
return this.sessionTransaction('updateUnprocessedAttempts', async () => {
this._checkNoPendingUnprocessed();
await window.Signal.Data.updateUnprocessedAttempts(id, attempts);
});
}
updateUnprocessedWithData(
id: string,
data: UnprocessedUpdateType
): Promise<void> {
return window.Signal.Data.updateUnprocessedWithData(id, data);
return this.sessionTransaction('updateUnprocessedWithData', async () => {
this._checkNoPendingUnprocessed();
await window.Signal.Data.updateUnprocessedWithData(id, data);
});
}
updateUnprocessedsWithData(
items: Array<{ id: string; data: UnprocessedUpdateType }>
): Promise<void> {
return window.Signal.Data.updateUnprocessedsWithData(items);
return this.sessionTransaction('updateUnprocessedsWithData', async () => {
this._checkNoPendingUnprocessed();
await window.Signal.Data.updateUnprocessedsWithData(items);
});
}
removeUnprocessed(idOrArray: string | Array<string>): Promise<void> {
return window.Signal.Data.removeUnprocessed(idOrArray);
return this.sessionTransaction('removeUnprocessed', async () => {
this._checkNoPendingUnprocessed();
await window.Signal.Data.removeUnprocessed(idOrArray);
});
}
removeAllUnprocessed(): Promise<void> {
return window.Signal.Data.removeAllUnprocessed();
return this.sessionTransaction('removeAllUnprocessed', async () => {
this._checkNoPendingUnprocessed();
await window.Signal.Data.removeAllUnprocessed();
});
}
async removeAllData(): Promise<void> {
@ -1473,6 +1675,30 @@ export class SignalProtocolStore extends EventsMixin {
window.storage.reset();
await window.storage.fetch();
}
private _getAllSessions(): Array<SessionCacheEntry> {
const union = new Map<string, SessionCacheEntry>();
this.sessions?.forEach((value, key) => {
union.set(key, value);
});
this.pendingSessions.forEach((value, key) => {
union.set(key, value);
});
return Array.from(union.values());
}
private _checkNoPendingUnprocessed(): void {
assert(
!this.sessionLock || this.sessionLock === GLOBAL_LOCK,
"Can't use this function with a global lock"
);
assert(
this.pendingUnprocessed.size === 0,
'Missing support for pending unprocessed'
);
}
}
window.SignalProtocolStore = SignalProtocolStore;