diff --git a/Cargo.lock b/Cargo.lock index 7e48ad9..986e217 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -191,6 +191,26 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" +[[package]] +name = "bindgen" +version = "0.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f" +dependencies = [ + "bitflags 2.9.4", + "cexpr", + "clang-sys", + "itertools 0.11.0", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn 2.0.106", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -255,6 +275,15 @@ dependencies = [ "shlex", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.3" @@ -302,6 +331,17 @@ dependencies = [ "half", ] +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "4.5.48" @@ -342,6 +382,15 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "colorchoice" version = "1.0.4" @@ -552,6 +601,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.106", +] + +[[package]] +name = "darling_macro" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.106", +] + [[package]] name = "der" version = "0.7.10" @@ -929,6 +1013,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "h2" version = "0.3.27" @@ -1233,6 +1323,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "1.1.0" @@ -1381,6 +1477,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.15" @@ -1432,6 +1537,16 @@ version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link 0.2.0", +] + [[package]] name = "libm" version = "0.2.15" @@ -1497,6 +1612,15 @@ dependencies = [ "hashbrown 0.15.5", ] +[[package]] +name = "mach-sys" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48460c2e82a3a0de197152fdf8d2c2d5e43adc501501553e439bf2156e6f87c7" +dependencies = [ + "fastrand", +] + [[package]] name = "md-5" version = "0.10.6" @@ -1558,6 +1682,66 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "mlx-internal-macros" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6be59ba4e17f894ffcf3d01f54ac37a3adec81710d1da25a3c91403e99694b" +dependencies = [ + "darling", + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "mlx-macros" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85cdc61be8c860eae6e4a7940db464def29c0ff60da986bd09c64f1981a839ab" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "mlx-rs" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95ab8ddcd8cec3911a80620e338b906f3d4b6686c7db1652b7045463ebd5c0ae" +dependencies = [ + "dyn-clone", + "half", + "itertools 0.14.0", + "libc", + "mach-sys", + "mlx-internal-macros", + "mlx-macros", + "mlx-sys", + "num-complex", + "num-traits", + "num_enum", + "parking_lot", + "paste", + "smallvec", + "strum 0.27.2", + "thiserror 2.0.17", +] + +[[package]] +name = "mlx-sys" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e3bc3880111918b2d5018f845d48fd995f9901f16efc81d1fcfd2f4210b8219" +dependencies = [ + "bindgen", + "cc", + "cmake", +] + [[package]] name = "mockall" version = "0.12.1" @@ -1654,6 +1838,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -1700,6 +1893,28 @@ dependencies = [ "libc", ] +[[package]] +name = "num_enum" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1207a7e20ad57b847bbddc6776b968420d38292bbfe2089accff5e19e82454c" +dependencies = [ + "num_enum_derive", + "rustversion", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff32365de1b6743cb203b710788263c44a03de03802daf96092f2da4fe6ba4d7" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "number_prefix" version = "0.4.0" @@ -1906,6 +2121,25 @@ dependencies = [ "termtree", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn 2.0.106", +] + +[[package]] +name = "proc-macro-crate" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -1996,7 +2230,7 @@ dependencies = [ "itertools 0.11.0", "lru", "paste", - "strum", + "strum 0.25.0", "unicode-segmentation", "unicode-width 0.1.14", ] @@ -2038,7 +2272,7 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom 0.2.16", "libredox", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -2132,6 +2366,7 @@ dependencies = [ "indicatif", "inquire", "lazy_static", + "mlx-rs", "mockall", "notify", "num_cpus", @@ -2143,9 +2378,10 @@ dependencies = [ "serde", "serde_json", "similar", + "sqlite-vec", "sqlx", "tempfile", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-util", "tracing", @@ -2208,6 +2444,12 @@ version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustix" version = "1.1.2" @@ -2488,6 +2730,15 @@ dependencies = [ "unicode_categories", ] +[[package]] +name = "sqlite-vec" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec77b84fb8dd5f0f8def127226db83b5d1152c5bf367f09af03998b76ba554a" +dependencies = [ + "cc", +] + [[package]] name = "sqlx" version = "0.7.4" @@ -2536,7 +2787,7 @@ dependencies = [ "sha2", "smallvec", "sqlformat", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-stream", "tracing", @@ -2622,7 +2873,7 @@ dependencies = [ "smallvec", "sqlx-core", "stringprep", - "thiserror", + "thiserror 1.0.69", "tracing", "uuid", "whoami", @@ -2662,7 +2913,7 @@ dependencies = [ "smallvec", "sqlx-core", "stringprep", - "thiserror", + "thiserror 1.0.69", "tracing", "uuid", "whoami", @@ -2722,7 +2973,16 @@ version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" dependencies = [ - "strum_macros", + "strum_macros 0.25.3", +] + +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +dependencies = [ + "strum_macros 0.27.2", ] [[package]] @@ -2738,6 +2998,18 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "strum_macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "subtle" version = "2.6.1" @@ -2829,7 +3101,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +dependencies = [ + "thiserror-impl 2.0.17", ] [[package]] @@ -2843,6 +3124,17 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "thiserror-impl" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "thread_local" version = "1.1.9" @@ -2983,6 +3275,36 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml_datetime" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2cdb639ebbc97961c51720f858597f7f24c4fc295327923af55b74c3c724533" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_edit" +version = "0.23.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6485ef6d0d9b5d0ec17244ff7eb05310113c3f316f2d14200d4de56b3cb98f8d" +dependencies = [ + "indexmap", + "toml_datetime", + "toml_parser", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0cbe268d35bdb4bb5a56a2de88d0ad0eb70af5384a99d648cd4b3d04039800e" +dependencies = [ + "winnow", +] + [[package]] name = "tower-service" version = "0.3.3" @@ -3008,7 +3330,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3566e8ce28cc0a3fe42519fc80e6b4c943cc4c8cef275620eb8dac2d3d4e06cf" dependencies = [ "crossbeam-channel", - "thiserror", + "thiserror 1.0.69", "time", "tracing-subscriber", ] @@ -3672,6 +3994,15 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" +[[package]] +name = "winnow" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21a0236b59786fed61e2a80582dd500fe61f18b5dca67a4a067d0bc9039339cf" +dependencies = [ + "memchr", +] + [[package]] name = "winreg" version = "0.50.0" diff --git a/Cargo.toml b/Cargo.toml index 5109f2c..adbaa1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,12 +39,18 @@ num_cpus = "1.16" indicatif = "0.17" inquire = "0.7" console = "0.15" +sqlite-vec = "0.1.6" regex = "1.10" lazy_static = "1.4" +# MLX for embedding generation (macOS only) +[target.'cfg(target_os = "macos")'.dependencies] +mlx-rs = { version = "0.25", optional = true } + [features] default = ["reqwest"] reqwest = ["dep:reqwest"] +mlx = ["mlx-rs"] [dev-dependencies] tempfile = "3.8" @@ -164,4 +170,4 @@ path = "tests/contract/test_cli_add_command.rs" [profile.release] lto = true codegen-units = 1 -panic = "abort" \ No newline at end of file +panic = "abort" diff --git a/migrations/011_add_message_embeddings.sql b/migrations/011_add_message_embeddings.sql new file mode 100644 index 0000000..45b9983 --- /dev/null +++ b/migrations/011_add_message_embeddings.sql @@ -0,0 +1,11 @@ +-- Add embedding column to messages table +-- Migration: 008_add_message_embeddings +-- Description: Add embedding vector column (768 dimensions) for semantic search + +-- Add embedding column to messages table (can be NULL for now) +-- Using BLOB to store float32 vectors of 768 dimensions +ALTER TABLE messages ADD COLUMN embedding BLOB; + +-- Create virtual table for vector similarity search using sqlite-vec +-- Note: This will be created dynamically when needed in the application code +-- since sqlite-vec needs to be loaded as an extension first diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 06ead46..77bcf88 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -130,6 +130,9 @@ pub enum Commands { /// Messages until this time (e.g., "now", "2024-10-31", "today") #[arg(long)] until: Option, + /// Use embedding-based semantic search instead of full-text search + #[arg(long)] + use_embedding: bool, }, /// [Alias for 'retrospect execute'] Review and analyze a chat session Review { @@ -187,6 +190,9 @@ pub enum QueryCommands { /// Messages until this time (e.g., "now", "2024-10-31", "today") #[arg(long)] until: Option, + /// Use embedding-based semantic search instead of full-text search + #[arg(long)] + use_embedding: bool, }, /// Query messages by time range Timeline { @@ -286,7 +292,11 @@ impl Cli { limit, since, until, - } => query::handle_search_command(query, limit, since, until).await, + use_embedding, + } => { + query::handle_search_command(query, limit, since, until, use_embedding) + .await + } QueryCommands::Timeline { since, until, @@ -363,7 +373,8 @@ impl Cli { limit, since, until, - } => query::handle_search_command(query, limit, since, until).await, + use_embedding, + } => query::handle_search_command(query, limit, since, until, use_embedding).await, Commands::Review { session_id } => { // For now, delegate to retrospect execute // TODO: Could make this more interactive diff --git a/src/cli/query.rs b/src/cli/query.rs index 7276d80..047f240 100644 --- a/src/cli/query.rs +++ b/src/cli/query.rs @@ -123,6 +123,7 @@ pub async fn handle_search_command( limit: Option, since: Option, until: Option, + use_embedding: bool, ) -> Result<()> { let db_path = crate::database::config::get_default_db_path()?; let db_manager = DatabaseManager::new(&db_path).await?; @@ -154,6 +155,12 @@ pub async fn handle_search_command( None }; + let search_type = if use_embedding { + Some("embedding".to_string()) + } else { + None + }; + let request = SearchRequest { query, page: Some(1), @@ -161,7 +168,7 @@ pub async fn handle_search_command( date_range, projects: None, providers: None, - search_type: None, + search_type, }; let response = query_service.search_messages(request).await?; diff --git a/src/database/message_repo.rs b/src/database/message_repo.rs index 30aba48..20f8380 100644 --- a/src/database/message_repo.rs +++ b/src/database/message_repo.rs @@ -18,13 +18,37 @@ impl MessageRepository { } } + /// Convert embedding vector to BLOB (bytes) for storage + fn embedding_to_blob(embedding: &[f32]) -> Vec { + embedding.iter().flat_map(|&f| f.to_le_bytes()).collect() + } + + /// Convert BLOB (bytes) back to embedding vector + fn blob_to_embedding(blob: &[u8]) -> Option> { + if !blob.len().is_multiple_of(4) { + return None; + } + + let mut embedding = Vec::with_capacity(blob.len() / 4); + for chunk in blob.chunks_exact(4) { + let bytes: [u8; 4] = chunk.try_into().ok()?; + embedding.push(f32::from_le_bytes(bytes)); + } + Some(embedding) + } + pub async fn create(&self, message: &Message) -> AnyhowResult<()> { + let embedding_blob = message + .embedding + .as_ref() + .map(|emb| Self::embedding_to_blob(emb)); + sqlx::query( r#" INSERT INTO messages ( id, session_id, role, content, timestamp, token_count, - metadata, sequence_number, message_type, tool_operation_id - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + metadata, sequence_number, message_type, tool_operation_id, embedding + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) "#, ) .bind(message.id.to_string()) @@ -37,6 +61,7 @@ impl MessageRepository { .bind(message.sequence_number) .bind(message.message_type.to_string()) .bind(message.tool_operation_id.map(|id| id.to_string())) + .bind(embedding_blob) .execute(&self.pool) .await .context("Failed to create message")?; @@ -48,7 +73,7 @@ impl MessageRepository { let row = sqlx::query( r#" SELECT id, session_id, role, content, timestamp, token_count, - metadata, sequence_number, message_type, tool_operation_id + metadata, sequence_number, message_type, tool_operation_id, embedding FROM messages WHERE id = ? "#, @@ -71,7 +96,7 @@ impl MessageRepository { let rows = sqlx::query( r#" SELECT id, session_id, role, content, timestamp, token_count, - metadata, sequence_number, message_type, tool_operation_id + metadata, sequence_number, message_type, tool_operation_id, embedding FROM messages WHERE session_id = ? ORDER BY sequence_number ASC @@ -142,7 +167,7 @@ impl MessageRepository { let mut sql = r#" SELECT m.id, m.session_id, m.role, m.content, m.timestamp, m.token_count, m.metadata, m.sequence_number, - m.message_type, m.tool_operation_id + m.message_type, m.tool_operation_id, m.embedding FROM messages m JOIN messages_fts fts ON m.rowid = fts.rowid WHERE messages_fts MATCH ? @@ -197,7 +222,7 @@ impl MessageRepository { let mut sql = r#" SELECT m.id, m.session_id, m.role, m.content, m.timestamp, m.token_count, m.metadata, m.sequence_number, - m.message_type, m.tool_operation_id + m.message_type, m.tool_operation_id, m.embedding FROM messages m JOIN messages_fts fts ON m.rowid = fts.rowid WHERE messages_fts MATCH ? @@ -291,7 +316,7 @@ impl MessageRepository { r#" SELECT m.id, m.session_id, m.role, m.content, m.timestamp, m.token_count, m.metadata, m.sequence_number, - m.message_type, m.tool_operation_id + m.message_type, m.tool_operation_id, m.embedding FROM messages m "#, ); @@ -376,12 +401,17 @@ impl MessageRepository { .context("Failed to start transaction")?; for message in messages { + let embedding_blob = message + .embedding + .as_ref() + .map(|emb| Self::embedding_to_blob(emb)); + sqlx::query( r#" INSERT INTO messages ( id, session_id, role, content, timestamp, token_count, - metadata, sequence_number, message_type, tool_operation_id - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + metadata, sequence_number, message_type, tool_operation_id, embedding + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) "#, ) .bind(message.id.to_string()) @@ -394,6 +424,7 @@ impl MessageRepository { .bind(message.sequence_number) .bind(message.message_type.to_string()) .bind(message.tool_operation_id.map(|id| id.to_string())) + .bind(embedding_blob) .execute(&mut *tx) .await .context("Failed to insert message in bulk")?; @@ -417,6 +448,7 @@ impl MessageRepository { let sequence_number: i64 = row.try_get("sequence_number")?; let message_type_str: String = row.try_get("message_type")?; let tool_operation_id_str: Option = row.try_get("tool_operation_id")?; + let embedding_blob: Option> = row.try_get("embedding")?; let id = Uuid::parse_str(&id_str).context("Invalid message ID format")?; let session_id = Uuid::parse_str(&session_id_str).context("Invalid session ID format")?; @@ -433,6 +465,8 @@ impl MessageRepository { None }; + let embedding = embedding_blob.and_then(|blob| Self::blob_to_embedding(&blob)); + let metadata: Option = serde_json::from_str("{}").ok(); Ok(Message { @@ -448,6 +482,7 @@ impl MessageRepository { tool_operation_id, tool_uses: None, tool_results: None, + embedding, }) } } diff --git a/src/env.rs b/src/env.rs index c2acbd8..cdc607d 100644 --- a/src/env.rs +++ b/src/env.rs @@ -56,3 +56,11 @@ pub mod retrospection { /// Maximum concurrent analysis operations pub const CONCURRENT: &str = "RETROCHAT_CONCURRENT"; } + +/// Embedding service configuration +pub mod embedding { + /// Enable MLX-based embedding generation (macOS only) + /// When enabled on macOS, uses MLX for embedding extraction + /// Shows warning and disables on unsupported platforms (Windows, Linux) + pub const USE_MLX: &str = "RETROCHAT_USE_MLX"; +} diff --git a/src/models/message.rs b/src/models/message.rs index b20c83e..daf0344 100644 --- a/src/models/message.rs +++ b/src/models/message.rs @@ -119,6 +119,9 @@ pub struct Message { pub tool_uses: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub tool_results: Option>, + /// Embedding vector for semantic search (768 dimensions, f32 values) + #[serde(skip_serializing_if = "Option::is_none")] + pub embedding: Option>, } impl Message { @@ -142,6 +145,7 @@ impl Message { tool_operation_id: None, tool_uses: None, tool_results: None, + embedding: None, } } @@ -177,6 +181,11 @@ impl Message { self } + pub fn with_embedding(mut self, embedding: Vec) -> Self { + self.embedding = Some(embedding); + self + } + pub fn is_valid(&self) -> bool { !self.content.is_empty() } diff --git a/src/services/embedding_service.rs b/src/services/embedding_service.rs new file mode 100644 index 0000000..28d257b --- /dev/null +++ b/src/services/embedding_service.rs @@ -0,0 +1,221 @@ +//! Embedding service for generating text embeddings +//! +//! This service provides text embedding generation with support for: +//! - MLX-based embeddings on macOS (when RETROCHAT_USE_MLX is enabled) +//! - Dummy embeddings (768 dimensions) for development and testing +//! +//! Platform Support: +//! - macOS: Full MLX support when enabled +//! - Windows/Linux: Shows warning, falls back to dummy embeddings + +use anyhow::Result; +use tracing::{info, warn}; + +/// Standard embedding dimension size (compatible with many embedding models) +pub const EMBEDDING_DIM: usize = 768; + +/// Embedding service for generating text embeddings +pub struct EmbeddingService { + enabled: bool, + mlx_available: bool, +} + +impl EmbeddingService { + /// Create a new embedding service + /// + /// Checks platform support and environment configuration: + /// - On macOS with RETROCHAT_USE_MLX=true: Enables MLX-based embeddings + /// - On other platforms or when disabled: Uses dummy embeddings + pub fn new() -> Self { + let use_mlx_env = std::env::var(crate::env::embedding::USE_MLX) + .unwrap_or_else(|_| "false".to_string()) + .to_lowercase(); + let use_mlx = use_mlx_env == "true" || use_mlx_env == "1"; + + let mlx_available = Self::check_mlx_support(); + + let enabled = if use_mlx { + if mlx_available { + info!("Embedding service enabled with MLX support"); + true + } else { + warn!( + "RETROCHAT_USE_MLX is enabled but MLX is not supported on this platform. \ + MLX only works on macOS. Embedding service will use dummy embeddings." + ); + false + } + } else { + info!("Embedding service using dummy embeddings (RETROCHAT_USE_MLX not enabled)"); + false + }; + + Self { + enabled, + mlx_available, + } + } + + /// Check if MLX is supported on the current platform + fn check_mlx_support() -> bool { + #[cfg(target_os = "macos")] + { + // On macOS, MLX is available if the feature is enabled + #[cfg(feature = "mlx")] + { + true + } + #[cfg(not(feature = "mlx"))] + { + false + } + } + #[cfg(not(target_os = "macos"))] + { + false + } + } + + /// Generate embedding for the given text + /// + /// Returns a 768-dimensional embedding vector. + /// Currently returns dummy embeddings; will be replaced with actual MLX implementation. + pub fn generate_embedding(&self, text: &str) -> Result> { + if self.enabled && self.mlx_available { + self.generate_mlx_embedding(text) + } else { + self.generate_dummy_embedding(text) + } + } + + /// Generate embedding using MLX (macOS only) + /// + /// TODO: Implement actual MLX-based embedding extraction + /// For now, returns dummy embeddings even when MLX is available + #[allow(unused_variables)] + fn generate_mlx_embedding(&self, text: &str) -> Result> { + #[cfg(all(target_os = "macos", feature = "mlx"))] + { + // TODO: Implement MLX-based embedding generation + // For now, return dummy embeddings + warn!("MLX embedding generation not yet implemented, using dummy embeddings"); + self.generate_dummy_embedding(text) + } + #[cfg(not(all(target_os = "macos", feature = "mlx")))] + { + self.generate_dummy_embedding(text) + } + } + + /// Generate dummy embedding for development/testing + /// + /// Creates a deterministic 768-dimensional embedding based on text content. + /// Uses a simple hash-based approach to ensure consistency. + fn generate_dummy_embedding(&self, text: &str) -> Result> { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + // Create a deterministic seed from the text + let mut hasher = DefaultHasher::new(); + text.hash(&mut hasher); + let seed = hasher.finish(); + + // Generate deterministic pseudo-random values + let mut embedding = Vec::with_capacity(EMBEDDING_DIM); + let mut rng_state = seed; + + for _ in 0..EMBEDDING_DIM { + // Simple LCG (Linear Congruential Generator) + rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1); + let value = (rng_state >> 32) as u32; + // Normalize to [-1, 1] range + let normalized = (value as f32 / u32::MAX as f32) * 2.0 - 1.0; + embedding.push(normalized); + } + + // Normalize the vector to unit length (L2 normalization) + let magnitude: f32 = embedding.iter().map(|x| x * x).sum::().sqrt(); + if magnitude > 0.0 { + for value in &mut embedding { + *value /= magnitude; + } + } + + Ok(embedding) + } + + /// Check if embedding service is enabled + pub fn is_enabled(&self) -> bool { + self.enabled + } + + /// Check if MLX is available on this platform + pub fn is_mlx_available(&self) -> bool { + self.mlx_available + } +} + +impl Default for EmbeddingService { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dummy_embedding_generation() { + let service = EmbeddingService { + enabled: false, + mlx_available: false, + }; + + let embedding = service.generate_embedding("test text").unwrap(); + assert_eq!(embedding.len(), EMBEDDING_DIM); + + // Check that values are normalized (approximately unit length) + let magnitude: f32 = embedding.iter().map(|x| x * x).sum::().sqrt(); + assert!((magnitude - 1.0).abs() < 0.001); + } + + #[test] + fn test_embedding_deterministic() { + let service = EmbeddingService { + enabled: false, + mlx_available: false, + }; + + let embedding1 = service.generate_embedding("test text").unwrap(); + let embedding2 = service.generate_embedding("test text").unwrap(); + + // Same text should produce same embedding + assert_eq!(embedding1, embedding2); + } + + #[test] + fn test_embedding_different_text() { + let service = EmbeddingService { + enabled: false, + mlx_available: false, + }; + + let embedding1 = service.generate_embedding("text one").unwrap(); + let embedding2 = service.generate_embedding("text two").unwrap(); + + // Different text should produce different embeddings + assert_ne!(embedding1, embedding2); + } + + #[test] + fn test_platform_support_check() { + let is_supported = EmbeddingService::check_mlx_support(); + + #[cfg(all(target_os = "macos", feature = "mlx"))] + assert!(is_supported); + + #[cfg(not(all(target_os = "macos", feature = "mlx")))] + assert!(!is_supported); + } +} diff --git a/src/services/mod.rs b/src/services/mod.rs index 6df7fbc..08f6440 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -1,5 +1,6 @@ pub mod analytics_service; pub mod auto_detect; +pub mod embedding_service; pub mod google_ai; pub mod import_service; pub mod parser_service; @@ -12,6 +13,7 @@ pub use analytics_service::{ MessageRoleDistribution, ProjectStats, ProviderStats, UsageInsights, }; pub use auto_detect::{AutoDetectService, DetectedProvider}; +pub use embedding_service::{EmbeddingService, EMBEDDING_DIM}; pub use google_ai::{ GenerateContentRequest, GenerateContentResponse, GoogleAiClient, GoogleAiConfig, GoogleAiError, }; diff --git a/src/services/query_service.rs b/src/services/query_service.rs index 318e848..70c385c 100644 --- a/src/services/query_service.rs +++ b/src/services/query_service.rs @@ -16,8 +16,8 @@ pub enum MessageGroup { Single(Message), /// A tool use message paired with its corresponding tool result message ToolPair { - tool_use_message: Message, - tool_result_message: Message, + tool_use_message: Box, + tool_result_message: Box, }, } @@ -70,8 +70,8 @@ impl MessageGroup { if has_matching_result { // Create a ToolPair and skip the next message groups.push(MessageGroup::ToolPair { - tool_use_message: current.clone(), - tool_result_message: next.clone(), + tool_use_message: Box::new(current.clone()), + tool_result_message: Box::new(next.clone()), }); i += 2; // Skip both messages continue; @@ -95,7 +95,7 @@ impl MessageGroup { MessageGroup::ToolPair { tool_use_message, tool_result_message, - } => vec![tool_use_message, tool_result_message], + } => vec![tool_use_message.as_ref(), tool_result_message.as_ref()], } } diff --git a/tests/contract/test_cli_add_command.rs b/tests/contract/test_cli_add_command.rs index 3b87220..9e41db2 100644 --- a/tests/contract/test_cli_add_command.rs +++ b/tests/contract/test_cli_add_command.rs @@ -93,6 +93,7 @@ fn test_search_command_structure() { limit: Some(10), since: None, until: None, + use_embedding: false, }; match search_cmd { @@ -101,11 +102,13 @@ fn test_search_command_structure() { limit, since, until, + use_embedding, } => { assert_eq!(query, "test query"); assert_eq!(limit, Some(10)); assert!(since.is_none()); assert!(until.is_none()); + assert!(!use_embedding); } _ => panic!("Expected Search command"), } @@ -118,6 +121,7 @@ fn test_search_command_without_limit() { limit: None, since: None, until: None, + use_embedding: false, }; match search_cmd { @@ -126,11 +130,13 @@ fn test_search_command_without_limit() { limit, since, until, + use_embedding, } => { assert_eq!(query, "test"); assert!(limit.is_none()); assert!(since.is_none()); assert!(until.is_none()); + assert!(!use_embedding); } _ => panic!("Expected Search command"), } @@ -143,6 +149,7 @@ fn test_search_command_with_time_range() { limit: Some(10), since: Some("7 days ago".to_string()), until: Some("now".to_string()), + use_embedding: false, }; match search_cmd { @@ -151,11 +158,41 @@ fn test_search_command_with_time_range() { limit, since, until, + use_embedding, } => { assert_eq!(query, "test"); assert_eq!(limit, Some(10)); assert_eq!(since, Some("7 days ago".to_string())); assert_eq!(until, Some("now".to_string())); + assert!(!use_embedding); + } + _ => panic!("Expected Search command"), + } +} + +#[test] +fn test_search_command_with_embedding() { + let search_cmd = Commands::Search { + query: "semantic search test".to_string(), + limit: Some(5), + since: None, + until: None, + use_embedding: true, + }; + + match search_cmd { + Commands::Search { + query, + limit, + since, + until, + use_embedding, + } => { + assert_eq!(query, "semantic search test"); + assert_eq!(limit, Some(5)); + assert!(since.is_none()); + assert!(until.is_none()); + assert!(use_embedding); } _ => panic!("Expected Search command"), }