diff --git a/package.json b/package.json index c4d264e58e..f1d8def9fa 100644 --- a/package.json +++ b/package.json @@ -130,5 +130,6 @@ "outputDirectory": "coverage", "outputName": "jest-sonar-report.xml", "relativePaths": true - } + }, + "packageManager": "yarn@1.22.22+sha512.a6b2f7906b721bba3d67d4aff083df04dad64c399707841b7acf00f6b133b7ac24255f2652fa22ae3534329dc6180534e98d17432037ff6fd140556e2bb3137e" } diff --git a/spec/unit/matrixrtc/KeyBuffer.spec.ts b/spec/unit/matrixrtc/KeyBuffer.spec.ts new file mode 100644 index 0000000000..5abcde3d0a --- /dev/null +++ b/spec/unit/matrixrtc/KeyBuffer.spec.ts @@ -0,0 +1,70 @@ +/* +Copyright 2025 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import { KeyBuffer } from "../../../src/matrixrtc/utils.ts"; +import { type InboundEncryptionSession } from "../../../src/matrixrtc"; + +describe("KeyBuffer Test", () => { + it("Should buffer and disambiguate keys by timestamp", () => { + jest.useFakeTimers(); + + const buffer = new KeyBuffer(1000); + + const aKey = fakeInboundSessionWithTimestamp(1000); + const olderKey = fakeInboundSessionWithTimestamp(300); + // Simulate receiving out of order keys + + const init = buffer.disambiguate(aKey.participantId, aKey); + expect(init).toEqual(aKey); + // Some time pass + jest.advanceTimersByTime(600); + // Then we receive the most recent key out of order + + const key = buffer.disambiguate(aKey.participantId, olderKey); + // this key is older and should be ignored even if received after + expect(key).toBe(null); + }); + + it("Should clear buffer after ttl", () => { + jest.useFakeTimers(); + + const buffer = new KeyBuffer(1000); + + const aKey = fakeInboundSessionWithTimestamp(1000); + const olderKey = fakeInboundSessionWithTimestamp(300); + // Simulate receiving out of order keys + + const init = buffer.disambiguate(aKey.participantId, aKey); + expect(init).toEqual(aKey); + + // Similar to previous test but there is too much delay + // We don't want to keep key material for too long + jest.advanceTimersByTime(1200); + + const key = buffer.disambiguate(aKey.participantId, olderKey); + // The buffer is cleared so should return this key + expect(key).toBe(olderKey); + }); + + function fakeInboundSessionWithTimestamp(ts: number): InboundEncryptionSession { + return { + keyIndex: 0, + creationTS: ts, + participantId: "@alice:localhost|ABCDE", + key: new Uint8Array(16), + }; + } +}); diff --git a/spec/unit/matrixrtc/MatrixRTCSession.spec.ts b/spec/unit/matrixrtc/MatrixRTCSession.spec.ts index 3508e01a64..5267177151 100644 --- a/spec/unit/matrixrtc/MatrixRTCSession.spec.ts +++ b/spec/unit/matrixrtc/MatrixRTCSession.spec.ts @@ -21,6 +21,7 @@ import { MatrixRTCSession, MatrixRTCSessionEvent } from "../../../src/matrixrtc/ import { type EncryptionKeysEventContent } from "../../../src/matrixrtc/types"; import { secureRandomString } from "../../../src/randomstring"; import { makeMockEvent, makeMockRoom, makeMockRoomState, membershipTemplate, makeKey } from "./mocks"; +import { RTCEncryptionManager } from "../../../src/matrixrtc/RTCEncryptionManager.ts"; const mockFocus = { type: "mock" }; @@ -878,11 +879,27 @@ describe("MatrixRTCSession", () => { expect(sendKeySpy).toHaveBeenCalledTimes(1); // check that we send the key with index 1 even though the send gets delayed when leaving. // this makes sure we do not use an index that is one too old. - expect(sendKeySpy).toHaveBeenLastCalledWith(expect.any(String), 1, sess.memberships); + expect(sendKeySpy).toHaveBeenLastCalledWith( + expect.any(String), + 1, + sess.memberships.map((m) => ({ + userId: m.sender, + deviceId: m.deviceId, + membershipTs: m.createdTs(), + })), + ); // fake a condition in which we send another encryption key event. // this could happen do to someone joining the call. (sess as unknown as any).encryptionManager.sendEncryptionKeysEvent(); - expect(sendKeySpy).toHaveBeenLastCalledWith(expect.any(String), 1, sess.memberships); + expect(sendKeySpy).toHaveBeenLastCalledWith( + expect.any(String), + 1, + sess.memberships.map((m) => ({ + userId: m.sender, + deviceId: m.deviceId, + membershipTs: m.createdTs(), + })), + ); jest.advanceTimersByTime(7000); const secondKeysPayload = await keysSentPromise2; @@ -996,10 +1013,14 @@ describe("MatrixRTCSession", () => { useNewMembershipManager: true, useExperimentalToDeviceTransport: true, }); + sess.onRTCSessionMemberUpdate(); await keySentPromise; expect(sendToDeviceMock).toHaveBeenCalled(); + + // Access private to test + expect(sess["encryptionManager"]).toBeInstanceOf(RTCEncryptionManager); } finally { jest.useRealTimers(); } diff --git a/spec/unit/matrixrtc/RTCEncrytionManager.spec.ts b/spec/unit/matrixrtc/RTCEncrytionManager.spec.ts new file mode 100644 index 0000000000..135f978bc6 --- /dev/null +++ b/spec/unit/matrixrtc/RTCEncrytionManager.spec.ts @@ -0,0 +1,598 @@ +/* +Copyright 2025 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import { type Mocked } from "jest-mock"; + +import { RTCEncryptionManager } from "../../../src/matrixrtc/RTCEncryptionManager.ts"; +import { type CallMembership, type Statistics } from "../../../src/matrixrtc"; +import { type ToDeviceKeyTransport } from "../../../src/matrixrtc/ToDeviceKeyTransport.ts"; +import { KeyTransportEvents, type KeyTransportEventsHandlerMap } from "../../../src/matrixrtc/IKeyTransport.ts"; +import { membershipTemplate, mockCallMembership } from "./mocks.ts"; +import { decodeBase64, TypedEventEmitter } from "../../../src"; +import { RoomAndToDeviceTransport } from "../../../src/matrixrtc/RoomAndToDeviceKeyTransport.ts"; +import { type RoomKeyTransport } from "../../../src/matrixrtc/RoomKeyTransport.ts"; +import type { Logger } from "../../../src/logger.ts"; + +describe("RTCEncryptionManager", () => { + // The manager being tested + let encryptionManager: RTCEncryptionManager; + let getMembershipMock: jest.Mock; + let mockTransport: Mocked; + let statistics: Statistics; + let onEncryptionKeysChanged: jest.Mock; + + beforeEach(() => { + statistics = { + counters: { + roomEventEncryptionKeysSent: 0, + roomEventEncryptionKeysReceived: 0, + }, + totals: { + roomEventEncryptionKeysReceivedTotalAge: 0, + }, + }; + getMembershipMock = jest.fn().mockReturnValue([]); + onEncryptionKeysChanged = jest.fn(); + mockTransport = { + start: jest.fn(), + stop: jest.fn(), + sendKey: jest.fn().mockResolvedValue(undefined), + on: jest.fn(), + off: jest.fn(), + } as unknown as Mocked; + + encryptionManager = new RTCEncryptionManager( + "@alice:example.org", + "DEVICE01", + getMembershipMock, + mockTransport, + statistics, + onEncryptionKeysChanged, + ); + }); + + it("should start and stop the transport properly", () => { + encryptionManager.join(undefined); + + expect(mockTransport.start).toHaveBeenCalledTimes(1); + expect(mockTransport.on).toHaveBeenCalledTimes(1); + expect(mockTransport.on).toHaveBeenCalledWith(KeyTransportEvents.ReceivedKeys, expect.any(Function)); + encryptionManager.leave(); + expect(mockTransport.stop).toHaveBeenCalledTimes(1); + expect(mockTransport.off).toHaveBeenCalledWith(KeyTransportEvents.ReceivedKeys, expect.any(Function)); + }); + + describe("Sharing Keys", () => { + it("Set up my key asap even if no key distribution is needed", () => { + getMembershipMock.mockReturnValue([]); + + encryptionManager.join(undefined); + // After join it is too early, key might be lost as no one is listening yet + expect(onEncryptionKeysChanged).not.toHaveBeenCalled(); + encryptionManager.onMembershipsUpdate([]); + // The key should have been rolled out immediately + expect(onEncryptionKeysChanged).toHaveBeenCalled(); + }); + + it("Should distribute keys to members on join", async () => { + const members = [ + aCallMembership("@bob:example.org", "BOBDEVICE"), + aCallMembership("@bob:example.org", "BOBDEVICE2"), + aCallMembership("@carl:example.org", "CARLDEVICE"), + ]; + getMembershipMock.mockReturnValue(members); + + encryptionManager.join(undefined); + encryptionManager.onMembershipsUpdate([]); + + expect(mockTransport.sendKey).toHaveBeenCalledTimes(1); + expect(mockTransport.sendKey).toHaveBeenCalledWith( + expect.any(String), + // It is the first key + 0, + members.map((m) => ({ userId: m.sender, deviceId: m.deviceId, membershipTs: m.createdTs() })), + ); + await jest.runOnlyPendingTimersAsync(); + // The key should have been rolled out immediately + expect(onEncryptionKeysChanged).toHaveBeenCalled(); + expect(onEncryptionKeysChanged).toHaveBeenCalledWith( + expect.any(Uint8Array), + 0, + "@alice:example.org:DEVICE01", + ); + }); + + it("Should re-distribute keys to members whom callMemberhsip ts has changed", async () => { + let members = [aCallMembership("@bob:example.org", "BOBDEVICE", 1000)]; + getMembershipMock.mockReturnValue(members); + + encryptionManager.join(undefined); + encryptionManager.onMembershipsUpdate([]); + + expect(mockTransport.sendKey).toHaveBeenCalledTimes(1); + expect(mockTransport.sendKey).toHaveBeenCalledWith( + expect.any(String), + // It is the first key + 0, + [ + { + userId: "@bob:example.org", + deviceId: "BOBDEVICE", + membershipTs: 1000, + }, + ], + ); + await jest.runOnlyPendingTimersAsync(); + // The key should have been rolled out immediately + expect(onEncryptionKeysChanged).toHaveBeenCalled(); + + mockTransport.sendKey.mockClear(); + onEncryptionKeysChanged.mockClear(); + + members = [aCallMembership("@bob:example.org", "BOBDEVICE", 2000)]; + getMembershipMock.mockReturnValue(members); + + // There are no membership change but the callMembership ts has changed (reset?) + // Resend the key + encryptionManager.onMembershipsUpdate(members); + await jest.runOnlyPendingTimersAsync(); + + expect(mockTransport.sendKey).toHaveBeenCalledTimes(1); + expect(mockTransport.sendKey).toHaveBeenCalledWith( + expect.any(String), + // Re send the same key to that user + 0, + [ + { + userId: "@bob:example.org", + deviceId: "BOBDEVICE", + membershipTs: 2000, + }, + ], + ); + }); + + it("Should not rotate key when a user join", async () => { + jest.useFakeTimers(); + + const members = [ + aCallMembership("@bob:example.org", "BOBDEVICE"), + aCallMembership("@bob:example.org", "BOBDEVICE2"), + ]; + getMembershipMock.mockReturnValue(members); + + // initial rollout + encryptionManager.join(undefined); + encryptionManager.onMembershipsUpdate([]); + await jest.runOnlyPendingTimersAsync(); + + expect(mockTransport.sendKey).toHaveBeenCalledTimes(1); + expect(mockTransport.sendKey).toHaveBeenCalledWith( + expect.any(String), + // It is the first key + 0, + members.map((m) => ({ userId: m.sender, deviceId: m.deviceId, membershipTs: m.createdTs() })), + ); + onEncryptionKeysChanged.mockClear(); + mockTransport.sendKey.mockClear(); + + const updatedMembers = [ + aCallMembership("@bob:example.org", "BOBDEVICE"), + aCallMembership("@bob:example.org", "BOBDEVICE2"), + aCallMembership("@carl:example.org", "CARLDEVICE"), + ]; + getMembershipMock.mockReturnValue(updatedMembers); + + encryptionManager.onMembershipsUpdate(updatedMembers); + + await jest.runOnlyPendingTimersAsync(); + + expect(mockTransport.sendKey).toHaveBeenCalledWith( + expect.any(String), + // It should not have incremented the key index + 0, + // And send it to the newly joined only + [{ userId: "@carl:example.org", deviceId: "CARLDEVICE", membershipTs: 1000 }], + ); + + expect(onEncryptionKeysChanged).not.toHaveBeenCalled(); + await jest.advanceTimersByTimeAsync(1000); + + expect(statistics.counters.roomEventEncryptionKeysSent).toBe(2); + }); + + it("Should not resend keys when no changes", async () => { + jest.useFakeTimers(); + + const members = [ + aCallMembership("@bob:example.org", "BOBDEVICE"), + aCallMembership("@bob:example.org", "BOBDEVICE2"), + ]; + getMembershipMock.mockReturnValue(members); + + // initial rollout + encryptionManager.join(undefined); + encryptionManager.onMembershipsUpdate([]); + await jest.runOnlyPendingTimersAsync(); + + expect(mockTransport.sendKey).toHaveBeenCalledTimes(1); + onEncryptionKeysChanged.mockClear(); + mockTransport.sendKey.mockClear(); + + encryptionManager.onMembershipsUpdate(members); + await jest.advanceTimersByTimeAsync(200); + encryptionManager.onMembershipsUpdate(members); + await jest.advanceTimersByTimeAsync(100); + encryptionManager.onMembershipsUpdate(members); + await jest.advanceTimersByTimeAsync(50); + encryptionManager.onMembershipsUpdate(members); + await jest.advanceTimersByTimeAsync(100); + + expect(mockTransport.sendKey).not.toHaveBeenCalled(); + }); + + it("Should rotate key when a user leaves and delay the rollout", async () => { + jest.useFakeTimers(); + + const members = [ + aCallMembership("@bob:example.org", "BOBDEVICE"), + aCallMembership("@bob:example.org", "BOBDEVICE2"), + aCallMembership("@carl:example.org", "CARLDEVICE"), + ]; + getMembershipMock.mockReturnValue(members); + + encryptionManager.join(undefined); + encryptionManager.onMembershipsUpdate([]); + await jest.advanceTimersByTimeAsync(10); + + expect(mockTransport.sendKey).toHaveBeenCalledTimes(1); + expect(mockTransport.sendKey).toHaveBeenCalledWith( + expect.any(String), + // It is the first key + 0, + members.map((m) => ({ userId: m.sender, deviceId: m.deviceId, membershipTs: m.createdTs() })), + ); + // initial rollout + expect(mockTransport.sendKey).toHaveBeenCalled(); + expect(onEncryptionKeysChanged).toHaveBeenCalledTimes(1); + onEncryptionKeysChanged.mockClear(); + + const updatedMembers = [ + aCallMembership("@bob:example.org", "BOBDEVICE"), + aCallMembership("@bob:example.org", "BOBDEVICE2"), + ]; + getMembershipMock.mockReturnValue(updatedMembers); + + encryptionManager.onMembershipsUpdate(updatedMembers); + + await jest.advanceTimersByTimeAsync(200); + // The is rotated but not rolled out yet to give time for the key to be sent + expect(mockTransport.sendKey).toHaveBeenCalledWith( + expect.any(String), + // It should have incremented the key index + 1, + // And send it to the updated members + updatedMembers.map((m) => ({ userId: m.sender, deviceId: m.deviceId, membershipTs: m.createdTs() })), + ); + + expect(onEncryptionKeysChanged).not.toHaveBeenCalled(); + await jest.advanceTimersByTimeAsync(1000); + + // now should be rolled out + expect(onEncryptionKeysChanged).toHaveBeenCalledWith( + expect.any(Uint8Array), + 1, + "@alice:example.org:DEVICE01", + ); + + expect(statistics.counters.roomEventEncryptionKeysSent).toBe(2); + }); + }); + + describe("Receiving Keys", () => { + beforeEach(() => { + const emitter = new TypedEventEmitter(); + mockTransport = { + start: jest.fn(), + stop: jest.fn(), + sendKey: jest.fn().mockResolvedValue(undefined), + on: emitter.on.bind(emitter), + off: emitter.off.bind(emitter), + emit: emitter.emit.bind(emitter), + } as unknown as Mocked; + encryptionManager = new RTCEncryptionManager( + "@alice:example.org", + "DEVICE01", + getMembershipMock, + mockTransport, + statistics, + onEncryptionKeysChanged, + ); + }); + + it("should accept keys from transport", async () => { + jest.useFakeTimers(); + + const members = [ + aCallMembership("@bob:example.org", "BOBDEVICE"), + aCallMembership("@bob:example.org", "BOBDEVICE2"), + aCallMembership("@carl:example.org", "CARLDEVICE"), + ]; + getMembershipMock.mockReturnValue(members); + + encryptionManager.join(undefined); + encryptionManager.onMembershipsUpdate([]); + await jest.advanceTimersByTimeAsync(10); + + mockTransport.emit( + KeyTransportEvents.ReceivedKeys, + "@bob:example.org", + "BOBDEVICE", + "AAAAAAAAAAA", + 0 /* KeyId */, + 0 /* Timestamp */, + ); + mockTransport.emit( + KeyTransportEvents.ReceivedKeys, + "@bob:example.org", + "BOBDEVICE2", + "BBBBBBBBBBB", + 4 /* KeyId */, + 0 /* Timestamp */, + ); + mockTransport.emit( + KeyTransportEvents.ReceivedKeys, + "@carl:example.org", + "CARLDEVICE", + "CCCCCCCCCC", + 8 /* KeyId */, + 0 /* Timestamp */, + ); + + expect(onEncryptionKeysChanged).toHaveBeenCalledTimes(4); + expect(onEncryptionKeysChanged).toHaveBeenCalledWith( + decodeBase64("AAAAAAAAAAA"), + 0, + "@bob:example.org:BOBDEVICE", + ); + + expect(onEncryptionKeysChanged).toHaveBeenCalledWith( + decodeBase64("BBBBBBBBBBB"), + 4, + "@bob:example.org:BOBDEVICE2", + ); + + expect(onEncryptionKeysChanged).toHaveBeenCalledWith( + decodeBase64("CCCCCCCCCC"), + 8, + "@carl:example.org:CARLDEVICE", + ); + + expect(statistics.counters.roomEventEncryptionKeysReceived).toBe(3); + }); + + it("Should support quick re-joiner if keys received out of order", async () => { + jest.useFakeTimers(); + + const members = [aCallMembership("@carl:example.org", "CARLDEVICE")]; + getMembershipMock.mockReturnValue(members); + + // Let's join + encryptionManager.join(undefined); + await jest.advanceTimersByTimeAsync(10); + + // Simulate Carl leaving then joining back, and key received out of order + const initialKey0TimeStamp = 1000; + const newKey0TimeStamp = 2000; + + mockTransport.emit( + KeyTransportEvents.ReceivedKeys, + "@carol:example.org", + "CAROLDEVICE", + "BBBBBBBBBBB", + 0 /* KeyId */, + newKey0TimeStamp, + ); + + await jest.advanceTimersByTimeAsync(20); + + mockTransport.emit( + KeyTransportEvents.ReceivedKeys, + "@bob:example.org", + "CAROLDEVICE", + "AAAAAAAAAAA", + 0 /* KeyId */, + initialKey0TimeStamp, + ); + + await jest.advanceTimersByTimeAsync(20); + + // The latest key used for carol should be the one with the latest timestamp + + expect(onEncryptionKeysChanged).toHaveBeenCalledWith( + decodeBase64("BBBBBBBBBBB"), + 0, + "@carol:example.org:CAROLDEVICE", + ); + }); + }); + + it("Should only rotate once again if several membership changes during a rollout", async () => { + jest.useFakeTimers(); + + let members = [ + aCallMembership("@bob:example.org", "BOBDEVICE"), + aCallMembership("@bob:example.org", "BOBDEVICE2"), + aCallMembership("@carl:example.org", "CARLDEVICE"), + ]; + getMembershipMock.mockReturnValue(members); + + // Let's join + encryptionManager.join(undefined); + encryptionManager.onMembershipsUpdate([]); + await jest.advanceTimersByTimeAsync(10); + + // The initial rollout + expect(onEncryptionKeysChanged).toHaveBeenCalledWith( + expect.any(Uint8Array), + 0, + "@alice:example.org:DEVICE01", + ); + onEncryptionKeysChanged.mockClear(); + + // Trigger a key rotation with a leaver + members = [aCallMembership("@bob:example.org", "BOBDEVICE"), aCallMembership("@bob:example.org", "BOBDEVICE2")]; + getMembershipMock.mockReturnValue(members); + + // This should start a new key rollout + encryptionManager.onMembershipsUpdate(members); + await jest.advanceTimersByTimeAsync(10); + + // Now simulate a new leaver + members = [aCallMembership("@bob:example.org", "BOBDEVICE")]; + getMembershipMock.mockReturnValue(members); + + // The key `1` rollout is in progress + encryptionManager.onMembershipsUpdate(members); + await jest.advanceTimersByTimeAsync(10); + + // And another one ( plus a joiner) + const lastMembership = [aCallMembership("@bob:example.org", "BOBDEVICE3")]; + getMembershipMock.mockReturnValue(lastMembership); + // The key `1` rollout is still in progress + encryptionManager.onMembershipsUpdate(lastMembership); + await jest.advanceTimersByTimeAsync(10); + + // Let all rollouts finish + await jest.advanceTimersByTimeAsync(2000); + + // There should 2 rollout. The `1` rollout, then just one additional one + // that has "buffered" the 2 membership changes with leavers + expect(onEncryptionKeysChanged).toHaveBeenCalledTimes(2); + expect(onEncryptionKeysChanged).toHaveBeenCalledWith( + expect.any(Uint8Array), + 1, + "@alice:example.org:DEVICE01", + ); + expect(onEncryptionKeysChanged).toHaveBeenCalledWith( + expect.any(Uint8Array), + 2, + "@alice:example.org:DEVICE01", + ); + + // Key `2` should only be distributed to the last membership + expect(mockTransport.sendKey).toHaveBeenLastCalledWith( + expect.any(String), + 2, + // And send only to the last membership + [ + { + userId: "@bob:example.org", + deviceId: "BOBDEVICE3", + membershipTs: 1000, + }, + ], + ); + }); + + it("Should re-distribute key on transport switch", async () => { + const toDeviceEmitter = new TypedEventEmitter(); + const mockToDeviceTransport = { + start: jest.fn(), + stop: jest.fn(), + sendKey: jest.fn().mockResolvedValue(undefined), + on: toDeviceEmitter.on.bind(toDeviceEmitter), + off: toDeviceEmitter.off.bind(toDeviceEmitter), + emit: toDeviceEmitter.emit.bind(toDeviceEmitter), + setParentLogger: jest.fn(), + } as unknown as Mocked; + + const roomEmitter = new TypedEventEmitter(); + const mockRoomTransport = { + start: jest.fn(), + stop: jest.fn(), + sendKey: jest.fn().mockResolvedValue(undefined), + on: roomEmitter.on.bind(roomEmitter), + off: roomEmitter.off.bind(roomEmitter), + emit: roomEmitter.emit.bind(roomEmitter), + setParentLogger: jest.fn(), + } as unknown as Mocked; + + const mockLogger = { + debug: jest.fn(), + warn: jest.fn(), + } as unknown as Mocked; + + const transport = new RoomAndToDeviceTransport(mockToDeviceTransport, mockRoomTransport, { + getChild: jest.fn().mockReturnValue(mockLogger), + } as unknown as Mocked); + + encryptionManager = new RTCEncryptionManager( + "@alice:example.org", + "DEVICE01", + getMembershipMock, + transport, + statistics, + onEncryptionKeysChanged, + ); + + const members = [ + aCallMembership("@bob:example.org", "BOBDEVICE"), + aCallMembership("@bob:example.org", "BOBDEVICE2"), + aCallMembership("@carl:example.org", "CARLDEVICE"), + ]; + getMembershipMock.mockReturnValue(members); + + // Let's join + encryptionManager.join(undefined); + encryptionManager.onMembershipsUpdate([]); + await jest.advanceTimersByTimeAsync(10); + + // Should have sent the key to the toDevice transport + expect(mockToDeviceTransport.sendKey).toHaveBeenCalledTimes(1); + expect(mockRoomTransport.sendKey).not.toHaveBeenCalled(); + + // Simulate receiving a key by room transport + roomEmitter.emit( + KeyTransportEvents.ReceivedKeys, + "@bob:example.org", + "BOBDEVICE", + "AAAAAAAAAAA", + 0 /* KeyId */, + 0 /* Timestamp */, + ); + + await jest.runOnlyPendingTimersAsync(); + + // The key should have beed re-distributed to the room transport + expect(mockRoomTransport.sendKey).toHaveBeenCalled(); + expect(mockToDeviceTransport.sendKey).toHaveBeenCalledWith( + expect.any(String), + // It is the first key re-distributed + 0, + // to all the members + members.map((m) => ({ userId: m.sender, deviceId: m.deviceId, membershipTs: m.createdTs() })), + ); + }); + + function aCallMembership(userId: string, deviceId: string, ts: number = 1000): CallMembership { + return mockCallMembership( + Object.assign({}, membershipTemplate, { device_id: deviceId, created_ts: ts }), + "!room:id", + userId, + ); + } +}); diff --git a/spec/unit/matrixrtc/RoomAndToDeviceTransport.spec.ts b/spec/unit/matrixrtc/RoomAndToDeviceTransport.spec.ts index a4ce40aa67..82a6b7356b 100644 --- a/spec/unit/matrixrtc/RoomAndToDeviceTransport.spec.ts +++ b/spec/unit/matrixrtc/RoomAndToDeviceTransport.spec.ts @@ -16,7 +16,7 @@ limitations under the License. import { type Mocked } from "jest-mock"; -import { makeKey, makeMockEvent, makeMockRoom, membershipTemplate, mockCallMembership } from "./mocks"; +import { makeKey, makeMockEvent, makeMockRoom } from "./mocks"; import { EventType, type IRoomTimelineData, type Room, RoomEvent, type MatrixClient } from "../../../src"; import { ToDeviceKeyTransport } from "../../../src/matrixrtc/ToDeviceKeyTransport.ts"; import { @@ -88,7 +88,9 @@ describe("RoomAndToDeviceTransport", () => { }); it("only sends to device keys when sending a key", async () => { transport.start(); - await transport.sendKey("1235", 0, [mockCallMembership(membershipTemplate, roomId, "@alice:example.org")]); + await transport.sendKey("1235", 0, [ + { userId: "@alice:example.org", deviceId: "ALICEDEVICE", membershipTs: 1234 }, + ]); expect(toDeviceSendKeySpy).toHaveBeenCalledTimes(1); expect(roomSendKeySpy).toHaveBeenCalledTimes(0); expect(transport.enabled.room).toBeFalsy(); @@ -118,7 +120,9 @@ describe("RoomAndToDeviceTransport", () => { expect(transport.enabled.room).toBeTruthy(); expect(transport.enabled.toDevice).toBeFalsy(); - await transport.sendKey("1235", 0, [mockCallMembership(membershipTemplate, roomId, "@alice:example.org")]); + await transport.sendKey("1235", 0, [ + { userId: "@alice:example.org", deviceId: "AlICEDEV", membershipTs: 1234 }, + ]); expect(sendEventMock).toHaveBeenCalledTimes(1); expect(roomSendKeySpy).toHaveBeenCalledTimes(1); expect(toDeviceSendKeySpy).toHaveBeenCalledTimes(0); diff --git a/spec/unit/matrixrtc/ToDeviceKeyTransport.spec.ts b/spec/unit/matrixrtc/ToDeviceKeyTransport.spec.ts index e120eeb887..23c20a24ca 100644 --- a/spec/unit/matrixrtc/ToDeviceKeyTransport.spec.ts +++ b/spec/unit/matrixrtc/ToDeviceKeyTransport.spec.ts @@ -16,7 +16,7 @@ limitations under the License. import { type Mocked } from "jest-mock"; -import { makeMockEvent, membershipTemplate, mockCallMembership } from "./mocks"; +import { makeMockEvent } from "./mocks"; import { ClientEvent, EventType, type MatrixClient } from "../../../src"; import { ToDeviceKeyTransport } from "../../../src/matrixrtc/ToDeviceKeyTransport.ts"; import { getMockClientWithEventEmitter } from "../../test-utils/client.ts"; @@ -62,21 +62,9 @@ describe("ToDeviceKeyTransport", () => { const keyBase64Encoded = "ABCDEDF"; const keyIndex = 2; await transport.sendKey(keyBase64Encoded, keyIndex, [ - mockCallMembership( - Object.assign({}, membershipTemplate, { device_id: "BOBDEVICE" }), - roomId, - "@bob:example.org", - ), - mockCallMembership( - Object.assign({}, membershipTemplate, { device_id: "CARLDEVICE" }), - roomId, - "@carl:example.org", - ), - mockCallMembership( - Object.assign({}, membershipTemplate, { device_id: "MATDEVICE" }), - roomId, - "@mat:example.org", - ), + { userId: "@bob:example.org", deviceId: "BOBDEVICE", membershipTs: 1234 }, + { userId: "@carl:example.org", deviceId: "CARLDEVICE", membershipTs: 1234 }, + { userId: "@mat:example.org", deviceId: "MATDEVICE", membershipTs: 1234 }, ]); expect(mockClient.encryptAndSendToDevice).toHaveBeenCalledTimes(1); @@ -101,6 +89,7 @@ describe("ToDeviceKeyTransport", () => { call_id: "", scope: "m.room", }, + sent_ts: expect.any(Number), }, ); @@ -149,11 +138,7 @@ describe("ToDeviceKeyTransport", () => { const keyBase64Encoded = "ABCDEDF"; const keyIndex = 2; await transport.sendKey(keyBase64Encoded, keyIndex, [ - mockCallMembership( - Object.assign({}, membershipTemplate, { device_id: "MYDEVICE" }), - roomId, - "@alice:example.org", - ), + { userId: "@alice:example.org", deviceId: "MYDEVICE", membershipTs: 1234 }, ]); transport.start(); diff --git a/src/matrixrtc/EncryptionManager.ts b/src/matrixrtc/EncryptionManager.ts index 0b0fe9cb92..bfcc5d8887 100644 --- a/src/matrixrtc/EncryptionManager.ts +++ b/src/matrixrtc/EncryptionManager.ts @@ -6,6 +6,7 @@ import { safeGetRetryAfterMs } from "../http-api/errors.ts"; import { type CallMembership } from "./CallMembership.ts"; import { type KeyTransportEventListener, KeyTransportEvents, type IKeyTransport } from "./IKeyTransport.ts"; import { isMyMembership, type Statistics } from "./types.ts"; +import { getParticipantId } from "./utils.ts"; import { type EnabledTransports, RoomAndToDeviceEvents, @@ -82,6 +83,7 @@ export class EncryptionManager implements IEncryptionManager { private latestGeneratedKeyIndex = -1; private joinConfig: EncryptionConfig | undefined; private logger: Logger; + public constructor( private userId: string, private deviceId: string, @@ -280,7 +282,18 @@ export class EncryptionManager implements IEncryptionManager { try { this.statistics.counters.roomEventEncryptionKeysSent += 1; - await this.transport.sendKey(encodeUnpaddedBase64(keyToSend), keyIndexToSend, this.getMemberships()); + const targets = this.getMemberships() + .filter((membership) => { + return membership.sender != undefined; + }) + .map((membership) => { + return { + userId: membership.sender!, + deviceId: membership.deviceId, + membershipTs: membership.createdTs(), + }; + }); + await this.transport.sendKey(encodeUnpaddedBase64(keyToSend), keyIndexToSend, targets); this.logger.debug( `sendEncryptionKeysEvent participantId=${this.userId}:${this.deviceId} numKeys=${myKeys.length} currentKeyIndex=${this.latestGeneratedKeyIndex} keyIndexToSend=${keyIndexToSend}`, this.encryptionKeys, @@ -408,8 +421,6 @@ export class EncryptionManager implements IEncryptionManager { }; } -const getParticipantId = (userId: string, deviceId: string): string => `${userId}:${deviceId}`; - function keysEqual(a: Uint8Array | undefined, b: Uint8Array | undefined): boolean { if (a === b) return true; return !!a && !!b && a.length === b.length && a.every((x, i) => x === b[i]); diff --git a/src/matrixrtc/IKeyTransport.ts b/src/matrixrtc/IKeyTransport.ts index b9776b90fc..17cfa7682a 100644 --- a/src/matrixrtc/IKeyTransport.ts +++ b/src/matrixrtc/IKeyTransport.ts @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -import { type CallMembership } from "./CallMembership.ts"; +import { type ParticipantDeviceInfo } from "./types.ts"; export enum KeyTransportEvents { ReceivedKeys = "received_keys", @@ -43,7 +43,7 @@ export interface IKeyTransport { * @param index * @param members - The participants that should get they key */ - sendKey(keyBase64Encoded: string, index: number, members: CallMembership[]): Promise; + sendKey(keyBase64Encoded: string, index: number, members: ParticipantDeviceInfo[]): Promise; /** Subscribe to keys from this transport. */ on(event: KeyTransportEvents.ReceivedKeys, listener: KeyTransportEventListener): this; diff --git a/src/matrixrtc/MatrixRTCSession.ts b/src/matrixrtc/MatrixRTCSession.ts index d20773d2d7..62ad994aa0 100644 --- a/src/matrixrtc/MatrixRTCSession.ts +++ b/src/matrixrtc/MatrixRTCSession.ts @@ -31,6 +31,7 @@ import { logDurationSync } from "../utils.ts"; import { type Statistics } from "./types.ts"; import { RoomKeyTransport } from "./RoomKeyTransport.ts"; import type { IMembershipManager } from "./IMembershipManager.ts"; +import { RTCEncryptionManager } from "./RTCEncryptionManager.ts"; import { RoomAndToDeviceEvents, type RoomAndToDeviceEventsHandlerMap, @@ -398,6 +399,7 @@ export class MatrixRTCSession extends TypedEventEmitter< // Create Encryption manager let transport; if (joinConfig?.useExperimentalToDeviceTransport) { + this.logger.info("Using experimental to-device transport for encryption keys"); this.logger.info("Using to-device with room fallback transport for encryption keys"); const [uId, dId] = [this.client.getUserId()!, this.client.getDeviceId()!]; const [room, client, statistics] = [this.roomSubset, this.client, this.statistics]; @@ -408,20 +410,40 @@ export class MatrixRTCSession extends TypedEventEmitter< // Expose the changes so the ui can display the currently used transport. this.reEmitter.reEmit(transport, [RoomAndToDeviceEvents.EnabledTransportsChanged]); + this.encryptionManager = new RTCEncryptionManager( + this.client.getUserId()!, + this.client.getDeviceId()!, + () => this.memberships, + transport, + this.statistics, + (keyBin: Uint8Array, encryptionKeyIndex: number, participantId: string) => { + this.emit( + MatrixRTCSessionEvent.EncryptionKeyChanged, + keyBin, + encryptionKeyIndex, + participantId, + ); + }, + this.logger, + ); } else { transport = new RoomKeyTransport(this.roomSubset, this.client, this.statistics); + this.encryptionManager = new EncryptionManager( + this.client.getUserId()!, + this.client.getDeviceId()!, + () => this.memberships, + transport, + this.statistics, + (keyBin: Uint8Array, encryptionKeyIndex: number, participantId: string) => { + this.emit( + MatrixRTCSessionEvent.EncryptionKeyChanged, + keyBin, + encryptionKeyIndex, + participantId, + ); + }, + ); } - this.encryptionManager = new EncryptionManager( - this.client.getUserId()!, - this.client.getDeviceId()!, - () => this.memberships, - transport, - this.statistics, - (keyBin: Uint8Array, encryptionKeyIndex: number, participantId: string) => { - this.emit(MatrixRTCSessionEvent.EncryptionKeyChanged, keyBin, encryptionKeyIndex, participantId); - }, - this.logger, - ); } // Join! diff --git a/src/matrixrtc/RTCEncryptionManager.ts b/src/matrixrtc/RTCEncryptionManager.ts new file mode 100644 index 0000000000..fa077a5d7b --- /dev/null +++ b/src/matrixrtc/RTCEncryptionManager.ts @@ -0,0 +1,321 @@ +/* +Copyright 2025 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import { type IEncryptionManager } from "./EncryptionManager.ts"; +import { type EncryptionConfig } from "./MatrixRTCSession.ts"; +import { type CallMembership } from "./CallMembership.ts"; +import { decodeBase64, encodeBase64 } from "../base64.ts"; +import { type IKeyTransport, type KeyTransportEventListener, KeyTransportEvents } from "./IKeyTransport.ts"; +import { logger as rootLogger, type Logger } from "../logger.ts"; +import { sleep } from "../utils.ts"; +import type { InboundEncryptionSession, ParticipantDeviceInfo, ParticipantId, Statistics } from "./types.ts"; +import { getParticipantId, KeyBuffer } from "./utils.ts"; +import { + type EnabledTransports, + RoomAndToDeviceEvents, + RoomAndToDeviceTransport, +} from "./RoomAndToDeviceKeyTransport.ts"; + +type OutboundEncryptionSession = { + key: Uint8Array; + creationTS: number; + sharedWith: Array; + // This is an index acting as the id of the key + keyId: number; +}; + +/** + * RTCEncryptionManager is used to manage the encryption keys for a call. + * + * It is responsible for distributing the keys to the other participants and rotating the keys if needed. + * + * This manager when used with to-device transport will share the existing key only to new joiners, and rotate + * if there is a leaver. + * + * XXX In the future we want to distribute a ratcheted key not the current one for new joiners. + */ +export class RTCEncryptionManager implements IEncryptionManager { + // The current per-sender media key for this device + private outboundSession: OutboundEncryptionSession | null = null; + + /** + * Ensures that there is only one distribute operation at a time for that call. + */ + private currentKeyDistributionPromise: Promise | null = null; + /** The time to wait before using outbound session after it has been distributed */ + private delayRolloutTimeMillis = 1000; + /** + * If a new key distribution is being requested while one is going on, we will set this flag to true. + * This will ensure that a new round is started after the current one. + * @private + */ + private needToEnsureKeyAgain = false; + + /** + * There is a possibility that keys arrive in wrong order. + * For example after a quick join/leave/join, there will be 2 keys of index 0 distributed and + * if they are received in wrong order the stream won't be decryptable. + * For that reason we keep a small buffer of keys for a limited time to disambiguate. + * @private + */ + private keyBuffer = new KeyBuffer(1000 /** 1 second */); + + private logger: Logger; + + public constructor( + private userId: string, + private deviceId: string, + private getMemberships: () => CallMembership[], + private transport: IKeyTransport, + private statistics: Statistics, + private onEncryptionKeysChanged: ( + keyBin: Uint8Array, + encryptionKeyIndex: number, + participantId: ParticipantId, + ) => void, + parentLogger?: Logger, + ) { + this.logger = (parentLogger ?? rootLogger).getChild(`[EncryptionManager]`); + } + + public getEncryptionKeys(): Map> { + // This is deprecated should be ignored. Only use by tests? + return new Map(); + } + + public join(joinConfig: EncryptionConfig | undefined): void { + this.logger.info(`Joining room`); + this.delayRolloutTimeMillis = joinConfig?.useKeyDelay ?? 1000; + this.transport.on(KeyTransportEvents.ReceivedKeys, this.onNewKeyReceived); + // Deprecate RoomKeyTransport: this can get removed. + if (this.transport instanceof RoomAndToDeviceTransport) { + this.transport.on(RoomAndToDeviceEvents.EnabledTransportsChanged, this.onTransportChanged); + } + + this.transport.start(); + } + + public leave(): void { + this.keyBuffer.clear(); + this.transport.off(KeyTransportEvents.ReceivedKeys, this.onNewKeyReceived); + this.transport.stop(); + } + + private onTransportChanged: (enabled: EnabledTransports) => void = () => { + this.logger.info("Transport change detected, restarting key distribution"); + // Temporary for backwards compatibility + if (this.currentKeyDistributionPromise) { + this.currentKeyDistributionPromise + .then(() => { + if (this.outboundSession) { + this.outboundSession.sharedWith = []; + this.ensureMediaKey(); + } + }) + .catch((e) => { + this.logger.error("Failed to restart key distribution", e); + }); + } else { + if (this.outboundSession) { + this.outboundSession.sharedWith = []; + this.ensureMediaKey(); + } + } + }; + + /** + * Will ensure that a new key is distributed and used to encrypt our media. + * If this function is called repeatidly, the calls will be buffered to a single key rotation. + */ + private ensureMediaKey(): void { + if (this.currentKeyDistributionPromise == null) { + this.logger.debug(`No active rollout, start a new one`); + // start a rollout + this.currentKeyDistributionPromise = this.rolloutOutboundKey().then(() => { + this.logger.debug(`Rollout completed`); + this.currentKeyDistributionPromise = null; + if (this.needToEnsureKeyAgain) { + this.logger.debug(`New Rollout needed`); + this.needToEnsureKeyAgain = false; + // rollout a new one + this.ensureMediaKey(); + } + }); + } else { + // There is a rollout in progress, but a key rotation is requested (could be caused by a membership change) + // Remember that a new rotation is needed after the current one. + this.logger.debug(`Rollout in progress, a new rollout will be started after the current one`); + this.needToEnsureKeyAgain = true; + } + } + + public onNewKeyReceived: KeyTransportEventListener = (userId, deviceId, keyBase64Encoded, index, timestamp) => { + this.logger.debug(`Received key over transport ${userId}:${deviceId} at index ${index}`); + + // We received a new key, notify the video layer of this new key so that it can decrypt the frames properly. + const participantId = getParticipantId(userId, deviceId); + const keyBin = decodeBase64(keyBase64Encoded); + const candidateInboundSession: InboundEncryptionSession = { + key: keyBin, + participantId, + keyIndex: index, + creationTS: timestamp, + }; + + const validSession = this.keyBuffer.disambiguate(participantId, candidateInboundSession); + if (validSession) { + this.onEncryptionKeysChanged(validSession.key, validSession.keyIndex, validSession.participantId); + this.statistics.counters.roomEventEncryptionKeysReceived += 1; + } else { + this.logger.info(`Received an out of order key for ${userId}:${deviceId}, dropping it`); + } + }; + + /** + * Called when the membership of the call changes. + * This encryption manager is very basic, it will rotate the key everytime this is called. + * @param oldMemberships + */ + public onMembershipsUpdate(oldMemberships: CallMembership[]): void { + this.logger.trace(`onMembershipsUpdate`); + + // Ensure the key is distributed. This will be no-op if the key is already being distributed to everyone. + // If there is an ongoing distribution, it will be completed before a new one is started. + this.ensureMediaKey(); + } + + private async rolloutOutboundKey(): Promise { + const isFirstKey = this.outboundSession == null; + if (isFirstKey) { + // create the first key + this.outboundSession = { + key: this.generateRandomKey(), + creationTS: Date.now(), + sharedWith: [], + keyId: 0, + }; + this.onEncryptionKeysChanged( + this.outboundSession.key, + this.outboundSession.keyId, + getParticipantId(this.userId, this.deviceId), + ); + } + // get current memberships + const toShareWith: ParticipantDeviceInfo[] = this.getMemberships() + .filter((membership) => { + return membership.sender != undefined; + }) + .map((membership) => { + return { + userId: membership.sender!, + deviceId: membership.deviceId, + membershipTs: membership.createdTs(), + }; + }); + + let alreadySharedWith = this.outboundSession?.sharedWith ?? []; + + // Some users might have rotate their membership event (formally called fingerprint) meaning they might have + // clear their key. Reset the `alreadySharedWith` flag for them. + alreadySharedWith = alreadySharedWith.filter( + (x) => + // If there was a member with same userId and deviceId but different membershipTs, we need to clear it + !toShareWith.some( + (o) => x.userId == o.userId && x.deviceId == o.deviceId && x.membershipTs != o.membershipTs, + ), + ); + + const anyLeft = alreadySharedWith.filter( + (x) => + !toShareWith.some( + (o) => x.userId == o.userId && x.deviceId == o.deviceId && x.membershipTs == o.membershipTs, + ), + ); + const anyJoined = toShareWith.filter( + (x) => + !alreadySharedWith.some( + (o) => x.userId == o.userId && x.deviceId == o.deviceId && x.membershipTs == o.membershipTs, + ), + ); + + let toDistributeTo: ParticipantDeviceInfo[] = []; + let outboundKey: OutboundEncryptionSession; + let hasKeyChanged = false; + if (anyLeft.length > 0) { + // We need to rotate the key + const newOutboundKey: OutboundEncryptionSession = { + key: this.generateRandomKey(), + creationTS: Date.now(), + sharedWith: [], + keyId: this.nextKeyIndex(), + }; + hasKeyChanged = true; + + this.logger.info(`creating new outbound key index:${newOutboundKey.keyId}`); + // Set this new key as the current one + this.outboundSession = newOutboundKey; + + // Send + toDistributeTo = toShareWith; + outboundKey = newOutboundKey; + } else if (anyJoined.length > 0) { + // keep the same key + // XXX In the future we want to distribute a ratcheted key not the current one + toDistributeTo = anyJoined; + outboundKey = this.outboundSession!; + } else { + // no changes + return; + } + + try { + this.logger.trace(`Sending key...`); + await this.transport.sendKey(encodeBase64(outboundKey.key), outboundKey.keyId, toDistributeTo); + this.statistics.counters.roomEventEncryptionKeysSent += 1; + outboundKey.sharedWith.push(...toDistributeTo); + this.logger.trace( + `key index:${outboundKey.keyId} sent to ${outboundKey.sharedWith.map((m) => `${m.userId}:${m.deviceId}`).join(",")}`, + ); + if (hasKeyChanged) { + // Delay a bit before using this key + // It is recommended not to start using a key immediately but instead wait for a short time to make sure it is delivered. + this.logger.trace(`Delay Rollout for key:${outboundKey.keyId}...`); + await sleep(this.delayRolloutTimeMillis); + this.logger.trace(`...Delayed rollout of index:${outboundKey.keyId} `); + this.onEncryptionKeysChanged( + outboundKey.key, + outboundKey.keyId, + getParticipantId(this.userId, this.deviceId), + ); + } + } catch (err) { + this.logger.error(`Failed to rollout key`, err); + } + } + + private nextKeyIndex(): number { + if (this.outboundSession) { + return (this.outboundSession!.keyId + 1) % 256; + } + return 0; + } + + private generateRandomKey(): Uint8Array { + const key = new Uint8Array(16); + globalThis.crypto.getRandomValues(key); + return key; + } +} diff --git a/src/matrixrtc/RoomAndToDeviceKeyTransport.ts b/src/matrixrtc/RoomAndToDeviceKeyTransport.ts index a3d656eeb4..c513b1f33c 100644 --- a/src/matrixrtc/RoomAndToDeviceKeyTransport.ts +++ b/src/matrixrtc/RoomAndToDeviceKeyTransport.ts @@ -16,10 +16,10 @@ limitations under the License. import { logger as rootLogger, type Logger } from "../logger.ts"; import { KeyTransportEvents, type KeyTransportEventsHandlerMap, type IKeyTransport } from "./IKeyTransport.ts"; -import { type CallMembership } from "./CallMembership.ts"; import type { RoomKeyTransport } from "./RoomKeyTransport.ts"; import type { ToDeviceKeyTransport } from "./ToDeviceKeyTransport.ts"; import { TypedEventEmitter } from "../models/typed-event-emitter.ts"; +import { type ParticipantDeviceInfo } from "./types.ts"; // Deprecate RoomAndToDeviceTransport: This whole class is only a stop gap until we remove RoomKeyTransport. export interface EnabledTransports { @@ -106,7 +106,7 @@ export class RoomAndToDeviceTransport this.toDeviceTransport.stop(); } - public async sendKey(keyBase64Encoded: string, index: number, members: CallMembership[]): Promise { + public async sendKey(keyBase64Encoded: string, index: number, members: ParticipantDeviceInfo[]): Promise { this.logger.debug( `Sending key with index ${index} to call members (count=${members.length}) via:` + (this._enabled.room ? "room transport" : "") + diff --git a/src/matrixrtc/RoomKeyTransport.ts b/src/matrixrtc/RoomKeyTransport.ts index 5f12d9c556..312a6ce257 100644 --- a/src/matrixrtc/RoomKeyTransport.ts +++ b/src/matrixrtc/RoomKeyTransport.ts @@ -15,13 +15,12 @@ limitations under the License. */ import type { MatrixClient } from "../client.ts"; -import type { EncryptionKeysEventContent, Statistics } from "./types.ts"; +import { type EncryptionKeysEventContent, type ParticipantDeviceInfo, type Statistics } from "./types.ts"; import { EventType } from "../@types/event.ts"; import { type MatrixError } from "../http-api/errors.ts"; import { logger as rootLogger, type Logger } from "../logger.ts"; import { KeyTransportEvents, type KeyTransportEventsHandlerMap, type IKeyTransport } from "./IKeyTransport.ts"; import { type MatrixEvent } from "../models/event.ts"; -import { type CallMembership } from "./CallMembership.ts"; import { TypedEventEmitter } from "../models/typed-event-emitter.ts"; import { type Room, RoomEvent } from "../models/room.ts"; @@ -81,7 +80,7 @@ export class RoomKeyTransport } /** implements {@link IKeyTransport#sendKey} */ - public async sendKey(keyBase64Encoded: string, index: number, members: CallMembership[]): Promise { + public async sendKey(keyBase64Encoded: string, index: number, members: ParticipantDeviceInfo[]): Promise { // members not used in room transports as the keys are sent to all room members const content: EncryptionKeysEventContent = { keys: [ diff --git a/src/matrixrtc/ToDeviceKeyTransport.ts b/src/matrixrtc/ToDeviceKeyTransport.ts index 8486b42052..4a6454d0ab 100644 --- a/src/matrixrtc/ToDeviceKeyTransport.ts +++ b/src/matrixrtc/ToDeviceKeyTransport.ts @@ -17,8 +17,7 @@ limitations under the License. import { TypedEventEmitter } from "../models/typed-event-emitter.ts"; import { type IKeyTransport, KeyTransportEvents, type KeyTransportEventsHandlerMap } from "./IKeyTransport.ts"; import { type Logger, logger as rootLogger } from "../logger.ts"; -import type { CallMembership } from "./CallMembership.ts"; -import type { EncryptionKeysToDeviceEventContent, Statistics } from "./types.ts"; +import { type EncryptionKeysToDeviceEventContent, type ParticipantDeviceInfo, type Statistics } from "./types.ts"; import { ClientEvent, type MatrixClient } from "../client.ts"; import type { MatrixEvent } from "../models/event.ts"; import { EventType } from "../@types/event.ts"; @@ -32,6 +31,7 @@ export class ToDeviceKeyTransport implements IKeyTransport { private logger: Logger = rootLogger; + public setParentLogger(parentLogger: Logger): void { this.logger = parentLogger.getChild(`[ToDeviceKeyTransport]`); } @@ -56,7 +56,7 @@ export class ToDeviceKeyTransport this.client.off(ClientEvent.ToDeviceEvent, this.onToDeviceEvent); } - public async sendKey(keyBase64Encoded: string, index: number, members: CallMembership[]): Promise { + public async sendKey(keyBase64Encoded: string, index: number, members: ParticipantDeviceInfo[]): Promise { const content: EncryptionKeysToDeviceEventContent = { keys: { index: index, @@ -71,24 +71,18 @@ export class ToDeviceKeyTransport application: "m.call", scope: "m.room", }, + sent_ts: Date.now(), }; const targets = members - .filter((member) => { - // filter malformed call members - if (member.sender == undefined || member.deviceId == undefined) { - this.logger.warn(`Malformed call member: ${member.sender}|${member.deviceId}`); - return false; - } - // Filter out me - return !(member.sender == this.userId && member.deviceId == this.deviceId); - }) .map((member) => { return { - userId: member.sender!, + userId: member.userId!, deviceId: member.deviceId!, }; - }); + }) + // filter out me + .filter((member) => !(member.userId == this.userId && member.deviceId == this.deviceId)); if (targets.length > 0) { await this.client.encryptAndSendToDevice(EventType.CallEncryptionKeysPrefix, targets, content); diff --git a/src/matrixrtc/types.ts b/src/matrixrtc/types.ts index d408080dfa..d8932359a7 100644 --- a/src/matrixrtc/types.ts +++ b/src/matrixrtc/types.ts @@ -16,11 +16,29 @@ limitations under the License. import type { IMentions } from "../matrix.ts"; import type { CallMembership } from "./CallMembership.ts"; +export type ParticipantId = string; + export interface EncryptionKeyEntry { index: number; key: string; } +export type ParticipantDeviceInfo = { + userId: string; + deviceId: string; + membershipTs: number; +}; + +/** + * A type representing the information needed to decrypt video streams. + */ +export type InboundEncryptionSession = { + key: Uint8Array; + participantId: ParticipantId; + keyIndex: number; + creationTS: number; +}; + export interface EncryptionKeysEventContent { keys: EncryptionKeyEntry[]; device_id: string; diff --git a/src/matrixrtc/utils.ts b/src/matrixrtc/utils.ts new file mode 100644 index 0000000000..cf41602976 --- /dev/null +++ b/src/matrixrtc/utils.ts @@ -0,0 +1,85 @@ +/* +Copyright 2025 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import type { InboundEncryptionSession, ParticipantId } from "./types.ts"; + +type BufferEntry = { + keys: Map; + timeout: any; +}; + +/** + * Holds the key received for a few seconds before dropping them in order to support some edge case with + * out of order keys. + */ +export class KeyBuffer { + private readonly ttl; + + private buffer: Map = new Map(); + + public constructor(ttl?: number) { + this.ttl = ttl ?? 1000; // Default 1 second + } + + /** + * Check if there is a recent key with the same keyId (index) and then use the creationTS to decide what to + * do with the key. If the key received is older than the one already in the buffer, it is ignored. + * @param participantId + * @param item + */ + public disambiguate(participantId: ParticipantId, item: InboundEncryptionSession): InboundEncryptionSession | null { + if (!this.buffer.has(participantId)) { + const timeout = setTimeout(() => { + this.buffer.delete(participantId); + }, this.ttl); + + const map = new Map(); + map.set(item.keyIndex, item); + const entry: BufferEntry = { + keys: map, + timeout, + }; + this.buffer.set(participantId, entry); + return item; + } + + const entry = this.buffer.get(participantId)!; + clearTimeout(entry.timeout); + entry.timeout = setTimeout(() => { + this.buffer.delete(participantId); + }, this.ttl); + + const existing = entry.keys.get(item.keyIndex); + if (existing && existing.creationTS > item.creationTS) { + // The existing is more recent just ignore this one, it is a key received out of order + return null; + } else { + entry.keys.set(item.keyIndex, item); + return item; + } + } + + public clear(): void { + this.buffer.forEach((entry) => { + clearTimeout(entry.timeout); + }); + this.buffer.clear(); + } +} + +export function getParticipantId(userId: string, deviceId: string): ParticipantId { + return `${userId}:${deviceId}`; +}