diff --git a/src/commands/general/init.ts b/src/commands/general/init.ts index f0ce5b16..7657fe2d 100644 --- a/src/commands/general/init.ts +++ b/src/commands/general/init.ts @@ -184,8 +184,12 @@ export async function initAction(options: InitActionOptions, simulatorService: I const ollamaAction = new OllamaAction(); const configManager = new ConfigFileManager(); const config = configManager.getConfig() + let ollamaModel = config.defaultOllamaModel; - let ollamaModel = config.defaultOllamaModel || 'llama3'; + if(!config.defaultOllamaModel){ + configManager.writeConfig('defaultOllamaModel', 'llama3'); + ollamaModel = 'llama3' + } console.log(`Pulling ${ollamaModel} from Ollama...`); diff --git a/src/commands/update/index.ts b/src/commands/update/index.ts index 161995bb..e0a19f74 100644 --- a/src/commands/update/index.ts +++ b/src/commands/update/index.ts @@ -1,5 +1,6 @@ import { Command } from "commander"; import { OllamaAction } from "./ollama"; +import { ConfigFileManager } from "../../lib/config/ConfigFileManager"; export function initializeUpdateCommands(program: Command) { const updateCommand = program @@ -12,7 +13,10 @@ export function initializeUpdateCommands(program: Command) { .option("--model [model-name]", "Specify the model to update or remove") .option("--remove", "Remove the specified model instead of updating") .action(async (options) => { - const modelName = options.model || "default-model"; + const configManager = new ConfigFileManager(); + const config = configManager.getConfig() + + const modelName = options.model || config.defaultOllamaModel; const ollamaAction = new OllamaAction(); if (options.remove) { diff --git a/src/commands/update/ollama.ts b/src/commands/update/ollama.ts index 1821a2de..6f96f54e 100644 --- a/src/commands/update/ollama.ts +++ b/src/commands/update/ollama.ts @@ -1,4 +1,5 @@ -import Docker from "dockerode" +import Docker from "dockerode"; +import { rpcClient } from "../../lib/clients/jsonRpcClient"; export class OllamaAction { private docker: Docker; @@ -8,34 +9,94 @@ export class OllamaAction { } async updateModel(modelName: string) { - await this.executeModelCommand("pull", modelName, `Model "${modelName}" updated successfully`); + const providersAndModels = await rpcClient.request({ + method: "sim_getProvidersAndModels", + params: [], + }); + + const existingOllamaProvider = providersAndModels.result.find( + (entry: any) => entry.plugin === "ollama" + ); + + if (!existingOllamaProvider) { + throw new Error("No existing 'ollama' provider found. Unable to add/update a model."); + } + + await this.executeModelCommand( + "pull", + modelName, + `Model "${modelName}" updated successfully` + ); + + const existingModel = providersAndModels.result.some( + (entry: any) => + entry.plugin === "ollama" && entry.model === modelName + ); + + if (!existingModel) { + console.log(`Model "${modelName}" not found in Provider Presets. Adding...`); + + const newModelConfig = { + config: existingOllamaProvider.config, + model: modelName, + plugin: "ollama", + plugin_config: existingOllamaProvider.plugin_config, + provider: "ollama", + }; + + await rpcClient.request({ + method: "sim_addProvider", + params: [newModelConfig], + }); + + console.log(`Model "${modelName}" added successfully.`); + } } async removeModel(modelName: string) { - await this.executeModelCommand("rm", modelName, `Model "${modelName}" removed successfully`); + await this.executeModelCommand( + "rm", + modelName, + `Model "${modelName}" removed successfully` + ); } - private async executeModelCommand(command: string, modelName: string, successMessage: string) { + private async executeModelCommand( + command: string, + modelName: string, + successMessage: string + ) { try { + let success = false; const ollamaContainer = this.docker.getContainer("ollama"); const exec = await ollamaContainer.exec({ Cmd: ["ollama", command, modelName], AttachStdout: true, AttachStderr: true, }); - const stream = await exec.start({ Detach: false, Tty: false }); + const stream = await exec.start({Detach: false, Tty: false}); stream.on("data", (chunk: any) => { - console.log(chunk.toString()); + const chunkStr = chunk.toString(); + console.log(chunkStr); + if (chunkStr.includes("success") || chunkStr.includes("deleted")) { + success = true; + } }); await new Promise((resolve, reject) => { - stream.on("end", resolve); + stream.on("end", () => { + if (success) { + resolve(); + } else { + reject('internal error'); + } + }); stream.on("error", reject); }); console.log(successMessage); - } catch (error) { + }catch (error) { console.error(`Error executing command "${command}" on model "${modelName}":`, error); } } diff --git a/tests/actions/init.test.ts b/tests/actions/init.test.ts index 522c2369..0a173eb0 100644 --- a/tests/actions/init.test.ts +++ b/tests/actions/init.test.ts @@ -9,11 +9,13 @@ import fs from "fs"; import * as dotenv from "dotenv"; import {localnetCompatibleVersion} from "../../src/lib/config/simulator"; import { OllamaAction } from "../../src/commands/update/ollama"; +import { ConfigFileManager } from "../../src/lib/config/ConfigFileManager"; vi.mock("fs"); vi.mock("dotenv"); vi.mock("../../src/commands/update/ollama") +vi.mock("../../src/lib/config/ConfigFileManager"); const tempDir = mkdtempSync(join(tmpdir(), "test-initAction-")); @@ -278,7 +280,7 @@ describe("init action", () => { simServDeleteAllValidators.mockResolvedValue(true); simServResetDockerContainers.mockResolvedValue(true); simServResetDockerImages.mockResolvedValue(true); - vi.mocked(fs.readFileSync).mockReturnValue(JSON.stringify({})); + vi.mocked(ConfigFileManager.prototype.getConfig).mockReturnValue({}); await initAction(defaultActionOptions, simulatorService); @@ -310,7 +312,7 @@ describe("init action", () => { simServDeleteAllValidators.mockResolvedValue(true); simServResetDockerContainers.mockResolvedValue(true); simServResetDockerImages.mockResolvedValue(true); - vi.mocked(fs.readFileSync).mockReturnValue(JSON.stringify({defaultOllamaModel: ollamaModel})); + vi.mocked(ConfigFileManager.prototype.getConfig).mockReturnValue({defaultOllamaModel: ollamaModel}); await initAction(defaultActionOptions, simulatorService); @@ -318,6 +320,39 @@ describe("init action", () => { expect(OllamaAction.prototype.updateModel).toHaveBeenCalled(); }); + test("should set defaultOllamaModel to llama 3 if no defaultOllamaModel is provided", async () => { + + inquirerPrompt.mockResolvedValue({ + confirmReset: true, + confirmDownload: true, + selectedLlmProviders: ["openai", "heuristai", "ollama"], + openai: "API_KEY1", + heuristai: "API_KEY2", + ollama: "API_KEY3", + }); + simServgetAiProvidersOptions.mockReturnValue([ + { name: "OpenAI", value: "openai" }, + { name: "Heurist", value: "heuristai" }, + { name: "Ollama", value: "ollama" }, + ]); + + vi.mocked(ConfigFileManager.prototype.getConfig).mockResolvedValueOnce({}) + vi.mocked(OllamaAction.prototype.updateModel).mockResolvedValueOnce(undefined); + + simServRunSimulator.mockResolvedValue(true); + simServWaitForSimulator.mockResolvedValue({ initialized: true }); + simServDeleteAllValidators.mockResolvedValue(true); + simServResetDockerContainers.mockResolvedValue(true); + simServResetDockerImages.mockResolvedValue(true); + vi.mocked(fs.readFileSync).mockReturnValue(JSON.stringify({})); + + await initAction(defaultActionOptions, simulatorService); + + expect(ConfigFileManager.prototype.writeConfig).toHaveBeenCalledWith('defaultOllamaModel', 'llama3') + expect(log).toHaveBeenCalledWith(`Pulling llama3 from Ollama...`); + expect(OllamaAction.prototype.updateModel).toHaveBeenCalled(); + }); + test("logs error if checkVersionRequirements throws", async () => { simServCheckInstallRequirements.mockResolvedValue({ git: true, docker: true }); const errorMsg = new Error("checkVersionRequirements error"); diff --git a/tests/actions/ollama.test.ts b/tests/actions/ollama.test.ts index c3b0eeb3..6827c089 100644 --- a/tests/actions/ollama.test.ts +++ b/tests/actions/ollama.test.ts @@ -1,8 +1,11 @@ -import {describe, test, vi, beforeEach, afterEach, expect, Mock} from "vitest"; +import { describe, test, vi, beforeEach, afterEach, expect, Mock } from "vitest"; import { OllamaAction } from "../../src/commands/update/ollama"; +import { rpcClient } from "../../src/lib/clients/jsonRpcClient"; + import Docker from "dockerode"; vi.mock("dockerode"); +vi.mock("../../src/lib/clients/jsonRpcClient"); describe("OllamaAction", () => { let ollamaAction: OllamaAction; @@ -41,9 +44,22 @@ describe("OllamaAction", () => { }); test("should update the model using 'pull'", async () => { - mockStream.on.mockImplementation((event: any, callback:any) => { - if (event === "data") callback(Buffer.from("Mocked output")); - if (event === "end") callback(); + const mockProvider = { + plugin: "ollama", + config: { key: "value" }, + plugin_config: { pluginKey: "pluginValue" }, + }; + vi.mocked(rpcClient.request).mockResolvedValueOnce({ + result: [mockProvider], + }); + + mockStream.on.mockImplementation((event: any, callback: any) => { + if (event === "data") { + callback(Buffer.from("Mocked output success")); + } + if (event === "end") { + callback(); + } }); console.log = vi.fn(); @@ -59,14 +75,18 @@ describe("OllamaAction", () => { expect(mockStart).toHaveBeenCalledWith({ Detach: false, Tty: false }); expect(mockStream.on).toHaveBeenCalledWith("data", expect.any(Function)); expect(mockStream.on).toHaveBeenCalledWith("end", expect.any(Function)); - expect(console.log).toHaveBeenCalledWith("Mocked output"); + expect(console.log).toHaveBeenCalledWith("Mocked output success"); expect(console.log).toHaveBeenCalledWith('Model "mocked_model" updated successfully'); }); test("should remove the model using 'rm'", async () => { - mockStream.on.mockImplementation((event:any, callback:any) => { - if (event === "data") callback(Buffer.from("Mocked output")); - if (event === "end") callback(); + mockStream.on.mockImplementation((event: any, callback: any) => { + if (event === "data") { + callback(Buffer.from("Mocked output success")); + } + if (event === "end") { + callback(); + } }); console.log = vi.fn(); @@ -82,11 +102,20 @@ describe("OllamaAction", () => { expect(mockStart).toHaveBeenCalledWith({ Detach: false, Tty: false }); expect(mockStream.on).toHaveBeenCalledWith("data", expect.any(Function)); expect(mockStream.on).toHaveBeenCalledWith("end", expect.any(Function)); - expect(console.log).toHaveBeenCalledWith("Mocked output"); + expect(console.log).toHaveBeenCalledWith("Mocked output success"); expect(console.log).toHaveBeenCalledWith('Model "mocked_model" removed successfully'); }); test("should log an error if an exception occurs during 'pull'", async () => { + const mockProvider = { + plugin: "ollama", + config: { key: "value" }, + plugin_config: { pluginKey: "pluginValue" }, + }; + vi.mocked(rpcClient.request).mockResolvedValueOnce({ + result: [mockProvider], + }); + const error = new Error("Mocked error"); mockGetContainer.mockReturnValueOnce( { @@ -126,4 +155,54 @@ describe("OllamaAction", () => { error ); }); + + test("should throw an error if no 'ollama' provider exists during updateModel", async () => { + vi.mocked(rpcClient.request).mockResolvedValueOnce({ + result: [], + }); + + const modelName = "mocked_model"; + + await expect(ollamaAction.updateModel(modelName)).rejects.toThrowError( + "No existing 'ollama' provider found. Unable to add/update a model." + ); + + expect(rpcClient.request).toHaveBeenCalledWith({ + method: "sim_getProvidersAndModels", + params: [], + }); + }); + + test("should reject with an error if success is not set to true", async () => { + console.error = vi.fn(); + + const mockProvider = { + plugin: "ollama", + config: { key: "value" }, + plugin_config: { pluginKey: "pluginValue" }, + }; + + vi.mocked(rpcClient.request).mockResolvedValueOnce({ + result: [mockProvider], + }); + + mockStream.on.mockImplementation((event: any, callback: any) => { + if (event === "data") { + callback(Buffer.from("Mocked output failure")); + } + if (event === "end") { + callback(); + } + }); + + console.log = vi.fn(); + console.error = vi.fn(); + + await ollamaAction.updateModel("mocked_model"); + + expect(console.error).toHaveBeenCalledWith( + 'Error executing command "pull" on model "mocked_model":', 'internal error' + ); + }); + }); diff --git a/tests/commands/update.test.ts b/tests/commands/update.test.ts index 00dcb947..dc30bb50 100644 --- a/tests/commands/update.test.ts +++ b/tests/commands/update.test.ts @@ -2,8 +2,10 @@ import { Command } from "commander"; import { vi, describe, beforeEach, afterEach, test, expect } from "vitest"; import { initializeUpdateCommands } from "../../src/commands/update"; import { OllamaAction } from "../../src/commands/update/ollama"; +import { ConfigFileManager } from "../../src/lib/config/ConfigFileManager"; vi.mock("../../src/commands/update/ollama"); +vi.mock("../../src/lib/config/ConfigFileManager"); describe("ollama command", () => { let program: Command; @@ -11,6 +13,9 @@ describe("ollama command", () => { beforeEach(() => { program = new Command(); initializeUpdateCommands(program); + + const mockConfig = { defaultOllamaModel: "default-model" }; + vi.mocked(ConfigFileManager.prototype.getConfig).mockReturnValue(mockConfig); }); afterEach(() => {