diff --git a/api/src/run.ts b/api/src/run.ts index 1624808b..455084f0 100644 --- a/api/src/run.ts +++ b/api/src/run.ts @@ -3,4 +3,21 @@ import { startServer } from "./server"; const repoDir = `${process.cwd()}/example-repo`; console.log("repoDir", repoDir); -startServer({ port: 4000, repoDir }); +const args = process.argv.slice(2); + +const options = { + copilotIpAddress: "127.0.0.1", + copilotPort: 9090, +}; + +for (let i = 0; i < args.length; i++) { + if (args[i] === "--copilotIP" && i + 1 < args.length) { + options.copilotIpAddress = args[i + 1]; + i++; + } else if (args[i] === "--copilotPort" && i + 1 < args.length) { + options.copilotPort = Number(args[i + 1]); + i++; + } +} + +startServer({ port: 4000, repoDir, ...options }); diff --git a/api/src/server.ts b/api/src/server.ts index c8ad55f9..1a523a40 100644 --- a/api/src/server.ts +++ b/api/src/server.ts @@ -10,7 +10,12 @@ import { bindState, writeState } from "./yjs/yjs-blob"; import cors from "cors"; import { createSpawnerRouter, router } from "./spawner/trpc"; -export async function startServer({ port, repoDir }) { +export async function startServer({ + port, + repoDir, + copilotIpAddress = "127.0.0.1", + copilotPort = 9090, +}) { console.log("starting server .."); const app = express(); app.use(express.json({ limit: "20mb" })); @@ -27,7 +32,11 @@ export async function startServer({ port, repoDir }) { "/trpc", trpcExpress.createExpressMiddleware({ router: router({ - spawner: createSpawnerRouter(yjsServerUrl), + spawner: createSpawnerRouter( + yjsServerUrl, + copilotIpAddress, + copilotPort + ), }), }) ); @@ -59,6 +68,8 @@ export async function startServer({ port, repoDir }) { }); http_server.listen({ port }, () => { - console.log(`🚀 Server ready at http://localhost:${port}`); + console.log( + `🚀 Server ready at http://localhost:${port}, LLM Copilot is hosted at ${copilotIpAddress}:${copilotPort}` + ); }); } diff --git a/api/src/spawner/trpc.ts b/api/src/spawner/trpc.ts index 87d1631e..49c0b801 100644 --- a/api/src/spawner/trpc.ts +++ b/api/src/spawner/trpc.ts @@ -3,6 +3,7 @@ const t = initTRPC.create(); export const router = t.router; export const publicProcedure = t.procedure; +import express from "express"; import Y from "yjs"; import WebSocket from "ws"; import { z } from "zod"; @@ -17,6 +18,12 @@ import { connectSocket, runtime2socket, RuntimeInfo } from "./yjs_runtime"; // FIXME need to have a TTL to clear the ydoc. const docs: Map = new Map(); +// FIXME hard-coded yjs server url +const yjsServerUrl = `ws://localhost:4000/socket`; + +const app = express(); +const http = require("http"); + async function getMyYDoc({ repoId, yjsServerUrl }): Promise { return new Promise((resolve, reject) => { const oldydoc = docs.get(repoId); @@ -52,7 +59,11 @@ async function getMyYDoc({ repoId, yjsServerUrl }): Promise { const routingTable: Map = new Map(); -export function createSpawnerRouter(yjsServerUrl) { +export function createSpawnerRouter( + yjsServerUrl, + copilotIpAddress, + copilotPort +) { return router({ spawnRuntime: publicProcedure .input(z.object({ runtimeId: z.string(), repoId: z.string() })) @@ -227,11 +238,100 @@ export function createSpawnerRouter(yjsServerUrl) { ); return true; }), + codeAutoComplete: publicProcedure + .input( + z.object({ + inputPrefix: z.string(), + inputSuffix: z.string(), + podId: z.string(), + }) + ) + .mutation(async ({ input: { inputPrefix, inputSuffix, podId } }) => { + console.log( + `======= codeAutoComplete of pod ${podId} ========\n`, + inputPrefix, + inputSuffix + ); + let data = ""; + let options = {}; + if (inputSuffix.length == 0) { + data = JSON.stringify({ + prompt: inputPrefix, + temperature: 0.1, + top_k: 40, + top_p: 0.9, + repeat_penalty: 1.05, + // large n_predict significantly slows down the server, a small value is good enough for testing purposes + n_predict: 128, + stream: false, + }); + + options = { + hostname: copilotIpAddress, + port: copilotPort, + path: "/completion", + method: "POST", + headers: { + "Content-Type": "application/json", + "Content-Length": data.length, + }, + }; + } else { + data = JSON.stringify({ + input_prefix: inputPrefix, + input_suffix: inputSuffix, + temperature: 0.1, + top_k: 40, + top_p: 0.9, + repeat_penalty: 1.05, + // large n_predict significantly slows down the server, a small value is good enough for testing purposes + n_predict: 128, + }); + + options = { + hostname: copilotIpAddress, + port: copilotPort, + path: "/infill", + method: "POST", + headers: { + "Content-Type": "application/json", + "Content-Length": data.length, + }, + }; + } + + return new Promise((resolve, reject) => { + const req = http.request(options, (res) => { + let responseData = ""; + + res.on("data", (chunk) => { + responseData += chunk; + }); + + res.on("end", () => { + if (responseData.toString() === "") { + resolve(""); // Resolve with an empty string if no data + } + const resData = JSON.parse(responseData.toString()); + console.log(res.statusCode, resData["content"]); + resolve(resData["content"]); // Resolve the Promise with the response data + }); + }); + + req.on("error", (error) => { + console.error(error); + reject(error); // Reject the Promise if an error occurs + }); + + req.write(data); + req.end(); + }); + }), }); } // This is only used for frontend to get the type of router. const _appRouter_for_type = router({ - spawner: createSpawnerRouter(null), // put procedures under "post" namespace + spawner: createSpawnerRouter(null, null, null), // put procedures under "post" namespace }); export type AppRouter = typeof _appRouter_for_type; diff --git a/ui/src/components/MyMonaco.tsx b/ui/src/components/MyMonaco.tsx index c1ba2570..49d1188f 100644 --- a/ui/src/components/MyMonaco.tsx +++ b/ui/src/components/MyMonaco.tsx @@ -13,6 +13,8 @@ import { Annotation } from "../lib/parser"; import { useApolloClient } from "@apollo/client"; import { trpc } from "../lib/trpc"; +import { llamaInlineCompletionProvider } from "../lib/llamaCompletionProvider"; + const theme: monaco.editor.IStandaloneThemeData = { base: "vs", inherit: true, @@ -404,6 +406,8 @@ export const MyMonaco = memo(function MyMonaco({ (state) => state.parseResult[id]?.annotations ); const showAnnotations = useStore(store, (state) => state.showAnnotations); + const copilotManualMode = useStore(store, (state) => state.copilotManualMode); + const scopedVars = useStore(store, (state) => state.scopedVars); const updateView = useStore(store, (state) => state.updateView); @@ -437,6 +441,7 @@ export const MyMonaco = memo(function MyMonaco({ const selectPod = useStore(store, (state) => state.selectPod); const resetSelection = useStore(store, (state) => state.resetSelection); const editMode = useStore(store, (state) => state.editMode); + const { client } = trpc.useUtils(); // FIXME useCallback? function onEditorDidMount( @@ -488,11 +493,33 @@ export const MyMonaco = memo(function MyMonaco({ }, }); + editor.addAction({ + id: "trigger-inline-suggest", + label: "Trigger Suggest", + keybindings: [ + monaco.KeyMod.WinCtrl | monaco.KeyMod.Shift | monaco.KeyCode.Space, + ], + run: () => { + editor.trigger(null, "editor.action.inlineSuggest.trigger", null); + }, + }); + // editor.onDidChangeModelContent(async (e) => { // // content is value? // updateGitGutter(editor); // }); + const llamaCompletionProvider = new llamaInlineCompletionProvider( + id, + editor, + client, + copilotManualMode || false + ); + monaco.languages.registerInlineCompletionsProvider( + "python", + llamaCompletionProvider + ); + // bind it to the ytext with pod id if (!codeMap.has(id)) { throw new Error("codeMap doesn't have pod " + id); diff --git a/ui/src/components/Sidebar.tsx b/ui/src/components/Sidebar.tsx index 22f4bc03..3aec8edf 100644 --- a/ui/src/components/Sidebar.tsx +++ b/ui/src/components/Sidebar.tsx @@ -64,6 +64,12 @@ function SidebarSettings() { ); const devMode = useStore(store, (state) => state.devMode); const setDevMode = useStore(store, (state) => state.setDevMode); + const copilotManualMode = useStore(store, (state) => state.copilotManualMode); + const setCopilotManualMode = useStore( + store, + (state) => state.setCopilotManualMode + ); + const showLineNumbers = useStore(store, (state) => state.showLineNumbers); const setShowLineNumbers = useStore( store, @@ -488,6 +494,26 @@ function SidebarSettings() { /> + + + ) => { + setCopilotManualMode(event.target.checked); + }} + /> + } + label="Trigger Copilot Manually" + /> + + {showAnnotations && ( Function Definition diff --git a/ui/src/lib/llamaCompletionProvider.ts b/ui/src/lib/llamaCompletionProvider.ts new file mode 100644 index 00000000..58e08586 --- /dev/null +++ b/ui/src/lib/llamaCompletionProvider.ts @@ -0,0 +1,114 @@ +import { monaco } from "react-monaco-editor"; + +export class llamaInlineCompletionProvider + implements monaco.languages.InlineCompletionsProvider +{ + private readonly podId: string; + private readonly editor: monaco.editor.IStandaloneCodeEditor; + private readonly trpc: any; + private readonly manualMode: boolean; + private isFetchingSuggestions: boolean; // Flag to track if a fetch operation is in progress + + constructor( + podId: string, + editor: monaco.editor.IStandaloneCodeEditor, + trpc: any, + manualMode: boolean + ) { + this.podId = podId; + this.editor = editor; + this.trpc = trpc; + this.manualMode = manualMode; + this.isFetchingSuggestions = false; // Initialize the flag + } + + private async provideSuggestions(prefix: string, suffix: string) { + const suggestion = await this.trpc.spawner.codeAutoComplete.mutate({ + inputPrefix: prefix, + inputSuffix: suffix, + podId: this.podId, + }); + return suggestion; + } + public async provideInlineCompletions( + model: monaco.editor.ITextModel, + position: monaco.IPosition, + context: monaco.languages.InlineCompletionContext, + token: monaco.CancellationToken + ): Promise { + if (!this.editor.hasTextFocus() || token.isCancellationRequested) { + return; + } + if ( + context.triggerKind === + monaco.languages.InlineCompletionTriggerKind.Automatic && + this.manualMode + ) { + return; + } + + if (!this.isFetchingSuggestions) { + this.isFetchingSuggestions = true; + try { + // Get text before the position + let inputPrefix = model.getValueInRange({ + startLineNumber: 1, + startColumn: 1, + endLineNumber: position.lineNumber, + endColumn: position.column, + }); + + // Get text after the position + let inputSuffix = model.getValueInRange({ + startLineNumber: position.lineNumber, + startColumn: position.column, + endLineNumber: model.getLineCount(), + endColumn: model.getLineMaxColumn(model.getLineCount()), + }); + + console.log(inputPrefix); + console.log(inputSuffix); + + if (/^\s*$/.test(inputPrefix || " ")) { + inputPrefix = inputPrefix.trim(); + } + if (/^\s*$/.test(inputSuffix || " ")) { + inputSuffix = inputSuffix.trim(); + } + + if (inputPrefix === "" && inputSuffix === "") { + return; + } + const suggestion = await this.provideSuggestions( + inputPrefix, + inputSuffix + ); + + return { + items: [ + { + insertText: suggestion, + range: new monaco.Range( + position.lineNumber, + position.column, + position.lineNumber, + position.column + ), + }, + ], + }; + } finally { + this.isFetchingSuggestions = false; + } + } + } + + handleItemDidShow?( + completions: monaco.languages.InlineCompletions, + item: monaco.languages.InlineCompletion + ): void {} + + freeInlineCompletions( + completions: monaco.languages.InlineCompletions + ): void {} +} diff --git a/ui/src/lib/store/settingSlice.ts b/ui/src/lib/store/settingSlice.ts index a942402a..e7c0b689 100644 --- a/ui/src/lib/store/settingSlice.ts +++ b/ui/src/lib/store/settingSlice.ts @@ -8,6 +8,8 @@ export interface SettingSlice { setShowAnnotations: (b: boolean) => void; devMode?: boolean; setDevMode: (b: boolean) => void; + copilotManualMode?: boolean; + setCopilotManualMode: (b: boolean) => void; autoRunLayout?: boolean; setAutoRunLayout: (b: boolean) => void; contextualZoomParams: Record; @@ -56,6 +58,17 @@ export const createSettingSlice: StateCreator = ( // also write to local storage localStorage.setItem("devMode", JSON.stringify(b)); }, + + copilotManualMode: localStorage.getItem("copilotManualMode") + ? JSON.parse(localStorage.getItem("copilotManualMode")!) + : false, + setCopilotManualMode: (b: boolean) => { + // set it + set({ copilotManualMode: b }); + // also write to local storage + localStorage.setItem("copilotManualMode", JSON.stringify(b)); + }, + autoRunLayout: localStorage.getItem("autoRunLayout") ? JSON.parse(localStorage.getItem("autoRunLayout")!) : true,