From 22235f30b0b057143b1b86ced1a33f3337d34209 Mon Sep 17 00:00:00 2001
From: luca cappa <lucappa@microsoft.com>
Date: Tue, 3 Dec 2024 17:50:52 -0800
Subject: [PATCH] code snippet provider

---
 Extension/src/LanguageServer/client.ts        |  26 ++-
 .../copilotCompletionContextProvider.ts       | 185 ++++++++++++++++++
 .../LanguageServer/copilotContextTelemetry.ts |  72 +++++++
 Extension/src/LanguageServer/extension.ts     |   8 +
 4 files changed, 289 insertions(+), 2 deletions(-)
 create mode 100644 Extension/src/LanguageServer/copilotCompletionContextProvider.ts
 create mode 100644 Extension/src/LanguageServer/copilotContextTelemetry.ts

diff --git a/Extension/src/LanguageServer/client.ts b/Extension/src/LanguageServer/client.ts
index d937cff044..c5c30fe740 100644
--- a/Extension/src/LanguageServer/client.ts
+++ b/Extension/src/LanguageServer/client.ts
@@ -53,9 +53,10 @@ import {
 } from './codeAnalysis';
 import { Location, TextEdit, WorkspaceEdit } from './commonTypes';
 import * as configs from './configurations';
+import { CopilotCompletionContextProvider } from './copilotCompletionContextProvider';
 import { DataBinding } from './dataBinding';
 import { cachedEditorConfigSettings, getEditorConfigSettings } from './editorConfig';
-import { CppSourceStr, clients, configPrefix, updateLanguageConfigurations, usesCrashHandler, watchForCrashes } from './extension';
+import { CppSourceStr, SnippetEntry, clients, configPrefix, updateLanguageConfigurations, usesCrashHandler, watchForCrashes } from './extension';
 import { LocalizeStringParams, getLocaleId, getLocalizedString } from './localization';
 import { PersistentFolderState, PersistentWorkspaceState } from './persistentState';
 import { RequestCancelled, ServerCancelled, createProtocolFilter } from './protocolFilter';
@@ -554,6 +555,15 @@ export interface ProjectContextResult {
     fileContext: FileContextResult;
 }
 
+export interface CompletionContextsResult {
+    context: SnippetEntry[];
+}
+
+export interface CompletionContextParams {
+    file: string;
+    caretOffset: number;
+}
+
 // Requests
 const PreInitializationRequest: RequestType<void, string, void> = new RequestType<void, string, void>('cpptools/preinitialize');
 const InitializationRequest: RequestType<CppInitializationParams, void, void> = new RequestType<CppInitializationParams, void, void>('cpptools/initialize');
@@ -575,6 +585,7 @@ const ChangeCppPropertiesRequest: RequestType<CppPropertiesParams, void, void> =
 const IncludesRequest: RequestType<GetIncludesParams, GetIncludesResult, void> = new RequestType<GetIncludesParams, GetIncludesResult, void>('cpptools/getIncludes');
 const CppContextRequest: RequestType<TextDocumentIdentifier, ChatContextResult, void> = new RequestType<TextDocumentIdentifier, ChatContextResult, void>('cpptools/getChatContext');
 const ProjectContextRequest: RequestType<TextDocumentIdentifier, ProjectContextResult, void> = new RequestType<TextDocumentIdentifier, ProjectContextResult, void>('cpptools/getProjectContext');
+const CompletionContextRequest: RequestType<CompletionContextParams, CompletionContextsResult, void> = new RequestType<CompletionContextParams, CompletionContextsResult, void>('cpptools/getCompletionContext');
 
 // Notifications to the server
 const DidOpenNotification: NotificationType<DidOpenTextDocumentParams> = new NotificationType<DidOpenTextDocumentParams>('textDocument/didOpen');
