diff --git a/CHANGELOG.md b/CHANGELOG.md index 9254114..dce1000 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,3 +12,4 @@ The format is based on Keep a Changelog and this project uses Semantic Versionin - Baseline local test harness via `scripts/test.sh` running package `xcodebuild` tests. - Tahoe-first glass-surface UI with fallback styling and floating desktop window behavior. - Microphone permission-gated audio startup coordinator for wake-word and STT flow. +- Model runtime bootstrap with first-run downloader, hash verification, and persona prompt builder. diff --git a/KAMIBotApp/Sources/KAMIBotApp/AppContainer.swift b/KAMIBotApp/Sources/KAMIBotApp/AppContainer.swift index 4b29af0..14d27bb 100644 --- a/KAMIBotApp/Sources/KAMIBotApp/AppContainer.swift +++ b/KAMIBotApp/Sources/KAMIBotApp/AppContainer.swift @@ -8,6 +8,7 @@ import VisionPipeline struct AppContainer { let agent: BMOAgent let audioStartupCoordinator: AudioStartupCoordinator + let modelStartupCoordinator: ModelStartupCoordinator init(config: AgentConfig = AgentConfig()) { let wakeWord = PorcupineWakeWordService(keyword: config.wakeWord) @@ -18,6 +19,13 @@ struct AppContainer { let modelStore = URL(fileURLWithPath: FileManager.default.currentDirectoryPath) .appendingPathComponent("models", isDirectory: true) let llm = MLXLLMService(modelID: config.llmModelID, modelStore: modelStore) + let modelDownloader = ModelDownloader(baseDirectory: modelStore) + let modelDescriptor = Self.resolveModelDescriptor(for: config) + self.modelStartupCoordinator = ModelStartupCoordinator( + downloader: modelDownloader, + descriptor: modelDescriptor, + llmService: llm + ) let vision = SnapshotVisionService(enabled: config.visionEnabled) self.audioStartupCoordinator = AudioStartupCoordinator(permissionProvider: permissionProvider) @@ -31,4 +39,26 @@ struct AppContainer { visionService: vision ) } + + private static func resolveModelDescriptor(for config: AgentConfig) -> ModelDescriptor { + let env = ProcessInfo.processInfo.environment + if let urlString = env["KAMI_BOT_MODEL_URL"], + let url = URL(string: urlString), + let sha = env["KAMI_BOT_MODEL_SHA256"], + !sha.isEmpty { + return ModelDescriptor( + id: config.llmModelID, + url: url, + sha256: sha, + license: env["KAMI_BOT_MODEL_LICENSE"] ?? "Custom" + ) + } + + return ModelDescriptor( + id: config.llmModelID, + url: ModelCatalog.llama31_8B4bit.url, + sha256: ModelCatalog.llama31_8B4bit.sha256, + license: ModelCatalog.llama31_8B4bit.license + ) + } } diff --git a/KAMIBotApp/Sources/KAMIBotApp/BMOViewModel.swift b/KAMIBotApp/Sources/KAMIBotApp/BMOViewModel.swift index fed9670..8768724 100644 --- a/KAMIBotApp/Sources/KAMIBotApp/BMOViewModel.swift +++ b/KAMIBotApp/Sources/KAMIBotApp/BMOViewModel.swift @@ -1,5 +1,6 @@ import AudioPipeline import CoreAgent +import ModelRuntime import Observation @MainActor @@ -7,15 +8,21 @@ import Observation final class BMOViewModel { private let agent: BMOAgent private let audioStartupCoordinator: AudioStartupCoordinator + private let modelStartupCoordinator: ModelStartupCoordinator private var streamTask: Task? var state: BMOState = .idle var expression: FaceExpression = .happy var transcript: [String] = [] - init(agent: BMOAgent, audioStartupCoordinator: AudioStartupCoordinator) { + init( + agent: BMOAgent, + audioStartupCoordinator: AudioStartupCoordinator, + modelStartupCoordinator: ModelStartupCoordinator + ) { self.agent = agent self.audioStartupCoordinator = audioStartupCoordinator + self.modelStartupCoordinator = modelStartupCoordinator } func start() { @@ -38,6 +45,15 @@ final class BMOViewModel { return } + do { + _ = try await modelStartupCoordinator.prepareModel() + } catch { + state = .error + transcript.append("Error: Model startup failed: \(error.localizedDescription)") + streamTask = nil + return + } + await agent.start() for await event in agent.eventStream() { switch event { diff --git a/KAMIBotApp/Sources/KAMIBotApp/KAMIBotApp.swift b/KAMIBotApp/Sources/KAMIBotApp/KAMIBotApp.swift index 49fd00f..bc1f8cc 100644 --- a/KAMIBotApp/Sources/KAMIBotApp/KAMIBotApp.swift +++ b/KAMIBotApp/Sources/KAMIBotApp/KAMIBotApp.swift @@ -9,7 +9,8 @@ struct KAMIBotApp: App { _viewModel = State( initialValue: BMOViewModel( agent: container.agent, - audioStartupCoordinator: container.audioStartupCoordinator + audioStartupCoordinator: container.audioStartupCoordinator, + modelStartupCoordinator: container.modelStartupCoordinator ) ) } diff --git a/KAMIBotApp/Tests/KAMIBotAppTests/KAMIBotAppTests.swift b/KAMIBotApp/Tests/KAMIBotAppTests/KAMIBotAppTests.swift index 70d9737..0a06409 100644 --- a/KAMIBotApp/Tests/KAMIBotAppTests/KAMIBotAppTests.swift +++ b/KAMIBotApp/Tests/KAMIBotAppTests/KAMIBotAppTests.swift @@ -7,6 +7,7 @@ final class KAMIBotAppTests: XCTestCase { let container = AppContainer() _ = container.agent _ = container.audioStartupCoordinator + _ = container.modelStartupCoordinator } func testFloatingWindowConfigDefaults() { diff --git a/Packages/ModelRuntime/Sources/ModelRuntime/ModelRuntime.swift b/Packages/ModelRuntime/Sources/ModelRuntime/ModelRuntime.swift index 235a029..0dd8a09 100644 --- a/Packages/ModelRuntime/Sources/ModelRuntime/ModelRuntime.swift +++ b/Packages/ModelRuntime/Sources/ModelRuntime/ModelRuntime.swift @@ -2,6 +2,19 @@ import CoreAgent import CryptoKit import Foundation +public protocol LLMGenerating: Sendable { + func generate(systemPrompt: String, prompt: String) async throws -> String +} + +public struct PromptEchoEngine: LLMGenerating { + public init() {} + + public func generate(systemPrompt: String, prompt: String) async throws -> String { + let prefix = "[BMO]" + return "\(prefix) \(prompt)" + } +} + public struct ModelDescriptor: Sendable, Codable { public var id: String public var url: URL @@ -16,10 +29,21 @@ public struct ModelDescriptor: Sendable, Codable { } } +public enum ModelCatalog { + public static let llama31_8B4bit = ModelDescriptor( + id: "llama-3.1-8b-4bit", + url: URL(string: "https://huggingface.co/mlx-community/Meta-Llama-3.1-8B-Instruct-4bit/resolve/main/model.safetensors")!, + // Placeholder hash until release packaging flow pins a verified artifact. + sha256: "replace-with-verified-sha256-from-release-manifest", + license: "Llama 3.1 Community License" + ) +} + public enum ModelRuntimeError: Error, Equatable { case modelNotFound(String) case downloadFailed(String) case hashMismatch(expected: String, got: String) + case invalidManifest(String) } public actor ModelDownloader { @@ -36,6 +60,11 @@ public actor ModelDownloader { return destination } + let hashPattern = #"^[a-f0-9]{64}$"# + if descriptor.sha256.range(of: hashPattern, options: .regularExpression) == nil { + throw ModelRuntimeError.invalidManifest("Model SHA256 must be a pinned 64-char lowercase hex digest") + } + try FileManager.default.createDirectory(at: baseDirectory, withIntermediateDirectories: true) do { @@ -57,11 +86,17 @@ public actor ModelDownloader { public actor MLXLLMService: LLMService { private let modelID: String private let modelStore: URL + private let engine: any LLMGenerating private(set) var loadedModelPath: URL? - public init(modelID: String, modelStore: URL) { + public init( + modelID: String, + modelStore: URL, + engine: any LLMGenerating = PromptEchoEngine() + ) { self.modelID = modelID self.modelStore = modelStore + self.engine = engine } public func loadIfNeeded() throws { @@ -81,10 +116,31 @@ public actor MLXLLMService: LLMService { try loadIfNeeded() } - let prefix = "[BMO]" - if let context { - return "\(prefix) I can see \(context.summary). You said: \(prompt)" - } - return "\(prefix) You said: \(prompt)" + let llmPrompt = PersonaPromptBuilder.makePrompt(userPrompt: prompt, visionContext: context) + let runtimePrompt = "\(systemPrompt)\n\n\(llmPrompt)" + return try await engine.generate(systemPrompt: systemPrompt, prompt: runtimePrompt) + } +} + +public actor ModelStartupCoordinator { + private let downloader: ModelDownloader + private let descriptor: ModelDescriptor + private let llmService: MLXLLMService + + public init( + downloader: ModelDownloader, + descriptor: ModelDescriptor, + llmService: MLXLLMService + ) { + self.downloader = downloader + self.descriptor = descriptor + self.llmService = llmService + } + + @discardableResult + public func prepareModel() async throws -> URL { + let localURL = try await downloader.ensureModelAvailable(descriptor) + try await llmService.loadIfNeeded() + return localURL } } diff --git a/Packages/ModelRuntime/Sources/ModelRuntime/PromptTemplates.swift b/Packages/ModelRuntime/Sources/ModelRuntime/PromptTemplates.swift new file mode 100644 index 0000000..c7bd794 --- /dev/null +++ b/Packages/ModelRuntime/Sources/ModelRuntime/PromptTemplates.swift @@ -0,0 +1,13 @@ +import CoreAgent + +public enum PersonaPromptBuilder { + public static func makePrompt(userPrompt: String, visionContext: VisionContext?) -> String { + let persona = "You are BMO. Keep responses concise, kind, and playful." + + if let visionContext { + return "\(persona)\nVision context: \(visionContext.summary)\nUser: \(userPrompt)" + } + + return "\(persona)\nUser: \(userPrompt)" + } +} diff --git a/Packages/ModelRuntime/Tests/ModelRuntimeTests/ModelRuntimeTests.swift b/Packages/ModelRuntime/Tests/ModelRuntimeTests/ModelRuntimeTests.swift index e2d697e..8cd9b03 100644 --- a/Packages/ModelRuntime/Tests/ModelRuntimeTests/ModelRuntimeTests.swift +++ b/Packages/ModelRuntime/Tests/ModelRuntimeTests/ModelRuntimeTests.swift @@ -1,5 +1,7 @@ import Foundation +import CryptoKit import XCTest +import CoreAgent @testable import ModelRuntime final class ModelRuntimeTests: XCTestCase { @@ -24,7 +26,7 @@ final class ModelRuntimeTests: XCTestCase { let descriptor = ModelDescriptor( id: "llama-3.1-8b-4bit", url: URL(string: "https://invalid.invalid/not-found.bin")!, - sha256: "deadbeef", + sha256: String(repeating: "a", count: 64), license: "custom" ) @@ -37,4 +39,66 @@ final class ModelRuntimeTests: XCTestCase { XCTFail("Unexpected error: \(error)") } } + + func testInvalidManifestHashRejectedEarly() async { + let tmp = URL(fileURLWithPath: NSTemporaryDirectory()).appendingPathComponent("kami-manifest-tests-\(UUID().uuidString)") + let downloader = ModelDownloader(baseDirectory: tmp) + + let descriptor = ModelDescriptor( + id: "llama-3.1-8b-4bit", + url: URL(string: "https://example.com/model.bin")!, + sha256: "not-a-valid-digest", + license: "custom" + ) + + do { + _ = try await downloader.ensureModelAvailable(descriptor) + XCTFail("Expected invalidManifest") + } catch ModelRuntimeError.invalidManifest { + // expected + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + func testModelDownloaderWritesVerifiedModelFromLocalURL() async { + let sourceDir = URL(fileURLWithPath: NSTemporaryDirectory()).appendingPathComponent("kami-model-source-\(UUID().uuidString)") + let sourceFile = sourceDir.appendingPathComponent("model.bin") + let outputDir = URL(fileURLWithPath: NSTemporaryDirectory()).appendingPathComponent("kami-model-out-\(UUID().uuidString)") + let data = Data("hello-model".utf8) + let digest = SHA256.hash(data: data).map { String(format: "%02x", $0) }.joined() + + do { + try FileManager.default.createDirectory(at: sourceDir, withIntermediateDirectories: true) + try data.write(to: sourceFile) + } catch { + XCTFail("Failed to create local source model: \(error)") + return + } + + let descriptor = ModelDescriptor( + id: "llama-3.1-8b-4bit", + url: sourceFile, + sha256: digest, + license: "test" + ) + + let downloader = ModelDownloader(baseDirectory: outputDir) + do { + let destination = try await downloader.ensureModelAvailable(descriptor) + XCTAssertTrue(FileManager.default.fileExists(atPath: destination.path())) + } catch { + XCTFail("Expected local model download to succeed: \(error)") + } + } + + func testPersonaPromptBuilderIncludesVisionContext() { + let prompt = PersonaPromptBuilder.makePrompt( + userPrompt: "What do you see?", + visionContext: .init(summary: "A keyboard") + ) + + XCTAssertTrue(prompt.contains("You are BMO")) + XCTAssertTrue(prompt.contains("Vision context: A keyboard")) + } } diff --git a/README.md b/README.md index 59efb22..9e13955 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,18 @@ swift build --package-path KAMIBotApp ./scripts/test.sh ``` +### Model Bootstrap (First Run) + +For local bootstrap without committing model weights, provide a pinned model manifest via env vars: + +```bash +export KAMI_BOT_MODEL_URL="https://example.com/path/to/model.bin" +export KAMI_BOT_MODEL_SHA256="<64-char-lowercase-hex>" +export KAMI_BOT_MODEL_LICENSE="Model license name" +``` + +If these are not set, KAMI BOT falls back to the default catalog entry, which must be pinned before production release. + ## Model and License Policy - Model weights are not committed to this repository by default. diff --git a/docs/architecture.md b/docs/architecture.md index 49394ad..7095706 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -77,3 +77,4 @@ The view model consumes `AsyncStream` from `BMOAgent`. - `BMOFaceView` uses `matchedGeometryEffect` for expression transitions. - `FloatingWindowStyler` configures a borderless, transparent, always-on-top desktop companion window. - `AudioStartupCoordinator` enforces microphone permission before activating wake-word listening. +- `ModelStartupCoordinator` performs first-run model download and hash verification before LLM use. diff --git a/docs/dependencies.md b/docs/dependencies.md index 13d3709..0395eea 100644 --- a/docs/dependencies.md +++ b/docs/dependencies.md @@ -15,3 +15,4 @@ This file tracks third-party libraries, runtimes, and model-license obligations. | MLX Swift | Local model runtime | TBD by upstream package | https://github.com/ml-explore/mlx-swift | Planned | | Porcupine | Wake-word engine | Commercial + free tiers | https://picovoice.ai | Planned | | Whisper Core ML models | STT | Varies by model | https://github.com/openai/whisper | Planned | +| Llama 3.1 8B 4-bit | Default LLM weights | Llama 3.1 Community License | https://huggingface.co/mlx-community/Meta-Llama-3.1-8B-Instruct-4bit | Planned (hash must be pinned before release) |