Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/commands/general/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,12 @@ export async function initAction(options: InitActionOptions, simulatorService: I
const ollamaAction = new OllamaAction();
const configManager = new ConfigFileManager();
const config = configManager.getConfig()
let ollamaModel = config.defaultOllamaModel;

let ollamaModel = config.defaultOllamaModel || 'llama3';
if(!config.defaultOllamaModel){
configManager.writeConfig('defaultOllamaModel', 'llama3');
ollamaModel = 'llama3'
}

console.log(`Pulling ${ollamaModel} from Ollama...`);

Expand Down
6 changes: 5 additions & 1 deletion src/commands/update/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Command } from "commander";
import { OllamaAction } from "./ollama";
import { ConfigFileManager } from "../../lib/config/ConfigFileManager";

export function initializeUpdateCommands(program: Command) {
const updateCommand = program
Expand All @@ -12,7 +13,10 @@ export function initializeUpdateCommands(program: Command) {
.option("--model [model-name]", "Specify the model to update or remove")
.option("--remove", "Remove the specified model instead of updating")
.action(async (options) => {
const modelName = options.model || "default-model";
const configManager = new ConfigFileManager();
const config = configManager.getConfig()

const modelName = options.model || config.defaultOllamaModel;
const ollamaAction = new OllamaAction();

if (options.remove) {
Expand Down
77 changes: 69 additions & 8 deletions src/commands/update/ollama.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Docker from "dockerode"
import Docker from "dockerode";
import { rpcClient } from "../../lib/clients/jsonRpcClient";

export class OllamaAction {
private docker: Docker;
Expand All @@ -8,34 +9,94 @@ export class OllamaAction {
}

async updateModel(modelName: string) {
await this.executeModelCommand("pull", modelName, `Model "${modelName}" updated successfully`);
const providersAndModels = await rpcClient.request({
method: "sim_getProvidersAndModels",
params: [],
});

const existingOllamaProvider = providersAndModels.result.find(
(entry: any) => entry.plugin === "ollama"
);

if (!existingOllamaProvider) {
throw new Error("No existing 'ollama' provider found. Unable to add/update a model.");
}

await this.executeModelCommand(
"pull",
modelName,
`Model "${modelName}" updated successfully`
);

const existingModel = providersAndModels.result.some(
(entry: any) =>
entry.plugin === "ollama" && entry.model === modelName
);

if (!existingModel) {
console.log(`Model "${modelName}" not found in Provider Presets. Adding...`);

const newModelConfig = {
config: existingOllamaProvider.config,
model: modelName,
plugin: "ollama",
plugin_config: existingOllamaProvider.plugin_config,
provider: "ollama",
};

await rpcClient.request({
method: "sim_addProvider",
params: [newModelConfig],
});

console.log(`Model "${modelName}" added successfully.`);
}
}

async removeModel(modelName: string) {
await this.executeModelCommand("rm", modelName, `Model "${modelName}" removed successfully`);
await this.executeModelCommand(
"rm",
modelName,
`Model "${modelName}" removed successfully`
);
}

private async executeModelCommand(command: string, modelName: string, successMessage: string) {
private async executeModelCommand(
command: string,
modelName: string,
successMessage: string
) {
try {
let success = false;
const ollamaContainer = this.docker.getContainer("ollama");
const exec = await ollamaContainer.exec({
Cmd: ["ollama", command, modelName],
AttachStdout: true,
AttachStderr: true,
});
const stream = await exec.start({ Detach: false, Tty: false });
const stream = await exec.start({Detach: false, Tty: false});

stream.on("data", (chunk: any) => {
console.log(chunk.toString());
const chunkStr = chunk.toString();
console.log(chunkStr);
if (chunkStr.includes("success") || chunkStr.includes("deleted")) {
success = true;
}
});

await new Promise<void>((resolve, reject) => {
stream.on("end", resolve);
stream.on("end", () => {
if (success) {
resolve();
} else {
reject('internal error');
}
});
stream.on("error", reject);
});

console.log(successMessage);
} catch (error) {
}catch (error) {
console.error(`Error executing command "${command}" on model "${modelName}":`, error);
}
}
Expand Down
39 changes: 37 additions & 2 deletions tests/actions/init.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import fs from "fs";
import * as dotenv from "dotenv";
import {localnetCompatibleVersion} from "../../src/lib/config/simulator";
import { OllamaAction } from "../../src/commands/update/ollama";
import { ConfigFileManager } from "../../src/lib/config/ConfigFileManager";


vi.mock("fs");
vi.mock("dotenv");
vi.mock("../../src/commands/update/ollama")
vi.mock("../../src/lib/config/ConfigFileManager");


const tempDir = mkdtempSync(join(tmpdir(), "test-initAction-"));
Expand Down Expand Up @@ -278,7 +280,7 @@ describe("init action", () => {
simServDeleteAllValidators.mockResolvedValue(true);
simServResetDockerContainers.mockResolvedValue(true);
simServResetDockerImages.mockResolvedValue(true);
vi.mocked(fs.readFileSync).mockReturnValue(JSON.stringify({}));
vi.mocked(ConfigFileManager.prototype.getConfig).mockReturnValue({});

await initAction(defaultActionOptions, simulatorService);

Expand Down Expand Up @@ -310,14 +312,47 @@ describe("init action", () => {
simServDeleteAllValidators.mockResolvedValue(true);
simServResetDockerContainers.mockResolvedValue(true);
simServResetDockerImages.mockResolvedValue(true);
vi.mocked(fs.readFileSync).mockReturnValue(JSON.stringify({defaultOllamaModel: ollamaModel}));
vi.mocked(ConfigFileManager.prototype.getConfig).mockReturnValue({defaultOllamaModel: ollamaModel});

await initAction(defaultActionOptions, simulatorService);

expect(log).toHaveBeenCalledWith(`Pulling ${ollamaModel} from Ollama...`);
expect(OllamaAction.prototype.updateModel).toHaveBeenCalled();
});

test("should set defaultOllamaModel to llama 3 if no defaultOllamaModel is provided", async () => {

inquirerPrompt.mockResolvedValue({
confirmReset: true,
confirmDownload: true,
selectedLlmProviders: ["openai", "heuristai", "ollama"],
openai: "API_KEY1",
heuristai: "API_KEY2",
ollama: "API_KEY3",
});
simServgetAiProvidersOptions.mockReturnValue([
{ name: "OpenAI", value: "openai" },
{ name: "Heurist", value: "heuristai" },
{ name: "Ollama", value: "ollama" },
]);

vi.mocked(ConfigFileManager.prototype.getConfig).mockResolvedValueOnce({})
vi.mocked(OllamaAction.prototype.updateModel).mockResolvedValueOnce(undefined);

simServRunSimulator.mockResolvedValue(true);
simServWaitForSimulator.mockResolvedValue({ initialized: true });
simServDeleteAllValidators.mockResolvedValue(true);
simServResetDockerContainers.mockResolvedValue(true);
simServResetDockerImages.mockResolvedValue(true);
vi.mocked(fs.readFileSync).mockReturnValue(JSON.stringify({}));

await initAction(defaultActionOptions, simulatorService);

expect(ConfigFileManager.prototype.writeConfig).toHaveBeenCalledWith('defaultOllamaModel', 'llama3')
expect(log).toHaveBeenCalledWith(`Pulling llama3 from Ollama...`);
expect(OllamaAction.prototype.updateModel).toHaveBeenCalled();
});

test("logs error if checkVersionRequirements throws", async () => {
simServCheckInstallRequirements.mockResolvedValue({ git: true, docker: true });
const errorMsg = new Error("checkVersionRequirements error");
Expand Down
97 changes: 88 additions & 9 deletions tests/actions/ollama.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import {describe, test, vi, beforeEach, afterEach, expect, Mock} from "vitest";
import { describe, test, vi, beforeEach, afterEach, expect, Mock } from "vitest";
import { OllamaAction } from "../../src/commands/update/ollama";
import { rpcClient } from "../../src/lib/clients/jsonRpcClient";

import Docker from "dockerode";

vi.mock("dockerode");
vi.mock("../../src/lib/clients/jsonRpcClient");

describe("OllamaAction", () => {
let ollamaAction: OllamaAction;
Expand Down Expand Up @@ -41,9 +44,22 @@ describe("OllamaAction", () => {
});

test("should update the model using 'pull'", async () => {
mockStream.on.mockImplementation((event: any, callback:any) => {
if (event === "data") callback(Buffer.from("Mocked output"));
if (event === "end") callback();
const mockProvider = {
plugin: "ollama",
config: { key: "value" },
plugin_config: { pluginKey: "pluginValue" },
};
vi.mocked(rpcClient.request).mockResolvedValueOnce({
result: [mockProvider],
});

mockStream.on.mockImplementation((event: any, callback: any) => {
if (event === "data") {
callback(Buffer.from("Mocked output success"));
}
if (event === "end") {
callback();
}
});

console.log = vi.fn();
Expand All @@ -59,14 +75,18 @@ describe("OllamaAction", () => {
expect(mockStart).toHaveBeenCalledWith({ Detach: false, Tty: false });
expect(mockStream.on).toHaveBeenCalledWith("data", expect.any(Function));
expect(mockStream.on).toHaveBeenCalledWith("end", expect.any(Function));
expect(console.log).toHaveBeenCalledWith("Mocked output");
expect(console.log).toHaveBeenCalledWith("Mocked output success");
expect(console.log).toHaveBeenCalledWith('Model "mocked_model" updated successfully');
});

test("should remove the model using 'rm'", async () => {
mockStream.on.mockImplementation((event:any, callback:any) => {
if (event === "data") callback(Buffer.from("Mocked output"));
if (event === "end") callback();
mockStream.on.mockImplementation((event: any, callback: any) => {
if (event === "data") {
callback(Buffer.from("Mocked output success"));
}
if (event === "end") {
callback();
}
});

console.log = vi.fn();
Expand All @@ -82,11 +102,20 @@ describe("OllamaAction", () => {
expect(mockStart).toHaveBeenCalledWith({ Detach: false, Tty: false });
expect(mockStream.on).toHaveBeenCalledWith("data", expect.any(Function));
expect(mockStream.on).toHaveBeenCalledWith("end", expect.any(Function));
expect(console.log).toHaveBeenCalledWith("Mocked output");
expect(console.log).toHaveBeenCalledWith("Mocked output success");
expect(console.log).toHaveBeenCalledWith('Model "mocked_model" removed successfully');
});

test("should log an error if an exception occurs during 'pull'", async () => {
const mockProvider = {
plugin: "ollama",
config: { key: "value" },
plugin_config: { pluginKey: "pluginValue" },
};
vi.mocked(rpcClient.request).mockResolvedValueOnce({
result: [mockProvider],
});

const error = new Error("Mocked error");
mockGetContainer.mockReturnValueOnce(
{
Expand Down Expand Up @@ -126,4 +155,54 @@ describe("OllamaAction", () => {
error
);
});

test("should throw an error if no 'ollama' provider exists during updateModel", async () => {
vi.mocked(rpcClient.request).mockResolvedValueOnce({
result: [],
});

const modelName = "mocked_model";

await expect(ollamaAction.updateModel(modelName)).rejects.toThrowError(
"No existing 'ollama' provider found. Unable to add/update a model."
);

expect(rpcClient.request).toHaveBeenCalledWith({
method: "sim_getProvidersAndModels",
params: [],
});
});

test("should reject with an error if success is not set to true", async () => {
console.error = vi.fn();

const mockProvider = {
plugin: "ollama",
config: { key: "value" },
plugin_config: { pluginKey: "pluginValue" },
};

vi.mocked(rpcClient.request).mockResolvedValueOnce({
result: [mockProvider],
});

mockStream.on.mockImplementation((event: any, callback: any) => {
if (event === "data") {
callback(Buffer.from("Mocked output failure"));
}
if (event === "end") {
callback();
}
});

console.log = vi.fn();
console.error = vi.fn();

await ollamaAction.updateModel("mocked_model");

expect(console.error).toHaveBeenCalledWith(
'Error executing command "pull" on model "mocked_model":', 'internal error'
);
});

});
5 changes: 5 additions & 0 deletions tests/commands/update.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@ import { Command } from "commander";
import { vi, describe, beforeEach, afterEach, test, expect } from "vitest";
import { initializeUpdateCommands } from "../../src/commands/update";
import { OllamaAction } from "../../src/commands/update/ollama";
import { ConfigFileManager } from "../../src/lib/config/ConfigFileManager";

vi.mock("../../src/commands/update/ollama");
vi.mock("../../src/lib/config/ConfigFileManager");

describe("ollama command", () => {
let program: Command;

beforeEach(() => {
program = new Command();
initializeUpdateCommands(program);

const mockConfig = { defaultOllamaModel: "default-model" };
vi.mocked(ConfigFileManager.prototype.getConfig).mockReturnValue(mockConfig);
});

afterEach(() => {
Expand Down
Loading