@@ -807,6 +818,7 @@ export interface Client {
     getIncludes(maxDepth: number): Promise<GetIncludesResult>;
     getChatContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise<ChatContextResult>;
     getProjectContext(uri: vscode.Uri): Promise<ProjectContextResult>;
+    getCompletionContext(fileName: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise<CompletionContextsResult>;
 }
 
 export function createClient(workspaceFolder?: vscode.WorkspaceFolder): Client {
@@ -839,7 +851,7 @@ export class DefaultClient implements Client {
     private settingsTracker: SettingsTracker;
     private loggingLevel: number = 1;
     private configurationProvider?: string;
-
+    private copilotCompletionProvider?: CopilotCompletionContextProvider;
     public lastCustomBrowseConfiguration: PersistentFolderState<WorkspaceBrowseConfiguration | undefined> | undefined;
     public lastCustomBrowseConfigurationProviderId: PersistentFolderState<string | undefined> | undefined;
     public lastCustomBrowseConfigurationProviderVersion: PersistentFolderState<Version> | undefined;
@@ -1298,6 +1310,8 @@ export class DefaultClient implements Client {
                     this.semanticTokensProviderDisposable = vscode.languages.registerDocumentSemanticTokensProvider(util.documentSelector, this.semanticTokensProvider, semanticTokensLegend);
                 }
 
+                this.copilotCompletionProvider = await CopilotCompletionContextProvider.Create();
+
                 // Listen for messages from the language server.
                 this.registerNotifications();
 
@@ -1807,6 +1821,7 @@ export class DefaultClient implements Client {
         if (diagnosticsCollectionIntelliSense) {
             diagnosticsCollectionIntelliSense.delete(document.uri);
         }
+        this.copilotCompletionProvider?.removeFile(uri);
         openFileVersions.delete(uri);
     }
 
@@ -2255,6 +2270,12 @@ export class DefaultClient implements Client {
             () => this.languageClient.sendRequest(CppContextRequest, params, token), token);
     }
 
+    public async getCompletionContext(file: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise<CompletionContextsResult> {
+        await withCancellation(this.ready, token);
+        return DefaultClient.withLspCancellationHandling(
+            () => this.languageClient.sendRequest(CompletionContextRequest, { file: file.toString(), caretOffset }, token), token);
+    }
+
     /**
      * a Promise that can be awaited to know when it's ok to proceed.
      *
@@ -4159,4 +4180,5 @@ class NullClient implements Client {
     getIncludes(maxDepth: number): Promise<GetIncludesResult> { return Promise.resolve({} as GetIncludesResult); }
     getChatContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise<ChatContextResult> { return Promise.resolve({} as ChatContextResult); }
     getProjectContext(uri: vscode.Uri): Promise<ProjectContextResult> { return Promise.resolve({} as ProjectContextResult); }
+    getCompletionContext(file: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise<CompletionContextsResult> { return Promise.resolve({} as CompletionContextsResult); }
 }
diff --git a/Extension/src/LanguageServer/copilotCompletionContextProvider.ts b/Extension/src/LanguageServer/copilotCompletionContextProvider.ts
new file mode 100644
index 0000000000..5c3e67b27f
--- /dev/null
+++ b/Extension/src/LanguageServer/copilotCompletionContextProvider.ts
@@ -0,0 +1,185 @@
+/* --------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All Rights Reserved.
+ * See 'LICENSE' in the project root for license information.
+ * ------------------------------------------------------------------------------------------ */
+import * as vscode from 'vscode';
+import { DocumentSelector } from 'vscode-languageserver-protocol';
+import { getOutputChannelLogger, Logger } from '../logger';
+import * as telemetry from '../telemetry';
+import { CopilotContextTelemetry } from './copilotContextTelemetry';
+import { getCopilotApi } from './copilotProviders';
+import { clients } from './extension';
+import { CodeSnippet, CompletionContext, ContextProviderApiV1, ContextResolver } from './tmp/contextProviderV1';
+
+class DefaultValueFallback extends Error {
+    static readonly DefaultValue = "DefaultValue";
+    constructor() { super(DefaultValueFallback.DefaultValue); }
+}
+
+class CancellationError extends Error {
+    static readonly Cancelled = "Cancelled";
+    constructor() { super(CancellationError.Cancelled); }
+}
+
+// Mutually exclusive values for the kind of snippets. They either are:
+// - computed.
+// - obtained from the cache.
+// - missing and the computation is taking too long and no cache is present (cache miss). The value
+//   is asynchronously computed and stored in cache.
+// - the token is signaled as cancelled, in which case all the operations are aborted.
+// - an unknown state.
+enum SnippetsKind {
+    Computed = 'computed',
+    GotFromCache = 'gotFromCacheHit',
+    MissingCacheMiss = 'missingCacheMiss',
+    Cancelled = 'cancelled',
+    Unknown = 'unknown'
+}
+
+export class CopilotCompletionContextProvider implements ContextResolver<CodeSnippet> {
+    private static readonly providerId = 'cppTools';
+    private readonly completionContextCache: Map<string, CodeSnippet[]> = new Map<string, CodeSnippet[]>();
+    private static readonly defaultCppDocumentSelector: DocumentSelector = [{ language: 'cpp' }, { language: 'c' }, { language: 'cuda-cpp' }];
+    private static readonly defaultTimeBudgetFactor: number = 0.5;
+    private completionContextCancellation = new vscode.CancellationTokenSource();
+
+    // Get the default value if the timeout expires, but throws an exception if the token is cancelled.
+    private async waitForCompletionWithTimeoutAndCancellation<T>(promise: Promise<T>, defaultValue: T | undefined,
+        timeout: number, token: vscode.CancellationToken): Promise<[T | undefined, SnippetsKind]> {
+        const defaultValuePromise = new Promise<T>((resolve, reject) => setTimeout(() => {
+            if (token.isCancellationRequested) {
+                reject(new CancellationError());
+            } else {
+                reject(new DefaultValueFallback());
+            }
+        }, timeout));
+        const cancellationPromise = new Promise<T>((_, reject) => {
+            token.onCancellationRequested(() => {
+                reject(new CancellationError());
+            });
+        });
+        let snippetsOrNothing: T | undefined;
+        try {
+            snippetsOrNothing = await Promise.race([promise, cancellationPromise, defaultValuePromise]);
+        } catch (e) {
+            if (e instanceof DefaultValueFallback) {
+                return [defaultValue, defaultValue !== undefined ? SnippetsKind.GotFromCache : SnippetsKind.MissingCacheMiss];
+            } else if (e instanceof CancellationError) {
+                return [undefined, SnippetsKind.Cancelled];
+            } else {
+                throw e;
+            }
+        }
+
+        return [snippetsOrNothing, SnippetsKind.Computed];
+    }
+
+    // Get the completion context with a timeout and a cancellation token.
+    // The cancellationToken indicates that the value should not be returned nor cached.
+    private async getCompletionContextWithCancellation(documentUri: string, caretOffset: number,
+        startTime: number, out: Logger, telemetry: CopilotContextTelemetry, token: vscode.CancellationToken): Promise<CodeSnippet[]> {
+        try {
+            const docUri = vscode.Uri.parse(documentUri);
+            const snippets = await clients.getClientFor(docUri).getCompletionContext(docUri, caretOffset, token);
+
+            const codeSnippets = snippets.context.map((item) => {
+                if (token.isCancellationRequested) {
+                    telemetry.addCancelledLate();
+                    throw new CancellationError();
+                }
+                return {
+                    importance: item.importance, uri: item.uri, value: item.text
+                };
+            });
+
+            this.completionContextCache.set(documentUri, codeSnippets);
+            const duration: number = performance.now() - startTime;
+            out.appendLine(`Copilot: getCompletionContextWithCancellation(): Cached in [ms]: ${duration}`);
+            telemetry.addSnippetCount(codeSnippets?.length);
+            telemetry.addCacheComputedElapsed(duration);
+
+            return codeSnippets;
+        } catch (e) {
+            const err = e as Error;
+            out.appendLine(`Copilot: getCompletionContextWithCancellation(): Error: '${err?.message}', stack '${err?.stack}`);
+            telemetry.addError();
+            return [];
+        }
+    }
+
+    private async fetchTimeBudgetFactor(context: CompletionContext): Promise<number> {
+        const budgetFactor = context.activeExperiments.get("CppToolsCopilotTimeBudget");
+        return (budgetFactor as number) !== undefined ? budgetFactor as number : CopilotCompletionContextProvider.defaultTimeBudgetFactor;
+    }
+
+    public static async Create() {
+        const copilotCompletionProvider = new CopilotCompletionContextProvider();
+        await copilotCompletionProvider.registerCopilotContextProvider();
+        return copilotCompletionProvider;
+    }
+
+    public removeFile(fileUri: string): void {
+        this.completionContextCache.delete(fileUri);
+    }
+
+    public async resolve(context: CompletionContext, copilotAborts: vscode.CancellationToken): Promise<CodeSnippet[]> {
+        const startTime = performance.now();
+        const out: Logger = getOutputChannelLogger();
+        const timeBudgetFactor = await this.fetchTimeBudgetFactor(context);
+        const telemetry = new CopilotContextTelemetry();
+        let codeSnippets: CodeSnippet[] | undefined;
+        let codeSnippetsKind: SnippetsKind = SnippetsKind.Unknown;
+        try {
+            this.completionContextCancellation.cancel();
+            this.completionContextCancellation = new vscode.CancellationTokenSource();
+            const docUri = context.documentContext.uri;
+            const cachedValue: CodeSnippet[] | undefined = this.completionContextCache.get(docUri.toString());
+            const snippetsPromise = this.getCompletionContextWithCancellation(docUri,
+                context.documentContext.offset, startTime, out, telemetry.fork(), this.completionContextCancellation.token);
+            [codeSnippets, codeSnippetsKind] = await this.waitForCompletionWithTimeoutAndCancellation(
+                snippetsPromise, cachedValue, context.timeBudget * timeBudgetFactor, copilotAborts);
+            if (codeSnippetsKind === SnippetsKind.Cancelled) {
+                const duration: number = performance.now() - startTime;
+                out.appendLine(`Copilot: getCompletionContext(): cancelled, elapsed time (ms) : ${duration}`);
+                telemetry.addCancelled();
+                telemetry.addCancellationElapsed(duration);
+                throw new CancellationError();
+            }
+            telemetry.addSnippetCount(codeSnippets?.length);
+            return codeSnippets ?? [];
+        } catch (e: any) {
+            telemetry.addError();
+            throw e;
+        } finally {
+            telemetry.addKind(codeSnippetsKind.toString());
+            const duration: number = performance.now() - startTime;
+            if (codeSnippets === undefined) {
+                out.appendLine(`Copilot: getCompletionContext(): no snkppets provided (${codeSnippetsKind.toString()}), elapsed time (ms): ${duration}`);
+            } else {
+                out.appendLine(`Copilot: getCompletionContext(): provided ${codeSnippets?.length} snippets (${codeSnippetsKind.toString()}), elapsed time (ms): ${duration}`);
+            }
+            telemetry.addResolvedElapsed(duration);
+            telemetry.addCacheSize(this.completionContextCache.size);
+            // //?? TODO telemetry.file();
+        }
+
+        return [];
+    }
+
+    public async registerCopilotContextProvider(): Promise<void> {
+        try {
+            const isCustomSnippetProviderApiEnabled = await telemetry.isExperimentEnabled("CppToolsCustomSnippetsApi");
+            if (isCustomSnippetProviderApiEnabled) {
+                const contextAPI = (await getCopilotApi() as any).getContextProviderAPI('v1') as ContextProviderApiV1;
+                contextAPI.registerContextProvider({
+                    id: CopilotCompletionContextProvider.providerId,
+                    selector: CopilotCompletionContextProvider.defaultCppDocumentSelector,
+                    resolver: this
+                });
+            }
+        } catch {
+            console.warn("Failed to register the Copilot Context Provider.");
+            telemetry.logCopilotEvent("registerCopilotContextProviderError", { "message": "Failed to register the Copilot Context Provider." });
+        }
+    }
+}
diff --git a/Extension/src/LanguageServer/copilotContextTelemetry.ts b/Extension/src/LanguageServer/copilotContextTelemetry.ts
new file mode 100644
index 0000000000..5160ca5fae
--- /dev/null
+++ b/Extension/src/LanguageServer/copilotContextTelemetry.ts
@@ -0,0 +1,72 @@
+/* --------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All Rights Reserved.
+ * See 'LICENSE' in the project root for license information.
+ * ------------------------------------------------------------------------------------------ */
+import { randomUUID } from 'crypto';
+import * as telemetry from '../telemetry';
+
+export class CopilotContextTelemetry {
+    private static readonly correlationIdKey = 'correlationId';
+    private static readonly copilotEventName = 'copilotContextProvider';
+    private readonly metrics: Record<string, number> = {};
+    private readonly properties: Record<string, string> = {};
+    private readonly id: string;
+    constructor(correlationId?: string) {
+        this.id = correlationId ?? randomUUID().toString();
+    }
+
+    private addMetric(key: string, value: number): void {
+        this.metrics[key] = value;
+    }
+
+    private addProperty(key: string, value: string): void {
+        this.properties[key] = value;
+    }
+
+    public addCancelled(): void {
+        this.addProperty('cancelled', 'true');
+    }
+
+    public addCancellationElapsed(duration: number): void {
+        this.addMetric('cancellationElapsedMs', duration);
+    }
+
+    public addCancelledLate(): void {
+        this.addProperty('cancelledLate', 'true');
+    }
+
+    public addError(): void {
+        this.addProperty('error', 'true');
+    }
+
+    public addKind(snippetsKind: string): void {
+        this.addProperty('kind', snippetsKind.toString());
+    }
+
+    public addResolvedElapsed(duration: number): void {
+        this.addMetric('overallResolveElapsedMs', duration);
+    }
+
+    public addCacheSize(size: number): void {
+        this.addMetric('cacheSize', size);
+    }
+
+    public addCacheComputedElapsed(duration: number): void {
+        this.addMetric('cacheComputedElapsedMs', duration);
+    }
+
+    // count can be undefined, in which case the count is set to -1 to indicate
+    // snippets are not available (different than having 0 snippets).
+    public addSnippetCount(count?: number) {
+        this.addMetric('snippetsCount', count ?? -1);
+    }
+
+    public file(): void {
+        this.properties[CopilotContextTelemetry.correlationIdKey] = this.id;
+        telemetry.logCopilotEvent(CopilotContextTelemetry.copilotEventName, this.properties, this.metrics);
+    }
+
+    public fork(): CopilotContextTelemetry {
+        return new CopilotContextTelemetry(this.id);
+    }
+}
diff --git a/Extension/src/LanguageServer/extension.ts b/Extension/src/LanguageServer/extension.ts
index 8bc64f82f8..37ec03fbce 100644
--- a/Extension/src/LanguageServer/extension.ts
+++ b/Extension/src/LanguageServer/extension.ts
@@ -34,6 +34,14 @@ import { CppSettings } from './settings';
 import { LanguageStatusUI, getUI } from './ui';
 import { makeLspRange, rangeEquals, showInstallCompilerWalkthrough } from './utils';
 
+export interface SnippetEntry {
+    uri: string;
+    text: string;
+    startLine: number;
+    endLine: number;
+    importance: number;
+}
+
 nls.config({ messageFormat: nls.MessageFormat.bundle, bundleFormat: nls.BundleFormat.standalone })();
 const localize: nls.LocalizeFunc = nls.loadMessageBundle();
 export const CppSourceStr: string = "C/C++";