diff --git a/ts/ConversationController.ts b/ts/ConversationController.ts index 36dc7f98a135..bb34d5205b64 100644 --- a/ts/ConversationController.ts +++ b/ts/ConversationController.ts @@ -1062,7 +1062,9 @@ export class ConversationController { log.warn(`${logId}: Delete all sessions tied to old conversationId`); // Note: we use the conversationId here in case we've already lost our uuid. - await window.textsecure.storage.protocol.removeAllSessions(obsoleteId); + await window.textsecure.storage.protocol.removeSessionsByConversation( + obsoleteId + ); log.warn( `${logId}: Delete all identity information tied to old conversationId` diff --git a/ts/SignalProtocolStore.ts b/ts/SignalProtocolStore.ts index 7d7c03db32b2..231abe315f5a 100644 --- a/ts/SignalProtocolStore.ts +++ b/ts/SignalProtocolStore.ts @@ -1226,35 +1226,68 @@ export class SignalProtocolStore extends EventEmitter { }); } - async removeAllSessions(identifier: string): Promise { - return this.withZone(GLOBAL_ZONE, 'removeAllSessions', async () => { + async removeSessionsByConversation(identifier: string): Promise { + return this.withZone( + GLOBAL_ZONE, + 'removeSessionsByConversation', + async () => { + if (!this.sessions) { + throw new Error( + 'removeSessionsByConversation: this.sessions not yet cached!' + ); + } + + if (identifier == null) { + throw new Error( + 'removeSessionsByConversation: identifier was undefined/null' + ); + } + + log.info( + 'removeSessionsByConversation: deleting sessions for', + identifier + ); + + const id = window.ConversationController.getConversationId(identifier); + strictAssert( + id, + `removeSessionsByConversation: Conversation not found: ${identifier}` + ); + + const entries = Array.from(this.sessions.values()); + + for (let i = 0, max = entries.length; i < max; i += 1) { + const entry = entries[i]; + if (entry.fromDB.conversationId === id) { + this.sessions.delete(entry.fromDB.id); + this.pendingSessions.delete(entry.fromDB.id); + } + } + + await window.Signal.Data.removeSessionsByConversation(id); + } + ); + } + + async removeSessionsByUUID(uuid: UUIDStringType): Promise { + return this.withZone(GLOBAL_ZONE, 'removeSessionsByUUID', async () => { if (!this.sessions) { - throw new Error('removeAllSessions: this.sessions not yet cached!'); + throw new Error('removeSessionsByUUID: this.sessions not yet cached!'); } - if (identifier == null) { - throw new Error('removeAllSessions: identifier was undefined/null'); - } - - log.info('removeAllSessions: deleting sessions for', identifier); - - const id = window.ConversationController.getConversationId(identifier); - strictAssert( - id, - `removeAllSessions: Conversation not found: ${identifier}` - ); + log.info('removeSessionsByUUID: deleting sessions for', uuid); const entries = Array.from(this.sessions.values()); for (let i = 0, max = entries.length; i < max; i += 1) { const entry = entries[i]; - if (entry.fromDB.conversationId === id) { + if (entry.fromDB.uuid === uuid) { this.sessions.delete(entry.fromDB.id); this.pendingSessions.delete(entry.fromDB.id); } } - await window.Signal.Data.removeSessionsByConversation(id); + await window.Signal.Data.removeSessionsByUUID(uuid); }); } @@ -1961,10 +1994,7 @@ export class SignalProtocolStore extends EventEmitter { return false; } - async removeIdentityKey( - uuid: UUID, - options?: { disableSessionDeletion: boolean } - ): Promise { + async removeIdentityKey(uuid: UUID): Promise { if (!this.identityKeys) { throw new Error('removeIdentityKey: this.identityKeys not yet cached!'); } @@ -1972,9 +2002,7 @@ export class SignalProtocolStore extends EventEmitter { const id = uuid.toString(); this.identityKeys.delete(id); await window.Signal.Data.removeIdentityKeyById(id); - if (!options?.disableSessionDeletion) { - await this.removeAllSessions(id); - } + await this.removeSessionsByUUID(id); } // Not yet processed messages - for resiliency diff --git a/ts/models/conversations.ts b/ts/models/conversations.ts index a1a93101c8c6..463081c6541d 100644 --- a/ts/models/conversations.ts +++ b/ts/models/conversations.ts @@ -1966,14 +1966,7 @@ export class ConversationModel extends window.Backbone // for the case where we need to do old and new PNI comparisons. We'll wait // for the PNI update to do that. if (oldValue && oldValue !== this.get('pni')) { - // We've already changed our UUID, so we need account for lookups on that old UUID - // to returng nothing: pass conversationId into removeAllSessions, and disable - // auto-deletion in removeIdentityKey. - window.textsecure.storage.protocol.removeAllSessions(this.id); - window.textsecure.storage.protocol.removeIdentityKey( - UUID.cast(oldValue), - { disableSessionDeletion: true } - ); + window.textsecure.storage.protocol.removeIdentityKey(UUID.cast(oldValue)); } this.captureChange('updateUuid'); @@ -2059,14 +2052,7 @@ export class ConversationModel extends window.Backbone // If this PNI is going away or going to someone else, we'll delete all its sessions if (oldValue) { - // We've already changed our UUID, so we need account for lookups on that old UUID - // to returng nothing: pass conversationId into removeAllSessions, and disable - // auto-deletion in removeIdentityKey. - window.textsecure.storage.protocol.removeAllSessions(this.id); - window.textsecure.storage.protocol.removeIdentityKey( - UUID.cast(oldValue), - { disableSessionDeletion: true } - ); + window.textsecure.storage.protocol.removeIdentityKey(UUID.cast(oldValue)); } if (pni && !this.get('uuid')) { diff --git a/ts/sql/Interface.ts b/ts/sql/Interface.ts index 7cc0c0d42e69..1e99a29f04fe 100644 --- a/ts/sql/Interface.ts +++ b/ts/sql/Interface.ts @@ -425,6 +425,7 @@ export type DataInterface = { bulkAddSessions: (array: Array) => Promise; removeSessionById: (id: SessionIdType) => Promise; removeSessionsByConversation: (conversationId: string) => Promise; + removeSessionsByUUID: (uuid: UUIDStringType) => Promise; removeAllSessions: () => Promise; getAllSessions: () => Promise>; diff --git a/ts/sql/Server.ts b/ts/sql/Server.ts index 0c3a874c845d..cfccde2db703 100644 --- a/ts/sql/Server.ts +++ b/ts/sql/Server.ts @@ -200,6 +200,7 @@ const dataInterface: ServerInterface = { bulkAddSessions, removeSessionById, removeSessionsByConversation, + removeSessionsByUUID, removeAllSessions, getAllSessions, @@ -1308,6 +1309,17 @@ async function removeSessionsByConversation( conversationId, }); } +async function removeSessionsByUUID(uuid: UUIDStringType): Promise { + const db = getInstance(); + db.prepare( + ` + DELETE FROM sessions + WHERE uuid = $uuid; + ` + ).run({ + uuid, + }); +} async function removeAllSessions(): Promise { return removeAllFromTable(getInstance(), SESSIONS_TABLE); } diff --git a/ts/test-electron/SignalProtocolStore_test.ts b/ts/test-electron/SignalProtocolStore_test.ts index cabbf7f80b13..ba77e478332e 100644 --- a/ts/test-electron/SignalProtocolStore_test.ts +++ b/ts/test-electron/SignalProtocolStore_test.ts @@ -995,7 +995,7 @@ describe('SignalProtocolStore', () => { assert.equal(record, testRecord); }); }); - describe('removeAllSessions', () => { + describe('removeSessionsByUUID', () => { it('removes all sessions for a uuid', async () => { const devices = [1, 2, 3].map( deviceId => @@ -1008,14 +1008,72 @@ describe('SignalProtocolStore', () => { }) ); - await store.removeAllSessions(theirUuid.toString()); + const records0 = await Promise.all( + devices.map(device => store.loadSession(device)) + ); + for (let i = 0, max = records0.length; i < max; i += 1) { + assert.exists(records0[i], 'before delete'); + } + + await store.removeSessionsByUUID(theirUuid.toString()); const records = await Promise.all( devices.map(device => store.loadSession(device)) ); - for (let i = 0, max = records.length; i < max; i += 1) { - assert.isUndefined(records[i]); + assert.isUndefined(records[i], 'in-memory'); + } + + await store.hydrateCaches(); + + const records2 = await Promise.all( + devices.map(device => store.loadSession(device)) + ); + for (let i = 0, max = records2.length; i < max; i += 1) { + assert.isUndefined(records2[i], 'from database'); + } + }); + }); + describe('removeSessionsByConversation', () => { + it('removes all sessions for a uuid', async () => { + const devices = [1, 2, 3].map( + deviceId => + new QualifiedAddress(ourUuid, new Address(theirUuid, deviceId)) + ); + const conversationId = window.ConversationController.getOrCreate( + theirUuid.toString(), + 'private' + ).id; + + await Promise.all( + devices.map(async encodedAddress => { + await store.storeSession(encodedAddress, getSessionRecord()); + }) + ); + + const records0 = await Promise.all( + devices.map(device => store.loadSession(device)) + ); + for (let i = 0, max = records0.length; i < max; i += 1) { + assert.exists(records0[i], 'before delete'); + } + + await store.removeSessionsByConversation(conversationId); + + const records = await Promise.all( + devices.map(device => store.loadSession(device)) + ); + for (let i = 0, max = records.length; i < max; i += 1) { + assert.isUndefined(records[i], 'in-memory'); + } + + await store.hydrateCaches(); + + const records2 = await Promise.all( + devices.map(device => store.loadSession(device)) + ); + for (let i = 0, max = records2.length; i < max; i += 1) { + assert.isUndefined(records[i], 'from database'); } }); }); @@ -1145,7 +1203,7 @@ describe('SignalProtocolStore', () => { beforeEach(async () => { await store.removeAllUnprocessed(); - await store.removeAllSessions(theirUuid.toString()); + await store.removeSessionsByUUID(theirUuid.toString()); await store.removeAllSenderKeys(); });