Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/vs/workbench/api/browser/mainThreadLanguageModels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,27 @@ import { SerializableObjectWithBuffers } from '../../services/extensions/common/
import { ExtHostContext, ExtHostLanguageModelsShape, MainContext, MainThreadLanguageModelsShape } from '../common/extHost.protocol.js';
import { LanguageModelError } from '../common/extHostTypes.js';

class RequestCancellationTokenSource extends Disposable {

private readonly _source: CancellationTokenSource;

constructor(parent: CancellationToken, onCancellationRequested?: () => void) {
super();
this._source = this._register(new CancellationTokenSource(parent));
if (onCancellationRequested) {
this._register(this._source.token.onCancellationRequested(onCancellationRequested));
}
}

get token(): CancellationToken {
return this._source.token;
}

cancel(): void {
this._source.cancel();
}
}

@extHostNamedCustomer(MainContext.MainThreadLanguageModels)
export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {

Expand All @@ -34,7 +55,7 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
private readonly _providerRegistrations = new DisposableMap<string>();
private readonly _lmProviderChange = new Emitter<{ vendor: string }>();
private readonly _pendingProgress = new Map<number, { defer: DeferredPromise<unknown>; stream: AsyncIterableSource<IChatResponsePart | IChatResponsePart[]> }>();
private readonly _pendingCancelCTS = new DisposableMap<number, CancellationTokenSource>();
private readonly _pendingCancelCTS = new DisposableMap<number, RequestCancellationTokenSource>();
private readonly _ignoredFileProviderRegistrations = new DisposableMap<number>();

constructor(
Expand Down Expand Up @@ -100,11 +121,10 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
try {
this._pendingProgress.set(requestId, { defer, stream });

const cts = new CancellationTokenSource(token);
this._pendingCancelCTS.set(requestId, cts);
cts.token.onCancellationRequested(() => {
const cts = new RequestCancellationTokenSource(token, () => {
this._proxy.$cancelLanguageModelChatRequest(requestId);
});
this._pendingCancelCTS.set(requestId, cts);

await Promise.all(
messages.flatMap(msg => msg.content)
Expand Down Expand Up @@ -194,7 +214,7 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
// Create a local CTS so cancellation can be signalled via
// $cancelLanguageModelChatRequest even after the RPC cancel
// handler for the original token has been removed.
const cts = new CancellationTokenSource(token);
const cts = new RequestCancellationTokenSource(token);
this._pendingCancelCTS.set(requestId, cts);

let response: ILanguageModelChatResponse;
Expand Down
46 changes: 44 additions & 2 deletions src/vs/workbench/api/test/browser/mainThreadLanguageModels.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import assert from 'assert';
import { CancellationToken, CancellationTokenSource } from '../../../../base/common/cancellation.js';
import { Emitter } from '../../../../base/common/event.js';
import { DisposableStore } from '../../../../base/common/lifecycle.js';
import { Disposable, DisposableStore } from '../../../../base/common/lifecycle.js';
import { ExtensionIdentifier } from '../../../../platform/extensions/common/extensions.js';
import { mock } from '../../../../base/test/common/mock.js';
import { ensureNoDisposablesAreLeakedInTestSuite } from '../../../../base/test/common/utils.js';
import { NullLogService } from '../../../../platform/log/common/log.js';
import { IAuthenticationService } from '../../../services/authentication/common/authentication.js';
import { IAuthenticationAccessService } from '../../../services/authentication/browser/authenticationAccessService.js';
import { ILanguageModelIgnoredFilesService } from '../../../contrib/chat/common/ignoredFiles.js';
import { ILanguageModelsService, IChatMessage } from '../../../contrib/chat/common/languageModels.js';
import { ILanguageModelChatProvider, ILanguageModelsService, IChatMessage } from '../../../contrib/chat/common/languageModels.js';
import { SerializableObjectWithBuffers } from '../../../services/extensions/common/proxyIdentifier.js';
import { TestExtensionService } from '../../../test/common/workbenchTestServices.js';
import { MainThreadLanguageModels } from '../../browser/mainThreadLanguageModels.js';
Expand Down Expand Up @@ -165,4 +165,46 @@ suite('MainThreadLanguageModels', function () {
// Should not throw
mainThread.$cancelLanguageModelChatRequest(999999);
});

test('disposes the provider request cancellation listener when the response completes', async () => {
const store = disposables.add(new DisposableStore());
let provider: ILanguageModelChatProvider | undefined;
let requestId: number | undefined;
let cancelCount = 0;
const proxy: Partial<ExtHostLanguageModelsShape> = {
$startChatRequest: async (_modelId, id) => {
requestId = id;
},
$cancelLanguageModelChatRequest: () => {
cancelCount++;
},
};
const languageModelsService = new class extends mock<ILanguageModelsService>() {
override readonly onDidChangeLanguageModels = store.add(new Emitter<string>()).event;
override getLanguageModelIds(): string[] { return []; }
override registerLanguageModelProvider(_vendor: string, value: ILanguageModelChatProvider) {
provider = value;
return Disposable.None;
}
};

const mainThread = store.add(new MainThreadLanguageModels(
SingleProxyRPCProtocol(proxy),
languageModelsService,
new NullLogService(),
new class extends mock<IAuthenticationService>() { },
new class extends mock<IAuthenticationAccessService>() { },
new TestExtensionService(),
new class extends mock<ILanguageModelIgnoredFilesService>() { },
));
mainThread.$registerLanguageModelProvider('test');

const cts = store.add(new CancellationTokenSource());
const response = await provider!.sendChatRequest('model-1', [], undefined, {}, cts.token);
await mainThread.$reportResponseDone(requestId!, undefined);
await response.result;
cts.cancel();

assert.strictEqual(cancelCount, 0);
});
});
6 changes: 3 additions & 3 deletions src/vs/workbench/services/host/browser/toasts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ export async function showBrowserToast(controller: IShowToastController, options

disposables.add(cts.token.onCancellationRequested(() => resolve({ supported: true, clicked: false })));

Event.once(toast.onClick)(() => resolve({ supported: true, clicked: true }));
Event.once(toast.onClose)(() => resolve({ supported: true, clicked: false }));
Event.once(toast.onError)(() => resolve({ supported: false, clicked: false }));
disposables.add(Event.once(toast.onClick)(() => resolve({ supported: true, clicked: true })));
disposables.add(Event.once(toast.onClose)(() => resolve({ supported: true, clicked: false })));
disposables.add(Event.once(toast.onError)(() => resolve({ supported: false, clicked: false })));
});
}

Expand Down
Loading