diff --git a/src/vs/workbench/api/browser/mainThreadLanguageModels.ts b/src/vs/workbench/api/browser/mainThreadLanguageModels.ts index db30c270ad91c..1f952e1c0e653 100644 --- a/src/vs/workbench/api/browser/mainThreadLanguageModels.ts +++ b/src/vs/workbench/api/browser/mainThreadLanguageModels.ts @@ -26,6 +26,27 @@ import { SerializableObjectWithBuffers } from '../../services/extensions/common/ import { ExtHostContext, ExtHostLanguageModelsShape, MainContext, MainThreadLanguageModelsShape } from '../common/extHost.protocol.js'; import { LanguageModelError } from '../common/extHostTypes.js'; +class RequestCancellationTokenSource extends Disposable { + + private readonly _source: CancellationTokenSource; + + constructor(parent: CancellationToken, onCancellationRequested?: () => void) { + super(); + this._source = this._register(new CancellationTokenSource(parent)); + if (onCancellationRequested) { + this._register(this._source.token.onCancellationRequested(onCancellationRequested)); + } + } + + get token(): CancellationToken { + return this._source.token; + } + + cancel(): void { + this._source.cancel(); + } +} + @extHostNamedCustomer(MainContext.MainThreadLanguageModels) export class MainThreadLanguageModels implements MainThreadLanguageModelsShape { @@ -34,7 +55,7 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape { private readonly _providerRegistrations = new DisposableMap(); private readonly _lmProviderChange = new Emitter<{ vendor: string }>(); private readonly _pendingProgress = new Map; stream: AsyncIterableSource }>(); - private readonly _pendingCancelCTS = new DisposableMap(); + private readonly _pendingCancelCTS = new DisposableMap(); private readonly _ignoredFileProviderRegistrations = new DisposableMap(); constructor( @@ -100,11 +121,10 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape { try { this._pendingProgress.set(requestId, { defer, stream }); - const cts = new CancellationTokenSource(token); - this._pendingCancelCTS.set(requestId, cts); - cts.token.onCancellationRequested(() => { + const cts = new RequestCancellationTokenSource(token, () => { this._proxy.$cancelLanguageModelChatRequest(requestId); }); + this._pendingCancelCTS.set(requestId, cts); await Promise.all( messages.flatMap(msg => msg.content) @@ -194,7 +214,7 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape { // Create a local CTS so cancellation can be signalled via // $cancelLanguageModelChatRequest even after the RPC cancel // handler for the original token has been removed. - const cts = new CancellationTokenSource(token); + const cts = new RequestCancellationTokenSource(token); this._pendingCancelCTS.set(requestId, cts); let response: ILanguageModelChatResponse; diff --git a/src/vs/workbench/api/test/browser/mainThreadLanguageModels.test.ts b/src/vs/workbench/api/test/browser/mainThreadLanguageModels.test.ts index eb78b2c4f9633..9c15b812ccf2a 100644 --- a/src/vs/workbench/api/test/browser/mainThreadLanguageModels.test.ts +++ b/src/vs/workbench/api/test/browser/mainThreadLanguageModels.test.ts @@ -6,7 +6,7 @@ import assert from 'assert'; import { CancellationToken, CancellationTokenSource } from '../../../../base/common/cancellation.js'; import { Emitter } from '../../../../base/common/event.js'; -import { DisposableStore } from '../../../../base/common/lifecycle.js'; +import { Disposable, DisposableStore } from '../../../../base/common/lifecycle.js'; import { ExtensionIdentifier } from '../../../../platform/extensions/common/extensions.js'; import { mock } from '../../../../base/test/common/mock.js'; import { ensureNoDisposablesAreLeakedInTestSuite } from '../../../../base/test/common/utils.js'; @@ -14,7 +14,7 @@ import { NullLogService } from '../../../../platform/log/common/log.js'; import { IAuthenticationService } from '../../../services/authentication/common/authentication.js'; import { IAuthenticationAccessService } from '../../../services/authentication/browser/authenticationAccessService.js'; import { ILanguageModelIgnoredFilesService } from '../../../contrib/chat/common/ignoredFiles.js'; -import { ILanguageModelsService, IChatMessage } from '../../../contrib/chat/common/languageModels.js'; +import { ILanguageModelChatProvider, ILanguageModelsService, IChatMessage } from '../../../contrib/chat/common/languageModels.js'; import { SerializableObjectWithBuffers } from '../../../services/extensions/common/proxyIdentifier.js'; import { TestExtensionService } from '../../../test/common/workbenchTestServices.js'; import { MainThreadLanguageModels } from '../../browser/mainThreadLanguageModels.js'; @@ -165,4 +165,46 @@ suite('MainThreadLanguageModels', function () { // Should not throw mainThread.$cancelLanguageModelChatRequest(999999); }); + + test('disposes the provider request cancellation listener when the response completes', async () => { + const store = disposables.add(new DisposableStore()); + let provider: ILanguageModelChatProvider | undefined; + let requestId: number | undefined; + let cancelCount = 0; + const proxy: Partial = { + $startChatRequest: async (_modelId, id) => { + requestId = id; + }, + $cancelLanguageModelChatRequest: () => { + cancelCount++; + }, + }; + const languageModelsService = new class extends mock() { + override readonly onDidChangeLanguageModels = store.add(new Emitter()).event; + override getLanguageModelIds(): string[] { return []; } + override registerLanguageModelProvider(_vendor: string, value: ILanguageModelChatProvider) { + provider = value; + return Disposable.None; + } + }; + + const mainThread = store.add(new MainThreadLanguageModels( + SingleProxyRPCProtocol(proxy), + languageModelsService, + new NullLogService(), + new class extends mock() { }, + new class extends mock() { }, + new TestExtensionService(), + new class extends mock() { }, + )); + mainThread.$registerLanguageModelProvider('test'); + + const cts = store.add(new CancellationTokenSource()); + const response = await provider!.sendChatRequest('model-1', [], undefined, {}, cts.token); + await mainThread.$reportResponseDone(requestId!, undefined); + await response.result; + cts.cancel(); + + assert.strictEqual(cancelCount, 0); + }); }); diff --git a/src/vs/workbench/services/host/browser/toasts.ts b/src/vs/workbench/services/host/browser/toasts.ts index fcec62cf82a4c..32c8e6d532303 100644 --- a/src/vs/workbench/services/host/browser/toasts.ts +++ b/src/vs/workbench/services/host/browser/toasts.ts @@ -43,9 +43,9 @@ export async function showBrowserToast(controller: IShowToastController, options disposables.add(cts.token.onCancellationRequested(() => resolve({ supported: true, clicked: false }))); - Event.once(toast.onClick)(() => resolve({ supported: true, clicked: true })); - Event.once(toast.onClose)(() => resolve({ supported: true, clicked: false })); - Event.once(toast.onError)(() => resolve({ supported: false, clicked: false })); + disposables.add(Event.once(toast.onClick)(() => resolve({ supported: true, clicked: true }))); + disposables.add(Event.once(toast.onClose)(() => resolve({ supported: true, clicked: false }))); + disposables.add(Event.once(toast.onError)(() => resolve({ supported: false, clicked: false }))); }); }