Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom tool invocation renderer for terminal tool #241768

Merged
merged 2 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vs/workbench/api/common/extHost.api.impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1794,6 +1794,7 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I
LanguageModelError: extHostTypes.LanguageModelError,
LanguageModelToolResult: extHostTypes.LanguageModelToolResult,
ExtendedLanguageModelToolResult: extHostTypes.ExtendedLanguageModelToolResult,
PreparedTerminalToolInvocation: extHostTypes.PreparedTerminalToolInvocation,
LanguageModelChatToolMode: extHostTypes.LanguageModelChatToolMode,
LanguageModelPromptTsxPart: extHostTypes.LanguageModelPromptTsxPart,
NewSymbolName: extHostTypes.NewSymbolName,
Expand Down
59 changes: 40 additions & 19 deletions src/vs/workbench/api/common/extHostLanguageModelTools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape
toolInvocationToken: dto.context as vscode.ChatParticipantToolToken | undefined,
chatRequestId: dto.chatRequestId,
};
if (isProposedApiEnabled(item.extension, 'chatParticipantPrivate') && dto.toolSpecificData?.kind === 'terminal') {
options.terminalCommand = dto.toolSpecificData.command;
}

if (dto.tokenBudget !== undefined) {
options.tokenizationOptions = {
tokenBudget: dto.tokenBudget,
Expand All @@ -134,29 +138,46 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape
throw new Error(`Unknown tool ${toolId}`);
}

if (!item.tool.prepareInvocation) {
return undefined;
}

const options: vscode.LanguageModelToolInvocationPrepareOptions<any> = { input };
const result = await item.tool.prepareInvocation(options, token);
if (!result) {
return undefined;
}
if (isProposedApiEnabled(item.extension, 'chatParticipantPrivate') && item.tool.prepareInvocation2) {
const result = await item.tool.prepareInvocation2(options, token);
if (!result) {
return undefined;
}

return {
confirmationMessages: result.confirmationMessages ? {
title: result.confirmationMessages.title,
message: typeof result.confirmationMessages.message === 'string' ? result.confirmationMessages.message : typeConvert.MarkdownString.from(result.confirmationMessages.message),
} : undefined,
toolSpecificData: {
kind: 'terminal',
language: result.language,
command: result.command,
}
};
} else if (item.tool.prepareInvocation) {
const result = await item.tool.prepareInvocation(options, token);
if (!result) {
return undefined;
}

if (result.pastTenseMessage || result.presentation) {
checkProposedApiEnabled(item.extension, 'chatParticipantPrivate');
if (result.pastTenseMessage || result.presentation) {
checkProposedApiEnabled(item.extension, 'chatParticipantPrivate');
}

return {
confirmationMessages: result.confirmationMessages ? {
title: result.confirmationMessages.title,
message: typeof result.confirmationMessages.message === 'string' ? result.confirmationMessages.message : typeConvert.MarkdownString.from(result.confirmationMessages.message),
} : undefined,
invocationMessage: typeConvert.MarkdownString.fromStrict(result.invocationMessage),
pastTenseMessage: typeConvert.MarkdownString.fromStrict(result.pastTenseMessage),
presentation: result.presentation
};
}

return {
confirmationMessages: result.confirmationMessages ? {
title: result.confirmationMessages.title,
message: typeof result.confirmationMessages.message === 'string' ? result.confirmationMessages.message : typeConvert.MarkdownString.from(result.confirmationMessages.message),
} : undefined,
invocationMessage: typeConvert.MarkdownString.fromStrict(result.invocationMessage),
pastTenseMessage: typeConvert.MarkdownString.fromStrict(result.pastTenseMessage),
presentation: result.presentation
};
return undefined;
}

registerTool(extension: IExtensionDescription, id: string, tool: vscode.LanguageModelTool<any>): IDisposable {
Expand Down
8 changes: 8 additions & 0 deletions src/vs/workbench/api/common/extHostTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4742,6 +4742,14 @@ export class LanguageModelToolResultPart implements vscode.LanguageModelToolResu
}
}

