diff --git a/package.json b/package.json index 64ece4c233a..d6254aa1514 100644 --- a/package.json +++ b/package.json @@ -82,6 +82,7 @@ "@babel/register": "^7.12.10", "@casualbot/jest-sonar-reporter": "^2.2.5", "@matrix-org/olm": "https://gitlab.matrix.org/api/v4/projects/27/packages/npm/@matrix-org/olm/-/@matrix-org/olm-3.2.14.tgz", + "@matrix-org/matrix-dmls-wasm": "https://gitlab.matrix.org/api/v4/projects/876/packages/npm/@matrix-org/matrix-dmls-wasm/-/@matrix-org/matrix-dmls-wasm-0.0.4.tgz", "@types/bs58": "^4.0.1", "@types/content-type": "^1.1.5", "@types/debug": "^4.1.7", diff --git a/src/client.ts b/src/client.ts index 8f861892a69..b56e708eeca 100644 --- a/src/client.ts +++ b/src/client.ts @@ -21,6 +21,7 @@ limitations under the License. import { Optional } from "matrix-events-sdk"; import type { IDeviceKeys, IMegolmSessionData, IOneTimeKey } from "./@types/crypto"; +import type { IMlsSessionData } from "./crypto/algorithms/dmls"; import { ISyncStateData, SyncApi, SyncApiOptions, SyncState } from "./sync"; import { EventStatus, @@ -3105,7 +3106,7 @@ export class MatrixClient extends TypedEventEmitter { + public importRoomKeys(keys: IMlsSessionData[], opts?: IImportRoomKeysOpts): Promise { if (!this.crypto) { throw new Error("End-to-end encryption disabled"); } @@ -3640,7 +3641,7 @@ export class MatrixClient extends TypedEventEmitter { + public async importRoomKey(session: IMlsSessionData, opts: object): Promise { // ignore by default } diff --git a/src/crypto/algorithms/dmls.ts b/src/crypto/algorithms/dmls.ts new file mode 100644 index 00000000000..33ff217e6ec --- /dev/null +++ b/src/crypto/algorithms/dmls.ts @@ -0,0 +1,968 @@ +/* +Copyright 2023 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/** + * Defines m.dmls encryption/decryption + */ + +import anotherjson from "another-json"; +import { + EventType +} from "../../@types/event"; +import { + DecryptionAlgorithm, + DecryptionError, + EncryptionAlgorithm, + registerAlgorithm, +} from "./base"; +import type { + IImportRoomKeysOpts, +} from "../api"; +import { Room } from "../../models/room"; +import { IContent, MatrixEvent } from "../../models/event"; +import { Crypto, IEncryptedContent, IEventDecryptionResult } from ".."; +import { UnstableValue } from "../../NamespacedValue"; +import * as matrixDmls from "@matrix-org/matrix-dmls-wasm"; +import * as olmlib from "../olmlib"; + +export const MLS_ALGORITHM = new UnstableValue( + "m.dmls.v1.dhkemx25519-aes128gcm-sha256-ed25519", + "org.matrix.msc2883.v0.dmls.dhkemx25519-aes128gcm-sha256-ed25519", +); +export const INIT_KEY_ALGORITHM = new UnstableValue( + "m.dmls.v1.key_package.dhkemx25519-aes128gcm-sha256-ed25519", + "org.matrix.msc2883.v0.dmls.key_package.dhkemx25519-aes128gcm-sha256-ed25519", +); +export const WELCOME_PACKAGE = new UnstableValue( + "m.dmls.v1.welcome.dhkemx25519-aes128gcm-sha256-ed25519", + "org.matrix.msc2883.v0.dmls.welcome.dhkemx25519-aes128gcm-sha256-ed25519", +); + +/* eslint-disable camelcase */ + +export interface IMlsSessionData { + room_id: string; + epoch: [number, string]; + group_export: string; + algorithm?: string; + untrusted?: boolean; +} + +/* eslint-enable camelcase */ + +let textEncoder = new TextEncoder(); +let textDecoder = new TextDecoder("utf-8", {fatal: true}); + +class MlsEncryption extends EncryptionAlgorithm { + public async encryptMessage(room: Room, eventType: string, content: IContent): Promise { + const mlsProvider = this.crypto.mlsProvider; + if (!this.roomId) { + console.error("MLS Error: No room ID") + throw "No room ID"; + } + let group = mlsProvider.getGroup(this.roomId); + if (!group || !group.is_joined()) { + const timeline = room.getLiveTimeline(); + const events = timeline.getEvents(); + events.reverse(); + let publicGroupStateEvent: MatrixEvent | undefined; + for (const event of events) { + if (event.getWireType() == "m.room.encrypted") { + const contents = event.getWireContent(); + if (contents.algorithm == MLS_ALGORITHM.name && + "public_group_state" in contents && + "sender" in contents) { + publicGroupStateEvent = event; + break; + } + } + } + // FIXME: search for more events if we still don't have public state + // FIXME: search for public group state again if the join fails + if (publicGroupStateEvent) { + const publicGroupStateContents = publicGroupStateEvent.getWireContent(); + const [joinedGroup, message] = mlsProvider.joinByExternalCommit( + publicGroupStateContents.public_group_state, + this.roomId, + publicGroupStateEvent.getId()!, + ); + + const senderB64 = olmlib.encodeUnpaddedBase64(joinId(this.userId, this.deviceId)); + group = joinedGroup; + const publicGroupState = group.public_group_state(mlsProvider.backend!); + const publicGroupStateB64 = olmlib.encodeUnpaddedBase64(Uint8Array.from(publicGroupState)); + const {event_id: eventId} = await this.baseApis.sendEvent(this.roomId, "m.room.encrypted", { + algorithm: MLS_ALGORITHM.name, + ciphertext: olmlib.encodeUnpaddedBase64(message), + epoch_creator: publicGroupStateContents.sender, + sender: senderB64, + resolves: [], + public_group_state: publicGroupStateB64, + commit_event: publicGroupStateEvent.getId(), + }); + mlsProvider.addEpochEvent(group, this.roomId, eventId); + } + } + if (!group) { + console.error("MLS error: No group available"); + throw "No group available"; + } + + // check if membership needs syncing, if group needs resolving + const members = await room.getEncryptionTargetMembers(); + const roomMembers = members.map(function (u) { + return u.userId; + }); + const devices = await this.crypto.downloadKeys(roomMembers, false); + // FIXME: remove blocked devices + + const memberMap: Map> = new Map(); + + for (const [userId, userDevices] of Object.entries(devices)) { + memberMap.set(userId, new Set(Object.keys(userDevices))); + } + + mlsProvider.syncMembers(this.roomId, memberMap); + + if (group.has_changes() || group.needs_resolve()) { + console.log("[MLS] has changes/needs resolve", group.has_changes(), group.needs_resolve()); + const [commit, baseEpochNum, baseEpochCreator, resolves, welcomeInfo] = await group.resolve(mlsProvider.backend!); + const [epochNum, epochCreator] = group.epoch(); + // don't wait for it to complete + this.crypto.backupManager.backupGroupSession(this.roomId, epochNum, olmlib.encodeUnpaddedBase64(epochCreator)); + + const creatorB64 = olmlib.encodeUnpaddedBase64(Uint8Array.from(baseEpochCreator)); + const senderB64 = olmlib.encodeUnpaddedBase64(joinId(this.userId, this.deviceId)); + + // FIXME: check if external commits are allowed + const publicGroupState = group.public_group_state(mlsProvider.backend!); + const publicGroupStateB64 = olmlib.encodeUnpaddedBase64(Uint8Array.from(publicGroupState)); + // FIXME: should we store public group state in media repo instead? + + const baseEventId = mlsProvider.getEpochEvent( + this.roomId, BigInt(baseEpochNum), creatorB64, + ); + + const {event_id: eventId} = await this.baseApis.sendEvent(this.roomId, "m.room.encrypted", { + algorithm: MLS_ALGORITHM.name, + ciphertext: olmlib.encodeUnpaddedBase64(Uint8Array.from(commit)), + epoch_creator: creatorB64, + sender: senderB64, + resolves: resolves.map(([epochNum, creator]: [number, number[]]) => { + return [epochNum, olmlib.encodeUnpaddedBase64(Uint8Array.from(creator))]; + }), + public_group_state: publicGroupStateB64, + commit_event: baseEventId, + }); + mlsProvider.addEpochEvent(group, this.roomId, eventId); + + if (welcomeInfo) { + const [welcome, adds] = welcomeInfo; + + const welcomeB64 = olmlib.encodeUnpaddedBase64(Uint8Array.from(welcome)); + + const contentMap: Record> = {}; + + const payload = { + algorithm: WELCOME_PACKAGE.name, + ciphertext: welcomeB64, + sender: senderB64, + room_id: room.roomId, + resolves: resolves.map(([epochNum, creator]: [number, number[]]) => { + return [epochNum, olmlib.encodeUnpaddedBase64(Uint8Array.from(creator))]; + }), + commit_event: eventId, + } + + for (const user of adds) { + try { + const [userId, deviceId] = splitId(user); + if (!(userId in contentMap)) { + contentMap[userId] = {}; + } + contentMap[userId][deviceId] = payload; + } catch (e) { + console.error("[MLS] Unable to add user", user, e); + } + } + + await this.baseApis.sendToDevice("m.room.encrypted", contentMap); + } + } + + const payload = textEncoder.encode(JSON.stringify({ + room_id: this.roomId, + type: eventType, + content: content, + })); + const [ciphertext, baseEpochNum, baseEpochCreator] = group.encrypt_message(mlsProvider.backend!, payload); + const creatorB64 = olmlib.encodeUnpaddedBase64(Uint8Array.from(baseEpochCreator)); + const baseEventId = mlsProvider.getEpochEvent( + this.roomId, BigInt(baseEpochNum), creatorB64, + )!; + return { + algorithm: MLS_ALGORITHM.name, + ciphertext: olmlib.encodeUnpaddedBase64(Uint8Array.from(ciphertext)), + epoch_creator: creatorB64, + commit_event: baseEventId, + } + } +} + +class MlsDecryption extends DecryptionAlgorithm { + private pendingEvents = new Map>>(); + private pendingBackfills = new Map>(); + + public async decryptEvent(event: MatrixEvent): Promise { + const content = event.getWireContent(); + if (typeof(content.ciphertext) !== "string" || typeof(content.epoch_creator) !== "string") { + throw new DecryptionError("MLS_MISSING_FIELDS", "Missing or invalid fields in input"); + } + if (content.ciphertext === "") { + // probably the initial commit + return { + clearEvent: { + type: "io.element.mls.internal", + content: {"body": "This is an MLS handshake message, so there's nothing useful to see here."}, + }, + }; + } + const mlsProvider = this.crypto.mlsProvider; + if (!this.roomId) { + throw "No room ID"; + } + const group = mlsProvider.getGroup(this.roomId); + const mlsMessage = new matrixDmls.MlsMessageIn(olmlib.decodeBase64(content.ciphertext)); + const isHandshake = mlsMessage.is_handshake_message; + const epochNumber = mlsMessage.epoch; + if (!group) { + this.addEventToPendingList(event, epochNumber, content.epoch_creator); + if (isHandshake) { + return { + clearEvent: { + type: "io.element.mls.internal", + content: {"body": "This is an MLS handshake message, so there's nothing useful to see here."}, + }, + }; + } + throw "No group available"; + } + const epochCreator = olmlib.decodeBase64(content.epoch_creator); + let unverifiedMessage; + try { + unverifiedMessage = group.parse_message( + mlsMessage, + epochCreator, + mlsProvider.backend!, + ); + } catch (e) { + console.log("Adding to pending:", epochNumber, content.epoch_creator); + this.addEventToPendingList(event, epochNumber, content.epoch_creator); + if (e == "Epoch not found") { + const parentCommit = content.commit_event; + if (parentCommit) { + this.backfillParent(parentCommit); + } + } + if (isHandshake) { + return { + clearEvent: { + type: "io.element.mls.internal", + content: {"body": "This is an MLS handshake message, so there's nothing useful to see here."}, + }, + }; + } + throw e; + } + let processedMessage; + try { + processedMessage = group.process_unverified_message( + unverifiedMessage, + epochCreator, + mlsProvider.backend!, + ); + } catch (e) { + if (isHandshake) { + return { + clearEvent: { + type: "io.element.mls.internal", + content: {"body": "This is an MLS handshake message, so there's nothing useful to see here."}, + }, + }; + } + throw e; + } + this.removeEventFromPendingList(event, epochNumber, content.epoch_creator); + if (processedMessage.is_application_message()) { + const messageArr = processedMessage.as_application_message(); + const clearEvent = JSON.parse(textDecoder.decode(Uint8Array.from(messageArr))); + if (typeof(clearEvent.room_id) !== "string" || + typeof(clearEvent.type) !== "string" || + typeof(clearEvent.content) !== "object") { + throw new DecryptionError("MLS_MISSING_FIELDS", "Missing or invalid fields in plaintext"); + } + return { + clearEvent + } + } else if (processedMessage.is_staged_commit()) { + if (typeof(content.sender) !== "string" || !Array.isArray(content.resolves)) { + throw new DecryptionError("MLS_MISSING_FIELDS", "Missing or invalid fields in cleartext"); + } + const sender = olmlib.decodeBase64(content.sender); + const resolves = content.resolves.map(([epochNum, creatorB64]: [number, string]) => { + return [epochNum, olmlib.decodeBase64(creatorB64)]; + }); + const commit = processedMessage.as_staged_commit(); + const [newEpochNum, newEpochCreator] = group.merge_staged_commit( + commit, epochNumber, epochCreator, + sender, resolves, + mlsProvider.backend!, + ); + this.retryDecryption(newEpochNum, olmlib.encodeUnpaddedBase64(newEpochCreator)); + // don't wait for it to complete + this.crypto.backupManager.backupGroupSession(this.roomId, newEpochNum, olmlib.encodeUnpaddedBase64(newEpochCreator)); + return { + clearEvent: { + type: "io.element.mls.internal", + content: {"body": "This is an MLS handshake message, so there's nothing useful to see here."}, + }, + }; + } else { + throw new DecryptionError("MLS_UNKNOWN_TYPE", "Unknown MLS message type"); + } + } + + public async importRoomKey(key: IMlsSessionData, opts: IImportRoomKeysOpts): Promise { + if (key.group_export) { + const mlsProvider = this.crypto.mlsProvider; + const [epochNumber, epochCreator] = key.epoch; + mlsProvider.importGroupData(this.roomId!, epochNumber, epochCreator, key.group_export); + this.retryDecryption(epochNumber, epochCreator); + } + } + + private async backfillParent(commitEventId: string): Promise { + if (!this.pendingBackfills.has(commitEventId)) { + this.pendingBackfills.set(commitEventId, (async () => { + try { + const event = await this.baseApis.fetchRoomEvent(this.roomId!, commitEventId); + if (event.type != "m.room.encrypted") { + return; + } + + const matrixEvent = new MatrixEvent(event); + await matrixEvent.attemptDecryption(this.crypto); + } finally { + this.pendingBackfills.delete(commitEventId); + } + })()); + } + return this.pendingBackfills.get(commitEventId)!; + } + + /** + * Add an event to the list of those awaiting their session keys. + * + * @internal + * + */ + private addEventToPendingList( + event: MatrixEvent, + epochNumber: BigInt, + epochCreator: string, + ): void { + if (!this.pendingEvents.has(epochNumber)) { + this.pendingEvents.set(epochNumber, new Map>()); + } + const epochNumPendingEvents = this.pendingEvents.get(epochNumber)!; + if (!epochNumPendingEvents.has(epochCreator)) { + epochNumPendingEvents.set(epochCreator, new Set()); + } + epochNumPendingEvents.get(epochCreator)!.add(event); + } + + /** + * Remove an event from the list of those awaiting their session keys. + * + * @internal + * + */ + private removeEventFromPendingList( + event: MatrixEvent, + epochNumber: BigInt, + epochCreator: string, + ): void { + const epochNumPendingEvents = this.pendingEvents.get(epochNumber); + const pendingEvents = epochNumPendingEvents?.get(epochCreator); + if (!pendingEvents) { + return; + } + + pendingEvents.delete(event); + if (pendingEvents.size === 0) { + epochNumPendingEvents!.delete(epochCreator); + } + if (epochNumPendingEvents!.size === 0) { + this.pendingEvents.delete(epochNumber); + } + } + + private async retryDecryption( + epochNumber: number, + epochCreator: string, + ): Promise { + const pending = this.pendingEvents.get(BigInt(epochNumber))?.get(epochCreator); + if (!pending) { + return true; + } + + const pendingList = [...pending]; + console.debug( + "Retrying decryption on events:", + pendingList.map((e) => `${e.getId()}`), + ); + + await Promise.all( + pendingList.map(async (ev) => { + try { + await ev.attemptDecryption(this.crypto, { isRetry: true }); + } catch (e) { + // don't die if something goes wrong + } + }), + ); + + // If decrypted successfully with trusted keys, they'll have + // been removed from pendingEvents + return !this.pendingEvents.get(BigInt(epochNumber))?.has(epochCreator); + } +} + +class WelcomeEncryption extends EncryptionAlgorithm { + public async encryptMessage(room: Room, eventType: string, content: IContent): Promise { + throw new Error("Encrypt not supported for welcome message"); + } +} + +class WelcomeDecryption extends DecryptionAlgorithm { + public async decryptEvent(event: MatrixEvent): Promise { + const content = event.getWireContent(); + console.log("[MLS] Got welcome", content); + // FIXME: check that it's a to-device event + if (typeof(content.ciphertext) !== "string" || + typeof(content.sender) !== "string" || + !Array.isArray(content.resolves)) { + throw new DecryptionError("MLS_WELCOME_MISSING_FIELDS", "Missing or invalid fields in input"); + } + this.crypto.mlsProvider.processWelcome( + content.ciphertext, + content.sender, + content.resolves, + content.commit_event, + ); + // welcome packages don't have any visible representation and don't get + // processed further + return { + clearEvent: { + type: "m.dummy", + content: {}, + } + } + } +} + +function joinId(userId: string, deviceId: string): Uint8Array { + return textEncoder.encode(userId + "|" + deviceId); +} + +function splitId(id: Uint8Array | number[]): [string, string] { + const userStr = textDecoder.decode(id instanceof Uint8Array ? id : Uint8Array.from(id)); + // FIXME: this will do the wrong thing if the device ID has a "|" + return userStr.split("|", 2) as [string, string]; +} + +export class MlsProvider { + private readonly groups: Map; + private readonly storage: Map; + private readonly members: Map>>; + private readonly epochMap: Map>>; + public backend?: matrixDmls.DmlsCryptoProvider; + public credential?: matrixDmls.Credential; + + constructor(public readonly crypto: Crypto) { + // FIXME: we should persist groups + this.groups = new Map(); + // FIXME: this should go in the cryptostorage + // FIXME: DmlsCryptoProvider should also use cryptostorage + this.storage = new Map(); + this.members = new Map(); + this.epochMap = new Map(); + } + + async init(): Promise { + await matrixDmls.initAsync(); + this.backend = new matrixDmls.DmlsCryptoProvider( + this.store.bind(this), + this.read.bind(this), + this.getInitKeys.bind(this), + ); + let baseApis = this.crypto.baseApis; + this.credential = new matrixDmls.Credential( + this.backend!, + joinId(baseApis.getUserId()!, baseApis.getDeviceId()!), + ); + } + + static keyToString([groupIdArr, epoch, creatorArr, historical]: [number[], number, number[], boolean]): string { + let groupId = new Uint8Array(groupIdArr); + let creator = new Uint8Array(creatorArr); + return olmlib.encodeUnpaddedBase64(groupId) + "|" + epoch + "|" + olmlib.encodeUnpaddedBase64(creator) + "|" + historical; + } + + store(key: [number[], number, number[], boolean], value: number[]): void { + this.storage.set(MlsProvider.keyToString(key), value); + } + + read(key: [number[], number, number[], boolean]): number[] | undefined { + return this.storage.get(MlsProvider.keyToString(key)); + } + + async getInitKeys(users: Uint8Array[]): Promise<(Uint8Array | undefined)[]> { + let baseApis = this.crypto.baseApis; + + if (users.length) { + const devicesToClaim: [string, string][] = users.map(splitId) + + const otks = await baseApis.claimOneTimeKeys(devicesToClaim, INIT_KEY_ALGORITHM.name); + + console.log("[MLS] InitKeys", otks); + + const keys: (Uint8Array | undefined)[] = []; + + for (const [user, device] of devicesToClaim) { + if (user in otks.one_time_keys && device in otks.one_time_keys[user]) { + const key = otks.one_time_keys[user][device]; + const initKeyB64 = Object.values(key)[0] as unknown as string; + const initKey = olmlib.decodeBase64(initKeyB64); + keys.push(initKey); + } else { + keys.push(undefined); + } + } + + return keys; + } else { + return []; + } + } + + async createGroup(room: Room, invite: string[]): Promise { + let baseApis = this.crypto.baseApis; + + const group = new matrixDmls.DmlsGroup(this.backend!, this.credential!, textEncoder.encode(room.roomId)) + this.groups.set(room.roomId, group); + + const [epochNum, epochCreator] = group.epoch(); + const epochCreatorB64 = olmlib.encodeUnpaddedBase64(epochCreator); + // don't wait for it to complete + this.crypto.backupManager.backupGroupSession(room.roomId, epochNum, epochCreatorB64); + + const userId = baseApis.getUserId()!; + const deviceMap = await this.crypto.deviceList.downloadKeys([userId].concat(invite), false); + delete deviceMap[userId][baseApis.getDeviceId()!]; + + let addedMembers = false; + const members: Map> = new Map(); + + for (const [user, devices] of Object.entries(deviceMap)) { + const memberDevices: Set = new Set(); + members.set(user, memberDevices); + for (const deviceId of Object.keys(devices)) { + addedMembers = true; + const mlsUser = joinId(user, deviceId); + group.add_member(mlsUser, this.backend!); + memberDevices.add(deviceId); + } + } + + members.get(userId)!.add(baseApis.getDeviceId()!); + this.members.set(room.roomId, members); + + const sender = joinId(baseApis.getUserId()!, baseApis.getDeviceId()!); + const senderB64 = olmlib.encodeUnpaddedBase64(sender); + + const createEvent = room.currentState.getStateEvents(EventType.RoomCreate, "")!; + + if (addedMembers) { + const [commit, _mlsEpoch, creator, resolves, welcomeInfo] = await group.resolve(this.backend!); + + const [epochNum, epochCreator] = group.epoch(); + const epochCreatorB64 = olmlib.encodeUnpaddedBase64(epochCreator); + // don't wait for it to complete + this.crypto.backupManager.backupGroupSession(room.roomId, epochNum, epochCreatorB64); + + const creatorB64 = olmlib.encodeUnpaddedBase64(Uint8Array.from(creator)); + + // FIXME: check if external commits are allowed + const publicGroupState = group.public_group_state(this.backend!); + const publicGroupStateB64 = olmlib.encodeUnpaddedBase64(Uint8Array.from(publicGroupState)); + + const {event_id: eventId} = await baseApis.sendEvent(room.roomId, "m.room.encrypted", { + algorithm: MLS_ALGORITHM.name, + ciphertext: olmlib.encodeUnpaddedBase64(Uint8Array.from(commit)), + epoch_creator: creatorB64, + sender: senderB64, + resolves: resolves.map(([epochNum, creator]: [number, number[]]) => { + return [epochNum, olmlib.encodeUnpaddedBase64(Uint8Array.from(creator))]; + }), + public_group_state: publicGroupStateB64, + commit_event: createEvent.getId(), + }); + + const roomEpochMap = new Map>(); + roomEpochMap.set(BigInt(epochNum), new Map([[epochCreatorB64, eventId]])); + this.epochMap.set(room.roomId, roomEpochMap); + + if (welcomeInfo) { + const [welcome, adds] = welcomeInfo; + + const welcomeB64 = olmlib.encodeUnpaddedBase64(Uint8Array.from(welcome)); + + const contentMap: Record> = {}; + + const payload = { + algorithm: WELCOME_PACKAGE.name, + ciphertext: welcomeB64, + sender: creatorB64, + room_id: room.roomId, + resolves: resolves.map(([epochNum, creator]: [number, number[]]) => { + return [epochNum, olmlib.encodeUnpaddedBase64(Uint8Array.from(creator))]; + }), + commit_event: eventId, + } + + for (const user of adds) { + try { + const [userId, deviceId] = splitId(user); + if (!(userId in contentMap)) { + contentMap[userId] = {}; + } + contentMap[userId][deviceId] = payload; + } catch (e) { + console.error("[MLS] Unable to add user", user, e); + } + } + + await baseApis.sendToDevice("m.room.encrypted", contentMap); + } + } else { + // FIXME: check if external commits are allowed + const publicGroupState = group.public_group_state(this.backend!); + const publicGroupStateB64 = olmlib.encodeUnpaddedBase64(Uint8Array.from(publicGroupState)); + + const {event_id: eventId} = await baseApis.sendEvent(room.roomId, "m.room.encrypted", { + algorithm: MLS_ALGORITHM.name, + ciphertext: "", + epoch_creator: senderB64, + sender: senderB64, + resolves: [], + public_group_state: publicGroupStateB64, + commit_event: createEvent.getId(), + }); + + const roomEpochMap = new Map>(); + roomEpochMap.set(BigInt(epochNum), new Map([[epochCreatorB64, eventId]])); + this.epochMap.set(room.roomId, roomEpochMap); + } + + return group; + } + + processWelcome( + welcomeB64: string, + creatorB64: string, + resolvesB64: [number, string][], + commitEvent: string, + ): void { + const welcome = olmlib.decodeBase64(welcomeB64); + const creator = olmlib.decodeBase64(creatorB64); + const resolves = resolvesB64.map(([epochNum, creatorB64]) => { + return [epochNum, olmlib.decodeBase64(creatorB64)]; + }); + const group = matrixDmls.DmlsGroup.new_from_welcome(this.backend!, welcome, creator); + const groupIdArr = group.group_id(); + const groupId = textDecoder.decode(groupIdArr); + console.log("[MLS] Welcome message for", groupId); + // FIXME: check that it's a valid room ID + + const [epochNum, epochCreator] = group.epoch(); + const epochCreatorB64 = olmlib.encodeUnpaddedBase64(epochCreator); + // don't wait for it to complete + this.crypto.backupManager.backupGroupSession(groupId, epochNum, epochCreatorB64); + + if (!this.epochMap.has(groupId)) { + this.epochMap.set(groupId, new Map()); + } + const roomEpochMap = this.epochMap.get(groupId)!; + if (!roomEpochMap.has(BigInt(epochNum))) { + roomEpochMap.set(BigInt(epochNum), new Map()); + } + roomEpochMap.get(BigInt(epochNum))!.set(epochCreatorB64, commitEvent); + + const oldGroup = this.groups.get(groupId); + + if (oldGroup) { + const joined = group.is_joined(); + oldGroup.add_epoch_from_new_group(this.backend!, group, resolves); + + if (!joined) { + const members: Map> = new Map(); + for (const member of group.members(this.backend!)) { + const [userId, deviceId] = splitId(member); + if (!members.has(userId)) { + members.set(userId, new Set()); + } + members.get(userId)!.add(deviceId); + } + + this.members.set(groupId, members); + } + } else { + this.groups.set(groupId, group); + + const members: Map> = new Map(); + for (const member of group.members(this.backend!)) { + const [userId, deviceId] = splitId(member); + if (!members.has(userId)) { + members.set(userId, new Set()); + } + members.get(userId)!.add(deviceId); + } + + this.members.set(groupId, members); + } + } + + joinByExternalCommit(publicGroupStateB64: string, roomId: string, commitEvent: string): [matrixDmls.DmlsGroup, Uint8Array] { + const publicGroupState = olmlib.decodeBase64(publicGroupStateB64); + const joinResult = matrixDmls.DmlsGroup.join_by_external_commit( + this.backend!, + publicGroupState, + this.credential!, + ); + const joinMsg = joinResult.message; + let group = joinResult.group; + const groupIdArr = group.group_id(); + const groupId = textDecoder.decode(groupIdArr); + if (groupId != roomId) { + throw "Group ID mismatch"; + } + + const oldGroup = this.groups.get(groupId); + if (oldGroup) { + oldGroup.add_epoch_from_new_group(this.backend!, group, []); + group = oldGroup; + } else { + this.groups.set(groupId, group); + } + + const [epochNum, epochCreator] = group.epoch(); + const epochCreatorB64 = olmlib.encodeUnpaddedBase64(epochCreator); + // don't wait for it to complete + this.crypto.backupManager.backupGroupSession(roomId, epochNum, epochCreatorB64); + + if (!this.epochMap.has(roomId)) { + this.epochMap.set(roomId, new Map()); + } + const roomEpochMap = this.epochMap.get(roomId)!; + if (!roomEpochMap.has(BigInt(epochNum))) { + roomEpochMap.set(BigInt(epochNum), new Map()); + } + roomEpochMap.get(BigInt(epochNum))!.set(epochCreatorB64, commitEvent); + + const members: Map> = new Map(); + for (const member of group.members(this.backend!)) { + const [userId, deviceId] = splitId(member); + if (!members.has(userId)) { + members.set(userId, new Set()); + } + members.get(userId)!.add(deviceId); + } + + this.members.set(groupId, members); + + return [group, joinMsg]; + } + + getGroup(roomId: string): matrixDmls.DmlsGroup | undefined { + return this.groups.get(roomId); + } + + addEpochEvent(group: matrixDmls.DmlsGroup, roomId: string, eventId: string): void { + const [epochNum, epochCreator] = group.epoch(); + const epochCreatorB64 = olmlib.encodeUnpaddedBase64(epochCreator); + + if (!this.epochMap.has(roomId)) { + this.epochMap.set(roomId, new Map()); + } + const roomEpochMap = this.epochMap.get(roomId)!; + if (!roomEpochMap.has(BigInt(epochNum))) { + roomEpochMap.set(BigInt(epochNum), new Map()); + } + roomEpochMap.get(BigInt(epochNum))!.set(epochCreatorB64, eventId); + } + + getEpochEvent(roomId: string, epochNum: BigInt, epochCreator: string): string | undefined { + return this.epochMap.get(roomId)?.get(epochNum)?.get(epochCreator); + } + + syncMembers(roomId: string, members: Map>): void { + /* Membership tracking: ideally, the way it would work is: + * + * - When we get a membership event in an encrypted group (join, leave, + * invite, etc.), then we mark the appropriate group adds/removes. + * (In the case of a join/invite, we need to get the user's devices, + * then add them all.) + * + * - We also store group membership by user -> groups. When we are + * notified that a user's devices have changed, we flag the user's + * groups a dirty. We will, at a later time, update the user's + * devices, and synchronize the device's membership. + * + * - We continue to receive and process incoming commits. + * + * - At a later time, we determine whether we need to send a commit, and + * do so if needed. + */ + const recordedMembers = this.members.get(roomId)!; + + console.log("[MLS] Syncing members", members, recordedMembers); + + // find out what devices have been added/removed + const adds: [string, string][] = []; + const removes: [string, string][] = []; + + for (const [userId, devices] of members.entries()) { + const recordedDevices = recordedMembers.get(userId); + if (recordedDevices) { + for (const deviceId of devices.values()) { + if (!recordedDevices.has(deviceId)) { + adds.push([userId, deviceId]); + } + } + for (const deviceId of recordedDevices.values()) { + if (!devices.has(deviceId)) { + removes.push([userId, deviceId]) + } + } + } else { + for (const deviceId of devices.values()) { + adds.push([userId, deviceId]); + } + } + } + + for (const [userId, devices] of recordedMembers.entries()) { + if (!members.has(userId)) { + for (const deviceId of devices.values()) { + removes.push([userId, deviceId]); + } + } + } + + console.log("[MLS] adds, removes", adds, removes); + + // sync up the group and recorded members + const group = this.groups.get(roomId)!; + + for (const [userId, deviceId] of adds) { + group.add_member(joinId(userId, deviceId), this.backend!); + if (!recordedMembers.has(userId)) { + recordedMembers.set(userId, new Set()); + } + recordedMembers.get(userId)!.add(deviceId); + } + + for (const [userId, deviceId] of removes) { + group.remove_member(joinId(userId, deviceId), this.backend!); + if (recordedMembers.has(userId)) { // should always be true, but be safe + const recordedDevices = recordedMembers.get(userId)!; + recordedDevices.delete(deviceId); + if (recordedDevices.size == 0) { + recordedMembers.delete(userId); + } + } + } + } + + exportGroupData(roomId: string, epochNumber: number, epochCreator: string): string { + const group = this.getGroup(roomId); + if (!group) { + throw new Error("No such group"); + } + console.info(roomId, epochNumber, BigInt(epochNumber), epochCreator, olmlib.decodeBase64(epochCreator)); + const groupExport = group.export_group(this.backend!, BigInt(epochNumber), olmlib.decodeBase64(epochCreator)); + console.info("OK", roomId, epochNumber, BigInt(epochNumber), epochCreator, olmlib.decodeBase64(epochCreator)); + return olmlib.encodeUnpaddedBase64(Uint8Array.from(groupExport)); + } + + importGroupData(roomId: string, epochNumber: number, epochCreator: string, groupExport: string): void { + console.log("Import group data for", roomId, epochNumber, epochCreator); + const groupExportBin = olmlib.decodeBase64(groupExport); + console.log("Decoded export"); + let group = this.getGroup(roomId); + if (!group) { + console.log("Creating group"); + const baseApis = this.crypto.baseApis; + group = matrixDmls.DmlsGroup.new_dummy_group( + this.backend!, + textEncoder.encode(roomId), + joinId(baseApis.getUserId()!, baseApis.getDeviceId()!), + ); + this.groups.set(roomId, group); + } + console.log("Importing..."); + try { + group.import_group(this.backend!, BigInt(epochNumber), olmlib.decodeBase64(epochCreator), groupExportBin); + } catch(e) { + console.error(e); + throw e; + } + console.log("Done"); + } + + signObject(obj: object & olmlib.IObject): void { + const sigs = obj.signatures || {}; + delete obj.signatures; + const unsigned = obj.unsigned; + if (obj.unsigned) delete obj.unsigned; + try { + const payload = textEncoder.encode(anotherjson.stringify(obj)); + const sig = this.backend!.sign(this.credential!, payload); + + const userId = this.crypto.baseApis.getUserId()! + const mysigs = sigs[userId] || {}; + sigs[userId] = mysigs; + mysigs[`org.matrix.msc2883.v0.dmls.credential.ed25519:${this.crypto.baseApis.getDeviceId()!}`] = olmlib.encodeBase64(sig); + } finally { + obj.signatures = sigs; + if (unsigned) obj.unsigned = unsigned; + } + } +} + +registerAlgorithm(MLS_ALGORITHM.name, MlsEncryption, MlsDecryption); +registerAlgorithm(WELCOME_PACKAGE.name, WelcomeEncryption, WelcomeDecryption); diff --git a/src/crypto/algorithms/index.ts b/src/crypto/algorithms/index.ts index b3c5b0ede84..f87f3a21192 100644 --- a/src/crypto/algorithms/index.ts +++ b/src/crypto/algorithms/index.ts @@ -16,5 +16,6 @@ limitations under the License. import "./olm"; import "./megolm"; +import "./dmls"; export * from "./base"; diff --git a/src/crypto/algorithms/megolm.ts b/src/crypto/algorithms/megolm.ts index 934b69bd35d..985c4c4b2f4 100644 --- a/src/crypto/algorithms/megolm.ts +++ b/src/crypto/algorithms/megolm.ts @@ -20,7 +20,7 @@ limitations under the License. import { v4 as uuidv4 } from "uuid"; -import type { IEventDecryptionResult, IMegolmSessionData } from "../../@types/crypto"; +import type { IEventDecryptionResult } from "../../@types/crypto"; import { logger, PrefixedLogger } from "../../logger"; import * as olmlib from "../olmlib"; import { @@ -555,7 +555,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { ); // don't wait for it to complete - this.crypto.backupManager.backupGroupSession(this.olmDevice.deviceCurve25519Key!, sessionId); + //this.crypto.backupManager.backupGroupSession(this.olmDevice.deviceCurve25519Key!, sessionId); return new OutboundSessionInfo(sessionId, sharedHistory); } @@ -1806,7 +1806,7 @@ export class MegolmDecryption extends DecryptionAlgorithm { } // don't wait for the keys to be backed up for the server - await this.crypto.backupManager.backupGroupSession(roomKey.senderKey, roomKey.sessionId); + //await this.crypto.backupManager.backupGroupSession(roomKey.senderKey, roomKey.sessionId); } catch (e) { this.prefixedLogger.error(`Error handling m.room_key_event: ${e}`); } @@ -2054,7 +2054,7 @@ export class MegolmDecryption extends DecryptionAlgorithm { * @param untrusted - whether the key should be considered as untrusted * @param source - where the key came from */ - public importRoomKey( + /*public importRoomKey( session: IMegolmSessionData, { untrusted, source }: { untrusted?: boolean; source?: string } = {}, ): Promise { @@ -2088,7 +2088,7 @@ export class MegolmDecryption extends DecryptionAlgorithm { // have another go at decrypting events sent with this session. this.retryDecryption(session.sender_key, session.session_id, !extraSessionData.untrusted); }); - } + }*/ /** * Have another go at decrypting events after we receive a key. Resolves once diff --git a/src/crypto/backup.ts b/src/crypto/backup.ts index d71cce99c7d..d6c49f4a74f 100644 --- a/src/crypto/backup.ts +++ b/src/crypto/backup.ts @@ -18,10 +18,11 @@ limitations under the License. * Classes for dealing with key backup. */ -import type { IMegolmSessionData } from "../@types/crypto"; +// import type { IMegolmSessionData } from "../@types/crypto"; import { MatrixClient } from "../client"; import { logger } from "../logger"; -import { MEGOLM_ALGORITHM, verifySignature } from "./olmlib"; +import { verifySignature } from "./olmlib"; +import { MLS_ALGORITHM, IMlsSessionData } from "./algorithms/dmls"; import { DeviceInfo } from "./deviceinfo"; import { DeviceTrustLevel } from "./CrossSigning"; import { keyFromPassphrase } from "./key_passphrase"; @@ -32,13 +33,11 @@ import { calculateKeyCheck, decryptAES, encryptAES, IEncryptedPayload } from "./ import { Curve25519SessionData, IAes256AuthData, - ICurve25519AuthData, IKeyBackupInfo, IKeyBackupSession, } from "./keybackup"; import { UnstableValue } from "../NamespacedValue"; import { CryptoEvent } from "./index"; -import { crypto } from "./crypto"; import { HTTPError, MatrixError } from "../http-api"; const KEY_BACKUP_KEYS_PER_REQUEST = 200; @@ -93,7 +92,7 @@ interface BackupAlgorithmClass { interface BackupAlgorithm { untrusted: boolean; encryptSession(data: Record): Promise; - decryptSessions(ciphertexts: Record): Promise; + decryptSessions(ciphertexts: Record): Promise; authData: AuthData; keyMatches(key: ArrayLike): Promise; free(): void; @@ -496,30 +495,34 @@ export class BackupManager { this.baseApis.crypto!.emit(CryptoEvent.KeyBackupSessionsRemaining, remaining); const rooms: IKeyBackup["rooms"] = {}; + const mlsProvider = this.baseApis.crypto!.mlsProvider; for (const session of sessions) { - const roomId = session.sessionData!.room_id; + const roomId = session.roomId; if (rooms[roomId] === undefined) { rooms[roomId] = { sessions: {} }; } - const sessionData = this.baseApis.crypto!.olmDevice.exportInboundGroupSession( - session.senderKey, - session.sessionId, - session.sessionData!, - ); - sessionData.algorithm = MEGOLM_ALGORITHM; - - const forwardedCount = (sessionData.forwarding_curve25519_key_chain || []).length; + const groupExport = mlsProvider.exportGroupData(session.roomId, session.epochNumber, session.epochCreator); + const sessionData = { + algorithm: MLS_ALGORITHM.name, + group_export: groupExport, + room_id: roomId, + epoch: [session.epochNumber, session.epochCreator], + }; + /* const userId = this.baseApis.crypto!.deviceList.getUserByIdentityKey(MEGOLM_ALGORITHM, session.senderKey); const device = this.baseApis.crypto!.deviceList.getDeviceByIdentityKey(MEGOLM_ALGORITHM, session.senderKey) ?? undefined; - const verified = this.baseApis.crypto!.checkDeviceInfoTrust(userId!, device).isVerified(); + */ + const verified = false; //this.baseApis.crypto!.checkDeviceInfoTrust(userId!, device).isVerified(); + + const sessionId = session.epochNumber + "|" + session.epochCreator; - rooms[roomId]["sessions"][session.sessionId] = { - first_message_index: sessionData.first_known_index, - forwarded_count: forwardedCount, + rooms[roomId]["sessions"][sessionId] = { + first_message_index: 0, + forwarded_count: 0, is_verified: verified, session_data: await this.algorithm!.encryptSession(sessionData), }; @@ -534,11 +537,12 @@ export class BackupManager { return sessions.length; } - public async backupGroupSession(senderKey: string, sessionId: string): Promise { + public async backupGroupSession(roomId: string, epochNumber: number, epochCreator: string): Promise { await this.baseApis.crypto!.cryptoStore.markSessionsNeedingBackup([ { - senderKey: senderKey, - sessionId: sessionId, + roomId, + epochNumber, + epochCreator, }, ]); @@ -595,6 +599,7 @@ export class BackupManager { } } +/* export class Curve25519 implements BackupAlgorithm { public static algorithmName = "m.megolm_backup.v1.curve25519-aes-sha2"; @@ -656,7 +661,7 @@ export class Curve25519 implements BackupAlgorithm { public async decryptSessions( sessions: Record>, - ): Promise { + ): Promise { const privKey = await this.getKey(); const decryption = new global.Olm.PkDecryption(); try { @@ -666,7 +671,7 @@ export class Curve25519 implements BackupAlgorithm { throw new MatrixError({ errcode: MatrixClient.RESTORE_BACKUP_ERROR_BAD_KEY }); } - const keys: IMegolmSessionData[] = []; + const keys: IMlsSessionData[] = []; for (const [sessionId, sessionData] of Object.entries(sessions)) { try { @@ -705,6 +710,7 @@ export class Curve25519 implements BackupAlgorithm { this.publicKey.free(); } } +*/ function randomBytes(size: number): Uint8Array { const buf = new Uint8Array(size); @@ -712,6 +718,7 @@ function randomBytes(size: number): Uint8Array { return buf; } +/* const UNSTABLE_MSC3270_NAME = new UnstableValue( "m.megolm_backup.v1.aes-hmac-sha2", "org.matrix.msc3270.v1.aes-hmac-sha2", @@ -806,10 +813,106 @@ export class Aes256 implements BackupAlgorithm { this.key.fill(0); } } +*/ + +const UNSTABLE_MSC4038_NAME = new UnstableValue( + "m.dmls_backup.v1.aes-hmac-sha2", + "org.matrix.msc4038.v0.aes-hmac-sha2", +); + +export class MlsAes256 implements BackupAlgorithm { + public static algorithmName = UNSTABLE_MSC4038_NAME.name; + + public constructor(public readonly authData: IAes256AuthData, private readonly key: Uint8Array) {} + + public static async init(authData: IAes256AuthData, getKey: () => Promise): Promise { + if (!authData) { + throw new Error("auth_data missing"); + } + const key = await getKey(); + if (authData.mac) { + const { mac } = await calculateKeyCheck(key, authData.iv); + if (authData.mac.replace(/=+$/g, "") !== mac.replace(/=+/g, "")) { + throw new Error("Key does not match"); + } + } + return new MlsAes256(authData, key); + } + + public static async prepare(key?: string | Uint8Array | null): Promise<[Uint8Array, AuthData]> { + let outKey: Uint8Array; + const authData: Partial = {}; + if (!key) { + outKey = randomBytes(32); + } else if (key instanceof Uint8Array) { + outKey = new Uint8Array(key); + } else { + const derivation = await keyFromPassphrase(key); + authData.private_key_salt = derivation.salt; + authData.private_key_iterations = derivation.iterations; + outKey = derivation.key; + } + + const { iv, mac } = await calculateKeyCheck(outKey); + authData.iv = iv; + authData.mac = mac; + + return [outKey, authData as AuthData]; + } + + public static checkBackupVersion(info: IKeyBackupInfo): void { + if (!("iv" in info.auth_data && "mac" in info.auth_data)) { + throw new Error("Invalid backup data returned"); + } + } + + public get untrusted(): boolean { + return false; + } + + public encryptSession(data: Record): Promise { + const plainText: Record = Object.assign({}, data); + delete plainText.session_id; + delete plainText.room_id; + delete plainText.first_known_index; + return encryptAES(JSON.stringify(plainText), this.key, data.epoch.join("|")); + } + + public async decryptSessions( + sessions: Record>, + ): Promise { + const keys: IMlsSessionData[] = []; + + for (const [sessionId, sessionData] of Object.entries(sessions)) { + try { + const decrypted = JSON.parse(await decryptAES(sessionData.session_data, this.key, sessionId)); + const [epochNum, epochCreator] = sessionId.split("|"); + decrypted.epoch = [parseInt(epochNum), epochCreator]; + keys.push(decrypted); + } catch (e) { + logger.log("Failed to decrypt MLS session from backup", e, sessionData); + } + } + return keys; + } + + public async keyMatches(key: Uint8Array): Promise { + if (this.authData.mac) { + const { mac } = await calculateKeyCheck(key, this.authData.iv); + return this.authData.mac.replace(/=+$/g, "") === mac.replace(/=+/g, ""); + } else { + // if we have no information, we have to assume the key is right + return true; + } + } + + public free(): void { + this.key.fill(0); + } +} export const algorithmsByName: Record = { - [Curve25519.algorithmName]: Curve25519, - [Aes256.algorithmName]: Aes256, + [MlsAes256.algorithmName]: MlsAes256, }; -export const DefaultAlgorithm: BackupAlgorithmClass = Curve25519; +export const DefaultAlgorithm: BackupAlgorithmClass = MlsAes256; diff --git a/src/crypto/index.ts b/src/crypto/index.ts index 5500872226f..8ffc0b3c0db 100644 --- a/src/crypto/index.ts +++ b/src/crypto/index.ts @@ -28,6 +28,7 @@ import { logger } from "../logger"; import { IExportedDevice, OlmDevice } from "./OlmDevice"; import { IOlmDevice } from "./algorithms/megolm"; import * as olmlib from "./olmlib"; +import * as dmls from "./algorithms/dmls"; import { DeviceInfoMap, DeviceList } from "./DeviceList"; import { DeviceInfo, IDevice } from "./deviceinfo"; import type { DecryptionAlgorithm, EncryptionAlgorithm } from "./algorithms"; @@ -224,9 +225,16 @@ export interface IMegolmEncryptedContent { ciphertext: string; [ToDeviceMessageId]?: string; } + +export interface IMlsEncryptedContent { + algorithm: typeof dmls.MLS_ALGORITHM.name; + ciphertext: string; + epoch_creator: string; + commit_event: string; +} /* eslint-enable camelcase */ -export type IEncryptedContent = IOlmEncryptedContent | IMegolmEncryptedContent; +export type IEncryptedContent = IOlmEncryptedContent | IMegolmEncryptedContent | IMlsEncryptedContent; export enum CryptoEvent { DeviceVerificationChanged = "deviceVerificationChanged", @@ -407,9 +415,12 @@ export class Crypto extends TypedEventEmitter; + public mlsProvider: dmls.MlsProvider; + /** * Cryptography bits * @@ -531,6 +542,8 @@ export class Crypto extends TypedEventEmitter { + this.mlsProvider.signObject(deviceKeys); return this.baseApis.uploadKeysRequest({ device_keys: deviceKeys as Required, }); @@ -1822,9 +1838,10 @@ export class Crypto extends TypedEventEmitter => { - while (keyLimit > keyCount || this.getNeedsNewFallback()) { + const uploadLoop = async (otkKeyCount: number, initKeyCount: number): Promise => { + while (keyLimit > otkKeyCount || this.getNeedsNewFallback() || keyLimit > initKeyCount) { // Ask olm to generate new one time keys, then upload them to synapse. - if (keyLimit > keyCount) { + if (keyLimit > otkKeyCount) { logger.info("generating oneTimeKeys"); - const keysThisLoop = Math.min(keyLimit - keyCount, maxKeysPerCycle); + const keysThisLoop = Math.min(keyLimit - otkKeyCount, maxKeysPerCycle); await this.olmDevice.generateOneTimeKeys(keysThisLoop); } @@ -1911,39 +1928,54 @@ export class Crypto extends TypedEventEmitter initKeyCount) { + logger.info("generating initKeys"); + const keysThisLoop = Math.min(keyLimit - initKeyCount, maxKeysPerCycle); + initKeys = this.mlsProvider.backend!.make_init_keys( + this.mlsProvider.credential!, + keysThisLoop, + ) + .map((keyArr: number[]) => olmlib.encodeUnpaddedBase64(Uint8Array.from(keyArr))); + } + logger.info("calling uploadOneTimeKeys"); - const res = await this.uploadOneTimeKeys(); + const res = await this.uploadOneTimeKeys(initKeys); if (res.one_time_key_counts && res.one_time_key_counts.signed_curve25519) { // if the response contains a more up to date value use this // for the next loop - keyCount = res.one_time_key_counts.signed_curve25519; + otkKeyCount = res.one_time_key_counts.signed_curve25519; } else { throw new Error( "response for uploading keys does not contain " + "one_time_key_counts.signed_curve25519", ); } + initKeyCount = (res.one_time_key_counts && res.one_time_key_counts [dmls.INIT_KEY_ALGORITHM.name]) || 0; } }; this.oneTimeKeyCheckInProgress = true; Promise.resolve() .then(() => { - if (this.oneTimeKeyCount !== undefined) { + if (this.oneTimeKeyCount !== undefined && this.initKeyCount !== undefined) { // We already have the current one_time_key count from a /sync response. // Use this value instead of asking the server for the current key count. - return Promise.resolve(this.oneTimeKeyCount); + return Promise.resolve([this.oneTimeKeyCount, this.initKeyCount]); } // ask the server how many keys we have return this.baseApis.uploadKeysRequest({}).then((res) => { - return res.one_time_key_counts.signed_curve25519 || 0; + return [ + res.one_time_key_counts.signed_curve25519 || 0, + res.one_time_key_counts[dmls.INIT_KEY_ALGORITHM.name] || 0, + ]; }); }) - .then((keyCount) => { + .then(([otkKeyCount, initKeyCount]) => { // Start the uploadLoop with the current keyCount. The function checks if // we need to upload new keys or not. // If there are too many keys on the server then we don't need to // create any more keys. - return uploadLoop(keyCount); + return uploadLoop(otkKeyCount, initKeyCount); }) .catch((e) => { logger.error("Error uploading one-time keys", e.stack || e); @@ -1952,12 +1984,13 @@ export class Crypto extends TypedEventEmitter { + private async uploadOneTimeKeys(initKeys: string[]): Promise { const promises: Promise[] = []; let fallbackJson: Record | undefined; @@ -1973,7 +2006,7 @@ export class Crypto extends TypedEventEmitter = {}; + const oneTimeJson: Record = {}; for (const keyId in oneTimeKeys.curve25519) { if (oneTimeKeys.curve25519.hasOwnProperty(keyId)) { @@ -1987,6 +2020,12 @@ export class Crypto extends TypedEventEmitter = { one_time_keys: oneTimeJson, }; @@ -2722,6 +2761,7 @@ export class Crypto extends TypedEventEmitter { const exportedSessions: IMegolmSessionData[] = []; + /* await this.cryptoStore.doTxn("readonly", [IndexedDBCryptoStore.STORE_INBOUND_GROUP_SESSIONS], (txn) => { this.cryptoStore.getAllEndToEndInboundGroupSessions(txn, (s) => { if (s === null) return; @@ -2732,6 +2772,7 @@ export class Crypto extends TypedEventEmitter { + public importRoomKeys(keys: dmls.IMlsSessionData[], opts: IImportRoomKeysOpts = {}): Promise { let successes = 0; let failures = 0; const total = keys.length; @@ -3198,7 +3239,7 @@ export class Crypto extends TypedEventEmitter { if ( toDevice.type === EventType.RoomMessageEncrypted && - !["m.olm.v1.curve25519-aes-sha2"].includes(toDevice.content?.algorithm) + !["m.olm.v1.curve25519-aes-sha2", dmls.WELCOME_PACKAGE.name].includes(toDevice.content?.algorithm) ) { logger.log("Ignoring invalid encrypted to-device event from " + toDevice.sender); return false; diff --git a/src/crypto/store/base.ts b/src/crypto/store/base.ts index 4c88ec2872e..2fc069eb63c 100644 --- a/src/crypto/store/base.ts +++ b/src/crypto/store/base.ts @@ -18,6 +18,7 @@ import { IRoomKeyRequestBody, IRoomKeyRequestRecipient } from "../index"; import { RoomKeyRequestState } from "../OutgoingRoomKeyRequestManager"; import { ICrossSigningKey } from "../../client"; import { IOlmDevice } from "../algorithms/megolm"; +import { IMlsSessionData } from "../algorithms/dmls"; import { TrackingStatus } from "../DeviceList"; import { IRoomEncryption } from "../RoomList"; import { IDevice } from "../deviceinfo"; @@ -150,9 +151,10 @@ export interface CryptoStore { export type Mode = "readonly" | "readwrite"; export interface ISession { - senderKey: string; - sessionId: string; - sessionData?: InboundGroupSessionData; + roomId: string; + epochNumber: number; + epochCreator: string; + sessionData?: IMlsSessionData; } export interface ISessionInfo { diff --git a/src/crypto/store/indexeddb-crypto-store-backend.ts b/src/crypto/store/indexeddb-crypto-store-backend.ts index 7827697ec8d..68f16a4bba9 100644 --- a/src/crypto/store/indexeddb-crypto-store-backend.ts +++ b/src/crypto/store/indexeddb-crypto-store-backend.ts @@ -634,7 +634,7 @@ export class Backend implements CryptoStore { } public getAllEndToEndInboundGroupSessions(txn: IDBTransaction, func: (session: ISession | null) => void): void { - const objectStore = txn.objectStore("inbound_group_sessions"); + /* const objectStore = txn.objectStore("inbound_group_sessions"); const getReq = objectStore.openCursor(); getReq.onsuccess = function (): void { const cursor = getReq.result; @@ -656,7 +656,8 @@ export class Backend implements CryptoStore { abortWithException(txn, e); } } - }; + }; */ + func(null); } public addEndToEndInboundGroupSession( @@ -765,19 +766,20 @@ export class Backend implements CryptoStore { resolve(sessions); }; const objectStore = txn.objectStore("sessions_needing_backup"); - const sessionStore = txn.objectStore("inbound_group_sessions"); + //const sessionStore = txn.objectStore("inbound_group_sessions"); const getReq = objectStore.openCursor(); getReq.onsuccess = function (): void { const cursor = getReq.result; if (cursor) { - const sessionGetReq = sessionStore.get(cursor.key); + sessions.push(cursor.value); + /*const sessionGetReq = sessionStore.get(cursor.key); sessionGetReq.onsuccess = function (): void { sessions.push({ senderKey: sessionGetReq.result.senderCurve25519Key, sessionId: sessionGetReq.result.sessionId, sessionData: sessionGetReq.result.session, }); - }; + };*/ if (!limit || sessions.length < limit) { cursor.continue(); } @@ -806,7 +808,7 @@ export class Backend implements CryptoStore { await Promise.all( sessions.map((session) => { return new Promise((resolve, reject) => { - const req = objectStore.delete([session.senderKey, session.sessionId]); + const req = objectStore.delete([session.roomId, session.epochNumber, session.epochCreator]); req.onsuccess = resolve; req.onerror = reject; }); @@ -823,8 +825,9 @@ export class Backend implements CryptoStore { sessions.map((session) => { return new Promise((resolve, reject) => { const req = objectStore.put({ - senderCurve25519Key: session.senderKey, - sessionId: session.sessionId, + roomId: session.roomId, + epochNumber: session.epochNumber, + epochCreator: session.epochCreator, }); req.onsuccess = resolve; req.onerror = reject; diff --git a/src/crypto/store/localStorage-crypto-store.ts b/src/crypto/store/localStorage-crypto-store.ts index 1a9adfb25d1..2f54690522c 100644 --- a/src/crypto/store/localStorage-crypto-store.ts +++ b/src/crypto/store/localStorage-crypto-store.ts @@ -206,7 +206,7 @@ export class LocalStorageCryptoStore extends MemoryCryptoStore { } public getAllEndToEndInboundGroupSessions(txn: unknown, func: (session: ISession | null) => void): void { - for (let i = 0; i < this.store.length; ++i) { + /* for (let i = 0; i < this.store.length; ++i) { const key = this.store.key(i); if (key?.startsWith(KEY_INBOUND_SESSION_PREFIX)) { // we can't use split, as the components we are trying to split out @@ -220,7 +220,7 @@ export class LocalStorageCryptoStore extends MemoryCryptoStore { sessionData: getJsonItem(this.store, key)!, }); } - } + } */ func(null); } @@ -287,14 +287,11 @@ export class LocalStorageCryptoStore extends MemoryCryptoStore { for (const session in sessionsNeedingBackup) { if (Object.prototype.hasOwnProperty.call(sessionsNeedingBackup, session)) { // see getAllEndToEndInboundGroupSessions for the magic number explanations - const senderKey = session.slice(0, 43); - const sessionId = session.slice(44); - this.getEndToEndInboundGroupSession(senderKey, sessionId, null, (sessionData) => { - sessions.push({ - senderKey: senderKey, - sessionId: sessionId, - sessionData: sessionData!, - }); + const [roomId, epochNumber, epochCreator] = JSON.parse(session); + sessions.push({ + roomId, + epochNumber, + epochCreator, }); if (limit && sessions.length >= limit) { break; @@ -315,7 +312,7 @@ export class LocalStorageCryptoStore extends MemoryCryptoStore { [senderKeySessionId: string]: string; }>(this.store, KEY_SESSIONS_NEEDING_BACKUP) || {}; for (const session of sessions) { - delete sessionsNeedingBackup[session.senderKey + "/" + session.sessionId]; + delete sessionsNeedingBackup[JSON.stringify([session.roomId, session.epochNumber, session.epochCreator])]; } setJsonItem(this.store, KEY_SESSIONS_NEEDING_BACKUP, sessionsNeedingBackup); return Promise.resolve(); @@ -327,7 +324,7 @@ export class LocalStorageCryptoStore extends MemoryCryptoStore { [senderKeySessionId: string]: boolean; }>(this.store, KEY_SESSIONS_NEEDING_BACKUP) || {}; for (const session of sessions) { - sessionsNeedingBackup[session.senderKey + "/" + session.sessionId] = true; + sessionsNeedingBackup[JSON.stringify([session.roomId, session.epochNumber, session.epochCreator])] = true; } setJsonItem(this.store, KEY_SESSIONS_NEEDING_BACKUP, sessionsNeedingBackup); return Promise.resolve(); diff --git a/src/crypto/store/memory-crypto-store.ts b/src/crypto/store/memory-crypto-store.ts index ad779ca993b..da4d7a1e68f 100644 --- a/src/crypto/store/memory-crypto-store.ts +++ b/src/crypto/store/memory-crypto-store.ts @@ -399,7 +399,7 @@ export class MemoryCryptoStore implements CryptoStore { } public getAllEndToEndInboundGroupSessions(txn: unknown, func: (session: ISession | null) => void): void { - for (const key of Object.keys(this.inboundGroupSessions)) { + /* for (const key of Object.keys(this.inboundGroupSessions)) { // we can't use split, as the components we are trying to split out // might themselves contain '/' characters. We rely on the // senderKey being a (32-byte) curve25519 key, base64-encoded @@ -410,7 +410,7 @@ export class MemoryCryptoStore implements CryptoStore { sessionId: key.slice(44), sessionData: this.inboundGroupSessions[key], }); - } + } */ func(null); } @@ -467,7 +467,7 @@ export class MemoryCryptoStore implements CryptoStore { public getSessionsNeedingBackup(limit: number): Promise { const sessions: ISession[] = []; - for (const session in this.sessionsNeedingBackup) { + /* for (const session in this.sessionsNeedingBackup) { if (this.inboundGroupSessions[session]) { sessions.push({ senderKey: session.slice(0, 43), @@ -478,7 +478,7 @@ export class MemoryCryptoStore implements CryptoStore { break; } } - } + } */ return Promise.resolve(sessions); } @@ -487,18 +487,18 @@ export class MemoryCryptoStore implements CryptoStore { } public unmarkSessionsNeedingBackup(sessions: ISession[]): Promise { - for (const session of sessions) { + /* for (const session of sessions) { const sessionKey = session.senderKey + "/" + session.sessionId; delete this.sessionsNeedingBackup[sessionKey]; - } + } */ return Promise.resolve(); } public markSessionsNeedingBackup(sessions: ISession[]): Promise { - for (const session of sessions) { + /* for (const session of sessions) { const sessionKey = session.senderKey + "/" + session.sessionId; this.sessionsNeedingBackup[sessionKey] = true; - } + } */ return Promise.resolve(); } diff --git a/src/sliding-sync-sdk.ts b/src/sliding-sync-sdk.ts index 93e29e0baa3..62ab75e09fa 100644 --- a/src/sliding-sync-sdk.ts +++ b/src/sliding-sync-sdk.ts @@ -45,6 +45,7 @@ import { EventType } from "./@types/event"; import { IPushRules } from "./@types/PushRules"; import { RoomStateEvent } from "./models/room-state"; import { RoomMemberEvent } from "./models/room-member"; +import { INIT_KEY_ALGORITHM } from "./crypto/algorithms/dmls"; // Number of consecutive failed syncs that will lead to a syncState of ERROR as opposed // to RECONNECTING. This is needed to inform the client of server issues when the @@ -97,7 +98,8 @@ class ExtensionE2EE implements Extension