diff --git a/packages/sdk/src/mls/mlsAgent.ts b/packages/sdk/src/mls/mlsAgent.ts new file mode 100644 index 0000000000..ff3d1d6da1 --- /dev/null +++ b/packages/sdk/src/mls/mlsAgent.ts @@ -0,0 +1,280 @@ +import TypedEmitter from 'typed-emitter' +import { StreamEncryptionEvents, StreamStateEvents, SyncedStreamEvents } from '../streamEvents' +import { MlsLoop } from './mlsLoop' +import { bin_toHexString, elogger, ELogger, shortenHexString } from '@river-build/dlog' +import { MlsStream } from './mlsStream' +import { MlsProcessor } from './mlsProcessor' +import { Client } from '../client' +import { MLS_ALGORITHM } from './constants' +import { EncryptedContent } from '../encryptedContentTypes' +import { Stream } from '../stream' +import { IPersistenceStore } from '../persistenceStore' +import { MlsCryptoStore, toLocalEpochSecretDTO, toLocalViewDTO } from './mlsCryptoStore' +import { LocalView } from './view/local' +import { IndefiniteValueAwaiter } from './awaiter' +import { StreamUpdate, StreamUpdateDelegate } from './types' + +const defaultLogger = elogger('csb:mls:agent') + +export type MlsAgentOpts = { + log: ELogger + mlsAlwaysEnabled: boolean +} + +const defaultMlsAgentOpts = { + log: defaultLogger, + mlsAlwaysEnabled: false, + delayMs: 15, + sendingOptions: {}, +} + +export class MlsAgent implements StreamUpdateDelegate { + private readonly client: Client + // private readonly mlsClient: MlsClient + private readonly persistenceStore: IPersistenceStore + private readonly encryptionEmitter?: TypedEmitter + private readonly stateEmitter?: TypedEmitter + + public readonly streams: Map = new Map() + public readonly processor: MlsProcessor + public readonly loop: MlsLoop + public readonly store: MlsCryptoStore + + private initRequests: Map> = new Map() + + private log: ELogger + public mlsAlwaysEnabled: boolean = false + + public constructor( + client: Client, + processor: MlsProcessor, + loop: MlsLoop, + store: MlsCryptoStore, + persistenceStore: IPersistenceStore, + encryptionEmitter?: TypedEmitter, + stateEmitter?: TypedEmitter, + opts?: MlsAgentOpts, + ) { + this.client = client + this.persistenceStore = persistenceStore + this.encryptionEmitter = encryptionEmitter + this.stateEmitter = stateEmitter + this.processor = processor + this.loop = loop + this.store = store + + const mlsAgentOpts = { + ...defaultMlsAgentOpts, + ...opts, + } + this.log = mlsAgentOpts.log + this.mlsAlwaysEnabled = mlsAgentOpts.mlsAlwaysEnabled + } + + public start(): void { + this.encryptionEmitter?.on('mlsNewEncryptedContent', this.onNewEncryptedContent) + this.encryptionEmitter?.on('mlsNewConfirmedEvent', this.onConfirmedEvent) + this.stateEmitter?.on( + 'streamEncryptionAlgorithmUpdated', + this.onStreamEncryptionAlgorithmUpdated, + ) + this.stateEmitter?.on('streamInitialized', this.onStreamInitialized) + } + + public stop(): void { + this.encryptionEmitter?.off('mlsNewEncryptedContent', this.onNewEncryptedContent) + this.encryptionEmitter?.off('mlsNewConfirmedEvent', this.onConfirmedEvent) + this.stateEmitter?.off('streamInitialized', this.onStreamInitialized) + this.stateEmitter?.off( + 'streamEncryptionAlgorithmUpdated', + this.onStreamEncryptionAlgorithmUpdated, + ) + } + + public readonly onStreamInitialized: StreamStateEvents['streamInitialized'] = ( + streamId: string, + ): void => { + this.log.log('onStreamInitialized', streamId) + this.loop.enqueueStreamUpdate(streamId) + } + + public readonly onConfirmedEvent: StreamEncryptionEvents['mlsNewConfirmedEvent'] = ( + ...args + ): void => { + this.log.log('agent: onConfirmedEvent', { + confirmedEventNum: args[1].confirmedEventNum, + case: args[1].case, + }) + this.loop.enqueueConfirmedEvent(...args) + } + + public readonly onStreamEncryptionAlgorithmUpdated = ( + streamId: string, + encryptionAlgorithm?: string, + ): void => { + this.log.log('agent: onStreamEncryptionAlgorithmUpdated', streamId, encryptionAlgorithm) + if (encryptionAlgorithm === MLS_ALGORITHM) { + this.loop.enqueueStreamUpdate(streamId) + } + } + + public readonly onNewEncryptedContent: StreamEncryptionEvents['mlsNewEncryptedContent'] = ( + streamId: string, + eventId: string, + content: EncryptedContent, + ): void => { + this.log.log('onNewEncryptedContent', streamId, eventId, content.content.mls?.epoch) + this.loop.enqueueNewEncryptedContent(streamId, eventId, content) + } + + public readonly onStreamRemovedFromSync: SyncedStreamEvents['streamRemovedFromSync'] = ( + streamId: string, + ): void => { + this.log.log('agent: onStreamRemovedFromSync', streamId) + // TODO: Persist MLS stuff + this.streams.delete(streamId) + } + + // This potentially involves loading from storage + public async initMlsStream(stream: Stream): Promise { + this.log.log('initStream', stream.streamId) + + let mlsStream = this.streams.get(stream.streamId) + + if (mlsStream !== undefined) { + this.log.log('stream already initialized', stream.streamId) + return mlsStream + } + + const existingAwaiter = this.initRequests.get(stream.streamId) + if (existingAwaiter !== undefined) { + return existingAwaiter.promise + } + + const innerAwaiter = new IndefiniteValueAwaiter() + const awaiter = { + promise: innerAwaiter.promise.then((value) => { + this.initRequests.delete(stream.streamId) + return value + }), + resolve: innerAwaiter.resolve, + } + + this.initRequests.set(stream.streamId, awaiter) + + // fetch localview from storage + let localView: LocalView | undefined + const dtos = await this.store.getLocalViewDTO(stream.streamId) + if (dtos !== undefined) { + this.log.log('loading local view', stream.streamId) + this.log.log('loading group', bin_toHexString(dtos.viewDTO.groupId)) + try { + localView = await this.processor.loadLocalView(dtos.viewDTO) + for (const localEpochSecretDTO of dtos.epochSecretDTOs) { + const epochSecret = { + epoch: BigInt(localEpochSecretDTO.epoch), + secret: localEpochSecretDTO.secret, + derivedKeys: { + publicKey: localEpochSecretDTO.derivedKeys.publicKey, + secretKey: localEpochSecretDTO.derivedKeys.secretKey, + }, + } + localView.epochSecrets.set(epochSecret.epoch, epochSecret) + } + } catch (e) { + this.log.error?.('loadLocalView error', stream.streamId, e) + } + } + + const mlsStreamOpts = { log: this.log.extend(shortenHexString(stream.streamId)) } + mlsStream = new MlsStream( + stream.streamId, + stream, + this.persistenceStore, + localView, + mlsStreamOpts, + ) + this.streams.set(stream.streamId, mlsStream) + awaiter.resolve(mlsStream) + + return awaiter.promise + } + + public async getMlsStream(stream: Stream): Promise { + const mlsStream = this.streams.get(stream.streamId) + if (mlsStream !== undefined) { + return mlsStream + } + return this.initMlsStream(stream) + } + + public async handleStreamUpdate(streamUpdate: StreamUpdate): Promise { + const streamId = streamUpdate.streamId + const stream = this.client.streams.get(streamId) + if (stream === undefined) { + throw new Error('stream not initialized') + } + + const encryptionAlgorithm = stream.view.membershipContent.encryptionAlgorithm + this.log.log('algorithm', encryptionAlgorithm) + + const mlsEnabled = encryptionAlgorithm === MLS_ALGORITHM || this.mlsAlwaysEnabled + + const mlsStream = await this.getMlsStream(stream) + + this.log.log('agent: mlsEnabled', streamId, mlsEnabled) + + if (mlsEnabled) { + // this.log.debug?.('agent: updated onchain view', streamId, mlsStream.onChainView) + await mlsStream.handleStreamUpdate(streamUpdate) + // TODO: this is potentially slow + await mlsStream.retryDecryptionFailures() + + this.log.log('agent: ', { + status: mlsStream.localView?.status ?? 'missing', + onChain: { + accepted: mlsStream.onChainView.accepted.size, + rejected: mlsStream.onChainView.rejected.size, + commits: mlsStream.onChainView.commits.size, + sealed: mlsStream.onChainView.sealedEpochSecrets.keys(), + }, + local: { + secrets: mlsStream.localView?.epochSecrets.keys() ?? [], + }, + }) + + if (mlsStream.localView?.status === 'active') { + this.log.log('agent: active', streamId) + // TODO: welcome new Clients + } else { + this.log.log('agent: inactive', streamId) + // are there any pending encrypts or decrypts? + const areTherePendingEncryptsOrDecrypts = + mlsStream.decryptionFailures.size > 0 || + mlsStream.awaitingActiveLocalView !== undefined + if (mlsEnabled || areTherePendingEncryptsOrDecrypts) { + this.log.log('agent: initializeOrJoinGroup', streamId) + try { + await this.processor.initializeOrJoinGroup(mlsStream) + } catch (e) { + this.log.error?.('agent: initializeOrJoinGroup error', streamId) + this.log.error?.('enqueue retry') + this.loop.enqueueStreamUpdate(streamId) + } + } + } + await this.processor.announceEpochSecrets(mlsStream) + + // Persisting the group to storage + if (mlsStream.localView !== undefined) { + const localViewDTO = toLocalViewDTO(mlsStream.streamId, mlsStream.localView) + const epochSecretsDTOs = Array.from(mlsStream.localView.epochSecrets.values()).map( + (epochSecret) => toLocalEpochSecretDTO(mlsStream.streamId, epochSecret), + ) + await this.store.saveLocalViewDTO(localViewDTO, epochSecretsDTOs) + this.log.log('saving group', bin_toHexString(mlsStream.localView.group.groupId)) + await mlsStream.localView.group.writeToStorage() + } + } + } +} diff --git a/packages/sdk/src/streamEvents.ts b/packages/sdk/src/streamEvents.ts index a869a0aed6..ae6408685b 100644 --- a/packages/sdk/src/streamEvents.ts +++ b/packages/sdk/src/streamEvents.ts @@ -20,6 +20,7 @@ import { KeySolicitationContent, UserDevice } from '@river-build/encryption' import { EncryptedContent } from './encryptedContentTypes' import { SyncState } from './syncedStreamsLoop' import { Pin } from './streamStateView_Members' +import { MlsConfirmedEvent } from './mls/types' export type StreamChange = { prepended?: RemoteTimelineEvent[] @@ -68,6 +69,8 @@ export type StreamEncryptionEvents = { groupInfoMessage: Uint8Array, ) => void mlsEpochSecrets: (streamId: string, secrets: { epoch: bigint; secret: Uint8Array }[]) => void + // MLS-specific confirmed events + mlsNewConfirmedEvent: (streamId: string, event: MlsConfirmedEvent) => void } export type SyncedStreamEvents = { diff --git a/packages/sdk/src/streamStateView_Mls.ts b/packages/sdk/src/streamStateView_Mls.ts index 589c6632a0..7e6850765d 100644 --- a/packages/sdk/src/streamStateView_Mls.ts +++ b/packages/sdk/src/streamStateView_Mls.ts @@ -123,6 +123,15 @@ export class StreamStateView_Mls extends StreamStateView_AbstractContent { default: break } + + const confirmedEvent = { + confirmedEventNum: event.confirmedEventNum, + miniblockNum: event.miniblockNum, + eventId: event.remoteEvent.hashStr, + ...payload.clone().content, + } + + encryptionEmitter?.emit('mlsNewConfirmedEvent', this.streamId, confirmedEvent) } addSignaturePublicKey(userId: string, signaturePublicKey: Uint8Array): void {