export class PreparedTerminalToolInvocation {
constructor(
public readonly command: string,
public readonly language: string,
public readonly confirmationMessages?: vscode.LanguageModelToolConfirmationMessages,
) { }
}

export class LanguageModelChatMessage implements vscode.LanguageModelChatMessage {

static User(content: string | (LanguageModelTextPart | LanguageModelToolResultPart | LanguageModelToolCallPart)[], name?: string): LanguageModelChatMessage {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class AcceptToolConfirmation extends Action2 {
f1: false,
category: CHAT_CATEGORY,
keybinding: {
when: ChatContextKeys.inChatInput,
when: ChatContextKeys.inChatSession,
primary: KeyMod.CtrlCmd | KeyCode.Enter,
weight: KeybindingWeight.EditorContrib
},
Expand All @@ -33,7 +33,8 @@ class AcceptToolConfirmation extends Action2 {

run(accessor: ServicesAccessor, ...args: any[]) {
const chatWidgetService = accessor.get(IChatWidgetService);
const lastItem = chatWidgetService.lastFocusedWidget?.viewModel?.getItems().at(-1);
const widget = chatWidgetService.lastFocusedWidget;
const lastItem = widget?.viewModel?.getItems().at(-1);
if (!isResponseVM(lastItem)) {
return;
}
Expand All @@ -42,6 +43,9 @@ class AcceptToolConfirmation extends Action2 {
if (unconfirmedToolInvocation) {
unconfirmedToolInvocation.confirmed.complete(true);
}

// Return focus to the chat input, in case it was in the tool confirmation editor
widget?.focusInput();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import { IMarkdownString, MarkdownString } from '../../../../../base/common/html
import { Disposable, DisposableStore, IDisposable } from '../../../../../base/common/lifecycle.js';
import { ThemeIcon } from '../../../../../base/common/themables.js';
import { MarkdownRenderer } from '../../../../../editor/browser/widget/markdownRenderer/browser/markdownRenderer.js';
import { ILanguageService } from '../../../../../editor/common/languages/language.js';
import { IModelService } from '../../../../../editor/common/services/model.js';
import { localize } from '../../../../../nls.js';
import { IInstantiationService } from '../../../../../platform/instantiation/common/instantiation.js';
import { IKeybindingService } from '../../../../../platform/keybinding/common/keybinding.js';
import { IChatMarkdownContent, IChatProgressMessage, IChatToolInvocation, IChatToolInvocationSerialized } from '../../common/chatService.js';
import { IChatMarkdownContent, IChatProgressMessage, IChatTerminalToolInvocationData, IChatToolInvocation, IChatToolInvocationSerialized } from '../../common/chatService.js';
import { IChatRendererContent } from '../../common/chatViewModel.js';
import { CodeBlockModelCollection } from '../../common/codeBlockModelCollection.js';
import { IToolResult } from '../../common/languageModelToolsService.js';
Expand Down Expand Up @@ -49,7 +51,7 @@ export class ChatToolInvocationPart extends Disposable implements IChatContentPa
renderer: MarkdownRenderer,
listPool: CollapsibleListPool,
editorPool: EditorPool,
currentWidth: number,
currentWidthDelegate: () => number,
codeBlockModelCollection: CodeBlockModelCollection,
codeBlockStartIndex: number,
@IInstantiationService instantiationService: IInstantiationService,
Expand All @@ -69,7 +71,7 @@ export class ChatToolInvocationPart extends Disposable implements IChatContentPa
dom.clearNode(this.domNode);
partStore.clear();

this.subPart = partStore.add(instantiationService.createInstance(ChatToolInvocationSubPart, toolInvocation, context, renderer, listPool, editorPool, currentWidth, codeBlockModelCollection, codeBlockStartIndex));
this.subPart = partStore.add(instantiationService.createInstance(ChatToolInvocationSubPart, toolInvocation, context, renderer, listPool, editorPool, currentWidthDelegate, codeBlockModelCollection, codeBlockStartIndex));
this.domNode.appendChild(this.subPart.domNode);
partStore.add(this.subPart.onDidChangeHeight(() => this._onDidChangeHeight.fire()));
partStore.add(this.subPart.onNeedsRerender(() => {
Expand All @@ -90,6 +92,9 @@ export class ChatToolInvocationPart extends Disposable implements IChatContentPa
}

class ChatToolInvocationSubPart extends Disposable {
private static idPool = 0;
private readonly _codeblocksPartId = 'tool-' + (ChatToolInvocationSubPart.idPool++);

public readonly domNode: HTMLElement;

private _onNeedsRerender = this._register(new Emitter<void>());
Expand All @@ -99,12 +104,14 @@ class ChatToolInvocationSubPart extends Disposable {
public readonly onDidChangeHeight = this._onDidChangeHeight.event;

private markdownPart: ChatMarkdownContentPart | undefined;
private _codeblocks: IChatCodeBlockInfo[] = [];
public get codeblocks(): IChatCodeBlockInfo[] {
return this.markdownPart?.codeblocks ?? [];
// TODO this is weird, the separate cases should maybe be their own "subparts"
return this.markdownPart?.codeblocks ?? this._codeblocks;
}

public get codeblocksPartId(): string | undefined {
return this.markdownPart?.codeblocksPartId;
public get codeblocksPartId(): string {
return this.markdownPart?.codeblocksPartId ?? this._codeblocksPartId;
}

constructor(
Expand All @@ -113,18 +120,24 @@ class ChatToolInvocationSubPart extends Disposable {
private readonly renderer: MarkdownRenderer,
private readonly listPool: CollapsibleListPool,
private readonly editorPool: EditorPool,
private readonly currentWidth: number,
private readonly currentWidthDelegate: () => number,
private readonly codeBlockModelCollection: CodeBlockModelCollection,
private readonly codeBlockStartIndex: number,
@IInstantiationService private readonly instantiationService: IInstantiationService,
@IKeybindingService private readonly keybindingService: IKeybindingService,
@IModelService private readonly modelService: IModelService,
@ILanguageService private readonly languageService: ILanguageService,
) {
super();

if (toolInvocation.kind === 'toolInvocation' && toolInvocation.confirmationMessages) {
this.domNode = this.createConfirmationWidget(toolInvocation);
} else if (toolInvocation.presentation === 'withCodeblocks' && typeof toolInvocation.invocationMessage !== 'string') {
this.domNode = this.createMarkdownWithCodeblocksProgressPart(toolInvocation);
if (toolInvocation.toolSpecificData?.kind === 'terminal') {
this.domNode = this.createTerminalConfirmationWidget(toolInvocation, toolInvocation.toolSpecificData);
} else {
this.domNode = this.createConfirmationWidget(toolInvocation);
}
} else if (toolInvocation.toolSpecificData?.kind === 'terminal') {
this.domNode = this.createTerminalMarkdownProgressPart(toolInvocation, toolInvocation.toolSpecificData);
} else if (toolInvocation.resultDetails?.length) {
this.domNode = this.createResultList(toolInvocation.pastTenseMessage ?? toolInvocation.invocationMessage, toolInvocation.resultDetails);
} else {
Expand Down Expand Up @@ -182,7 +195,7 @@ class ChatToolInvocationSubPart extends Disposable {
wordWrap: 'on'
}
};
this.markdownPart = this._register(this.instantiationService.createInstance(ChatMarkdownContentPart, chatMarkdownContent, this.context, this.editorPool, false, this.codeBlockStartIndex, this.renderer, this.currentWidth, this.codeBlockModelCollection, { codeBlockRenderOptions }));
this.markdownPart = this._register(this.instantiationService.createInstance(ChatMarkdownContentPart, chatMarkdownContent, this.context, this.editorPool, false, this.codeBlockStartIndex, this.renderer, this.currentWidthDelegate(), this.codeBlockModelCollection, { codeBlockRenderOptions }));
this._register(this.markdownPart.onDidChangeHeight(() => this._onDidChangeHeight.fire()));
confirmWidget = this._register(this.instantiationService.createInstance(
ChatCustomConfirmationWidget,
Expand All @@ -202,6 +215,92 @@ class ChatToolInvocationSubPart extends Disposable {
return confirmWidget.domNode;
}

private createTerminalConfirmationWidget(toolInvocation: IChatToolInvocation, terminalData: IChatTerminalToolInvocationData): HTMLElement {
if (!toolInvocation.confirmationMessages) {
throw new Error('Confirmation messages are missing');
}
const title = toolInvocation.confirmationMessages.title;
const message = toolInvocation.confirmationMessages.message;
const continueLabel = localize('continue', "Continue");
const continueKeybinding = this.keybindingService.lookupKeybinding(AcceptToolConfirmationActionId)?.getLabel();
const continueTooltip = continueKeybinding ? `${continueLabel} (${continueKeybinding})` : continueLabel;
const cancelLabel = localize('cancel', "Cancel");
const cancelKeybinding = this.keybindingService.lookupKeybinding(CancelChatActionId)?.getLabel();
const cancelTooltip = cancelKeybinding ? `${cancelLabel} (${cancelKeybinding})` : cancelLabel;

const buttons: IChatConfirmationButton[] = [
{
label: continueLabel,
data: true,
tooltip: continueTooltip
},
{
label: cancelLabel,
data: false,
isSecondary: true,
tooltip: cancelTooltip
}];
const renderedMessage = this._register(this.renderer.render(
typeof message === 'string' ? new MarkdownString(message) : message,
{ asyncRenderCallback: () => this._onDidChangeHeight.fire() }
));
const codeBlockRenderOptions: ICodeBlockRenderOptions = {
hideToolbar: true,
reserveWidth: 19,
verticalPadding: 5,
editorOptions: {
wordWrap: 'on',
readOnly: false
}
};
const langId = this.languageService.getLanguageIdByLanguageName(terminalData.language ?? 'sh') ?? 'shellscript';
const model = this.modelService.createModel(terminalData.command, this.languageService.createById(langId));
const editor = this._register(this.editorPool.get());
editor.object.render({
codeBlockIndex: this.codeBlockStartIndex,
codeBlockPartIndex: 0,
element: this.context.element,
languageId: langId,
renderOptions: codeBlockRenderOptions,
textModel: Promise.resolve(model)
}, this.currentWidthDelegate());
this._codeblocks.push({
codeBlockIndex: this.codeBlockStartIndex,
codemapperUri: undefined,
elementId: this.context.element.id,
focus: () => editor.object.focus(),
isStreaming: false,
ownerMarkdownPartId: this.codeblocksPartId,
uri: model.uri,
uriPromise: Promise.resolve(model.uri)
});
this._register(editor.object.onDidChangeContentHeight(() => {
editor.object.layout(this.currentWidthDelegate());
this._onDidChangeHeight.fire();
}));
this._register(model.onDidChangeContent(e => {
terminalData.command = model.getValue();
}));
const element = dom.$('');
dom.append(element, editor.object.element);
dom.append(element, renderedMessage.element);
const confirmWidget = this._register(this.instantiationService.createInstance(
ChatCustomConfirmationWidget,
title,
element,
buttons
));

this._register(confirmWidget.onDidClick(button => {
toolInvocation.confirmed.complete(button.data);
}));
this._register(confirmWidget.onDidChangeHeight(() => this._onDidChangeHeight.fire()));
toolInvocation.confirmed.p.then(() => {
this._onNeedsRerender.fire();
});
return confirmWidget.domNode;
}

private createProgressPart(): HTMLElement {
let content: IMarkdownString;
if (this.toolInvocation.isComplete && this.toolInvocation.isConfirmed !== false && this.toolInvocation.pastTenseMessage) {
Expand All @@ -218,18 +317,16 @@ class ChatToolInvocationSubPart extends Disposable {
kind: 'progressMessage',
content
};
const iconOverride = this.toolInvocation.isConfirmed === false ?
const iconOverride = !this.toolInvocation.isConfirmed ?
Codicon.error :
this.toolInvocation.isComplete ?
Codicon.check : undefined;
const progressPart = this._register(this.instantiationService.createInstance(ChatProgressContentPart, progressMessage, this.renderer, this.context, undefined, true, iconOverride));
return progressPart.domNode;
}

private createMarkdownWithCodeblocksProgressPart(toolInvocation: IChatToolInvocation | IChatToolInvocationSerialized): HTMLElement {
const content = toolInvocation.isComplete ?
(toolInvocation.pastTenseMessage ?? toolInvocation.invocationMessage)
: toolInvocation.invocationMessage;
private createTerminalMarkdownProgressPart(toolInvocation: IChatToolInvocation | IChatToolInvocationSerialized, terminalData: IChatTerminalToolInvocationData): HTMLElement {
const content = new MarkdownString(`\`\`\`${terminalData.language}\n${terminalData.command}\n\`\`\``);
const chatMarkdownContent: IChatMarkdownContent = {
kind: 'markdownContent',
content: content as IMarkdownString,
Expand All @@ -243,9 +340,9 @@ class ChatToolInvocationSubPart extends Disposable {
wordWrap: 'on'
}
};
this.markdownPart = this._register(this.instantiationService.createInstance(ChatMarkdownContentPart, chatMarkdownContent, this.context, this.editorPool, false, this.codeBlockStartIndex, this.renderer, this.currentWidth, this.codeBlockModelCollection, { codeBlockRenderOptions }));
this.markdownPart = this._register(this.instantiationService.createInstance(ChatMarkdownContentPart, chatMarkdownContent, this.context, this.editorPool, false, this.codeBlockStartIndex, this.renderer, this.currentWidthDelegate(), this.codeBlockModelCollection, { codeBlockRenderOptions }));
this._register(this.markdownPart.onDidChangeHeight(() => this._onDidChangeHeight.fire()));
const icon = this.toolInvocation.isConfirmed === false ?
const icon = !this.toolInvocation.isConfirmed ?
Codicon.error :
this.toolInvocation.isComplete ?
Codicon.check : ThemeIcon.modify(Codicon.loading, 'spin');
Expand Down
2 changes: 1 addition & 1 deletion src/vs/workbench/contrib/chat/browser/chatListRenderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ export class ChatListItemRenderer extends Disposable implements ITreeRenderer<Ch

private renderToolInvocation(toolInvocation: IChatToolInvocation | IChatToolInvocationSerialized, context: IChatContentPartRenderContext, templateData: IChatListItemTemplate): IChatContentPart | undefined {
const codeBlockStartIndex = this.getCodeBlockStartIndex(context);
const part = this.instantiationService.createInstance(ChatToolInvocationPart, toolInvocation, context, this.renderer, this._contentReferencesListPool, this._editorPool, this._currentLayoutWidth, this._toolInvocationCodeBlockCollection, codeBlockStartIndex);
const part = this.instantiationService.createInstance(ChatToolInvocationPart, toolInvocation, context, this.renderer, this._contentReferencesListPool, this._editorPool, () => this._currentLayoutWidth, this._toolInvocationCodeBlockCollection, codeBlockStartIndex);
part.addDisposable(part.onDidChangeHeight(() => {
this.updateItemHeight(templateData);
}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo
if (!userConfirmed) {
throw new CancellationError();
}

dto.toolSpecificData = toolInvocation?.toolSpecificData;
}
} else {
const prepared = tool.impl.prepareToolInvocation ?
Expand Down
2 changes: 1 addition & 1 deletion src/vs/workbench/contrib/chat/browser/media/chat.css
Original file line number Diff line number Diff line change
Expand Up @@ -1595,7 +1595,7 @@ have to be updated for changes to the rules above, or to support more deeply nes
padding: 4px 8px;
}

.interactive-item-container .chat-confirmation-widget .rendered-markdown [data-code] {
.interactive-item-container .chat-confirmation-widget .interactive-result-code-block {
margin-bottom: 8px;
}

Expand Down
Loading
Loading