diff --git a/.gitignore b/.gitignore index b3f6a75..6ed123d 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,4 @@ docs-site/.vitepress/.temp # Claude Code local settings .claude/ +docs/tasks \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a29243..6d2cc2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file. ## [Unreleased] ### Added +- **Mixedbread model support**: Added first-class support for Mixedbread embedding and reranking models + - Embedding model: `mxbai-xsmall` (`mixedbread-ai/mxbai-embed-xsmall-v1`) - 384 dimensions, 4K context window + - Reranker: `mxbai` (`mixedbread-ai/mxbai-rerank-xsmall-v1`) - Neural cross-encoder reranker + - Fully local inference using ONNX Runtime with quantized models + - Provider abstraction for clean model selection and routing + - Model registry integration with `mxbai-xsmall` alias + - CLI support: `--model mxbai-xsmall` and `--rerank-model mxbai` + - MCP server support for Mixedbread models in semantic/hybrid search tools - **VitePress documentation site**: Comprehensive documentation with improved navigation, search, and structure in `docs-site/` directory - **Documentation features**: Guide pages, feature documentation, CLI reference, embedding model guide, architecture docs, and contributing guides - **Local search**: Built-in search functionality in documentation site diff --git a/Cargo.lock b/Cargo.lock index c1dd026..3d6478a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -512,7 +512,7 @@ dependencies = [ "ck-embed", "hf-hub 0.3.2", "serde", - "tokenizers", + "tokenizers 0.22.1", "tracing", "tree-sitter", "tree-sitter-c-sharp", @@ -545,8 +545,15 @@ version = "0.7.1" dependencies = [ "anyhow", "ck-core", + "ck-models", "fastembed", + "hf-hub 0.4.3", + "ndarray", + "num_cpus", + "once_cell", + "ort", "serde", + "tokenizers 0.20.4", "tokio", ] @@ -1186,6 +1193,9 @@ name = "esaxx-rs" version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +dependencies = [ + "cc", +] [[package]] name = "euclid" @@ -1229,7 +1239,7 @@ dependencies = [ "ndarray", "ort", "serde_json", - "tokenizers", + "tokenizers 0.22.1", ] [[package]] @@ -1542,6 +1552,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "hf-hub" version = "0.3.2" @@ -1966,6 +1982,15 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.12.1" @@ -2482,6 +2507,16 @@ dependencies = [ "libm", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "number_prefix" version = "0.4.0" @@ -3095,6 +3130,17 @@ dependencies = [ "rayon-core", ] +[[package]] +name = "rayon-cond" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" +dependencies = [ + "either", + "itertools 0.11.0", + "rayon", +] + [[package]] name = "rayon-cond" version = "0.4.0" @@ -4134,6 +4180,38 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokenizers" +version = "0.20.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905" +dependencies = [ + "aho-corasick", + "derive_builder", + "esaxx-rs", + "getrandom 0.2.16", + "indicatif", + "itertools 0.12.1", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.8.5", + "rayon", + "rayon-cond 0.3.0", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 1.0.69", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokenizers" version = "0.22.1" @@ -4156,7 +4234,7 @@ dependencies = [ "paste", "rand 0.9.2", "rayon", - "rayon-cond", + "rayon-cond 0.4.0", "regex", "regex-syntax", "serde", diff --git a/Cargo.toml b/Cargo.toml index 142cceb..f9db21f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,3 +57,9 @@ pdf-extract = "0.9" uuid = { version = "1.8", features = ["v4", "serde"] } base64 = "0.22" sha2 = "0.10" +hf-hub = { version = "0.4.3", default-features = false, features = ["ureq"] } +tokenizers = "0.20.1" +ort = { version = "2.0.0-rc.10", default-features = false, features = ["download-binaries"] } +once_cell = "1.19" +ndarray = { version = "0.16", default-features = false, features = ["std"] } +num_cpus = "1.16" diff --git a/README.md b/README.md index 9719d46..45e77be 100644 --- a/README.md +++ b/README.md @@ -244,6 +244,9 @@ Choose the right embedding model for your needs: # Default: BGE-Small (fast, precise chunking) ck --index . +# Mixedbread xsmall: Optimized for local semantic search (4K context, 384 dims) +ck --index --model mxbai-xsmall . + # Enhanced: Nomic V1.5 (8K context, optimal for large functions) ck --index --model nomic-v1.5 . @@ -253,6 +256,7 @@ ck --index --model jina-code . **Model Comparison:** - **`bge-small`** (default): 400-token chunks, fast indexing, good for most code +- **`mxbai-xsmall`**: 4K context window, 384 dimensions, optimized for local inference (Mixedbread) - **`nomic-v1.5`**: 1024-token chunks with 8K model capacity, better for large functions - **`jina-code`**: 1024-token chunks with 8K model capacity, specialized for code understanding @@ -264,6 +268,7 @@ ck --status . # Clean up and rebuild / switch models ck --clean . +ck --switch-model mxbai-xsmall . ck --switch-model nomic-v1.5 . ck --switch-model nomic-v1.5 --force . # Force rebuild diff --git a/ck-cli/src/main.rs b/ck-cli/src/main.rs index f4c9ec6..440e58e 100644 --- a/ck-cli/src/main.rs +++ b/ck-cli/src/main.rs @@ -328,7 +328,7 @@ struct Cli { #[arg( long = "model", value_name = "MODEL", - help = "Embedding model to use for indexing (bge-small, nomic-v1.5, jina-code) [default: bge-small]. Only used with --index." + help = "Embedding model to use for indexing (bge-small, nomic-v1.5, jina-code, mxbai-xsmall) [default: bge-small]. Only used with --index." )] model: Option, @@ -342,7 +342,7 @@ struct Cli { #[arg( long = "rerank-model", value_name = "MODEL", - help = "Reranking model to use (jina, bge) [default: jina]" + help = "Reranking model to use (jina, bge, mxbai) [default: jina]" )] rerank_model: Option, @@ -451,46 +451,6 @@ fn build_exclude_patterns(cli: &Cli) -> Vec { ck_core::build_exclude_patterns(&cli.exclude, !cli.no_default_excludes) } -fn resolve_model_selection( - registry: &ck_models::ModelRegistry, - requested: Option<&str>, -) -> Result<(String, ck_models::ModelConfig)> { - match requested { - Some(name) => { - if let Some(config) = registry.get_model(name) { - return Ok((name.to_string(), config.clone())); - } - - if let Some((alias, config)) = registry - .models - .iter() - .find(|(_, config)| config.name == name) - { - return Ok((alias.clone(), config.clone())); - } - - anyhow::bail!( - "Unknown model '{}'. Available models: {}", - name, - registry - .models - .keys() - .cloned() - .collect::>() - .join(", ") - ); - } - None => { - let alias = registry.default_model.clone(); - let config = registry - .get_default_model() - .ok_or_else(|| anyhow::anyhow!("No default model configured"))? - .clone(); - Ok((alias, config)) - } - } -} - async fn run_index_workflow( status: &StatusReporter, path: &Path, @@ -983,7 +943,9 @@ async fn run_cli_mode(cli: Cli) -> Result<()> { .unwrap_or_else(|| PathBuf::from(".")); let registry = ck_models::ModelRegistry::default(); - let (model_alias, model_config) = resolve_model_selection(®istry, Some(model_name))?; + let (model_alias, model_config) = registry + .resolve(Some(model_name)) + .map_err(|e| anyhow::anyhow!(e.to_string()))?; if !cli.force { let manifest_path = path.join(".ck").join("manifest.json"); @@ -992,7 +954,7 @@ async fn run_cli_mode(cli: Cli) -> Result<()> { && let Ok(manifest) = serde_json::from_slice::(&data) && let Some(existing_model) = manifest.embedding_model.clone() && let Ok((existing_alias, existing_config)) = - resolve_model_selection(®istry, Some(existing_model.as_str())) + registry.resolve(Some(existing_model.as_str())) && existing_config.name == model_config.name { status.section_header("Switching Embedding Model"); @@ -1042,7 +1004,9 @@ async fn run_cli_mode(cli: Cli) -> Result<()> { .unwrap_or_else(|| PathBuf::from(".")); let registry = ck_models::ModelRegistry::default(); - let (model_alias, model_config) = resolve_model_selection(®istry, cli.model.as_deref())?; + let (model_alias, model_config) = registry + .resolve(cli.model.as_deref()) + .map_err(|e| anyhow::anyhow!(e.to_string()))?; run_index_workflow( &status, @@ -1596,22 +1560,24 @@ async fn run_search( let resolved_model = ck_engine::resolve_model_for_path(&options.path, options.embedding_model.as_deref())?; - if resolved_model.alias == resolved_model.canonical_name { + if resolved_model.alias == resolved_model.canonical_name() { eprintln!( "🤖 Model: {} ({} dims)", - resolved_model.canonical_name, resolved_model.dimensions + resolved_model.canonical_name(), + resolved_model.dimensions() ); } else { eprintln!( "🤖 Model: {} (alias '{}', {} dims)", - resolved_model.canonical_name, resolved_model.alias, resolved_model.dimensions + resolved_model.canonical_name(), + resolved_model.alias, + resolved_model.dimensions() ); } - let max_tokens = - ck_chunk::TokenEstimator::get_model_limit(resolved_model.canonical_name.as_str()); + let max_tokens = ck_chunk::TokenEstimator::get_model_limit(resolved_model.canonical_name()); let (chunk_tokens, overlap_tokens) = - ck_chunk::get_model_chunk_config(Some(resolved_model.canonical_name.as_str())); + ck_chunk::get_model_chunk_config(Some(resolved_model.canonical_name())); eprintln!("📏 FastEmbed Config: {} token limit", max_tokens); eprintln!( diff --git a/ck-cli/tests/integration_tests.rs b/ck-cli/tests/integration_tests.rs index 27f1ca9..8ebe61c 100644 --- a/ck-cli/tests/integration_tests.rs +++ b/ck-cli/tests/integration_tests.rs @@ -821,3 +821,198 @@ fn test_add_file_with_relative_path() { let stdout = String::from_utf8(output.stdout).unwrap(); assert!(stdout.contains("Relative path content")); } + +#[test] +#[serial] +#[ignore] // Requires models to be downloaded - run with: CK_MIXEDBREAD_MODELS_READY=1 cargo test -- --ignored +fn test_mixedbread_index_and_search() { + // Skip if models aren't ready (set CK_MIXEDBREAD_MODELS_READY=1 to enable) + if std::env::var("CK_MIXEDBREAD_MODELS_READY").is_err() { + return; + } + + let temp_dir = TempDir::new().unwrap(); + + // Create test files with semantic content + fs::write( + temp_dir.path().join("rust_error.rs"), + r#"fn handle_error() -> Result { + let result = risky_operation()?; + Ok(result) +}"#, + ) + .unwrap(); + fs::write( + temp_dir.path().join("python_web.py"), + "from flask import Flask\napp = Flask(__name__)\n@app.route('/')\ndef hello(): return 'Hello'", + ) + .unwrap(); + fs::write( + temp_dir.path().join("error_handling.md"), + "Error handling in Rust uses Result and Option types for safe error propagation", + ) + .unwrap(); + + // Test indexing with Mixedbread model + let output = Command::new(ck_binary()) + .args(["--index", "--model", "mxbai-xsmall", "."]) + .current_dir(temp_dir.path()) + .output() + .expect("Failed to run ck index with Mixedbread"); + + assert!( + output.status.success(), + "Indexing failed: stderr: {}, stdout: {}", + String::from_utf8_lossy(&output.stderr), + String::from_utf8_lossy(&output.stdout) + ); + + // Verify index was created + assert!(temp_dir.path().join(".ck").exists(), "Index directory should exist"); + + // Check manifest contains Mixedbread model + let manifest_path = temp_dir.path().join(".ck").join("manifest.json"); + let manifest_data = fs::read(&manifest_path).expect("manifest should exist"); + let manifest: serde_json::Value = + serde_json::from_slice(&manifest_data).expect("valid json"); + let embedding_model = manifest + .get("embedding_model") + .and_then(|v| v.as_str()) + .expect("embedding_model should be set"); + assert!( + embedding_model.contains("mxbai-embed-xsmall-v1"), + "Manifest should record Mixedbread model, got: {}", + embedding_model + ); + + let embedding_dimensions = manifest + .get("embedding_dimensions") + .and_then(|v| v.as_u64()) + .expect("embedding_dimensions should be set"); + assert_eq!( + embedding_dimensions, 384, + "Mixedbread xsmall should have 384 dimensions" + ); + + // Test semantic search with Mixedbread + let output = Command::new(ck_binary()) + .args([ + "--sem", + "error handling", + "--model", + "mxbai-xsmall", + ".", + ]) + .current_dir(temp_dir.path()) + .output() + .expect("Failed to run ck semantic search with Mixedbread"); + + assert!( + output.status.success(), + "Semantic search failed: stderr: {}, stdout: {}", + String::from_utf8_lossy(&output.stderr), + String::from_utf8_lossy(&output.stdout) + ); + + let stdout = String::from_utf8(output.stdout).unwrap(); + // Should find error handling related content + assert!( + stdout.contains("error") || stdout.contains("Error"), + "Should find error handling content" + ); + + // Test reranking with Mixedbread reranker + let output = Command::new(ck_binary()) + .args([ + "--sem", + "error handling", + "--model", + "mxbai-xsmall", + "--rerank", + "--rerank-model", + "mxbai", + ".", + ]) + .current_dir(temp_dir.path()) + .output() + .expect("Failed to run ck search with Mixedbread reranker"); + + assert!( + output.status.success(), + "Reranked search failed: stderr: {}, stdout: {}", + String::from_utf8_lossy(&output.stderr), + String::from_utf8_lossy(&output.stdout) + ); + + let stdout = String::from_utf8(output.stdout).unwrap(); + assert!(!stdout.is_empty(), "Should return results"); +} + +#[test] +#[serial] +#[ignore] // Requires models to be downloaded +fn test_switch_model_to_mixedbread() { + if std::env::var("CK_MIXEDBREAD_MODELS_READY").is_err() { + return; + } + + let temp_dir = TempDir::new().unwrap(); + fs::write( + temp_dir.path().join("test.rs"), + "fn main() { println!(\"Hello\"); }", + ) + .unwrap(); + + // Create index with default model + let output = Command::new(ck_binary()) + .args(["--index", "."]) + .current_dir(temp_dir.path()) + .output() + .expect("Failed to create initial index"); + + assert!(output.status.success()); + + let updated_before = read_manifest_updated(temp_dir.path()); + + std::thread::sleep(std::time::Duration::from_secs(1)); + + // Switch to Mixedbread model + let output = Command::new(ck_binary()) + .args(["--switch-model", "mxbai-xsmall"]) + .current_dir(temp_dir.path()) + .output() + .expect("Failed to switch model"); + + assert!( + output.status.success(), + "Switch model failed: stderr: {}, stdout: {}", + String::from_utf8_lossy(&output.stderr), + String::from_utf8_lossy(&output.stdout) + ); + + let stderr = String::from_utf8(output.stderr).unwrap(); + assert!( + stderr.contains("Switching") || stderr.contains("Rebuilding"), + "Should indicate model switch" + ); + + let updated_after = read_manifest_updated(temp_dir.path()); + assert!( + updated_after > updated_before, + "Manifest should be updated after model switch" + ); + + // Verify manifest now has Mixedbread model + let manifest_path = temp_dir.path().join(".ck").join("manifest.json"); + let manifest_data = fs::read(&manifest_path).expect("manifest should exist"); + let manifest: serde_json::Value = + serde_json::from_slice(&manifest_data).expect("valid json"); + let embedding_model = manifest + .get("embedding_model") + .and_then(|v| v.as_str()) + .expect("embedding_model should be set"); + assert!( + embedding_model.contains("mxbai-embed-xsmall-v1"), + "Manifest should now have Mixedbread model" + ); +} diff --git a/ck-embed/Cargo.toml b/ck-embed/Cargo.toml index afba29f..dbb63cb 100644 --- a/ck-embed/Cargo.toml +++ b/ck-embed/Cargo.toml @@ -12,13 +12,28 @@ categories = ["science"] [dependencies] ck-core = { version = "0.7.1", path = "../ck-core" } +ck-models = { version = "0.7.1", path = "../ck-models" } anyhow = { workspace = true } serde = { workspace = true } tokio = { workspace = true } fastembed = { workspace = true, optional = true } +hf-hub = { workspace = true, optional = true } +tokenizers = { workspace = true, optional = true } +ort = { workspace = true, optional = true } +once_cell = { workspace = true, optional = true } +ndarray = { workspace = true, optional = true } +num_cpus = { workspace = true, optional = true } [features] -default = ["fastembed"] -fastembed = ["dep:fastembed"] \ No newline at end of file +default = ["fastembed", "mixedbread"] +fastembed = ["dep:fastembed"] +mixedbread = [ + "dep:hf-hub", + "dep:tokenizers", + "dep:ort", + "dep:once_cell", + "dep:ndarray", + "dep:num_cpus", +] diff --git a/ck-embed/examples/test_mixedbread.rs b/ck-embed/examples/test_mixedbread.rs new file mode 100644 index 0000000..36eef2a --- /dev/null +++ b/ck-embed/examples/test_mixedbread.rs @@ -0,0 +1,176 @@ +#[cfg(feature = "mixedbread")] +use ck_embed::create_embedder; +#[cfg(feature = "mixedbread")] +use ck_embed::reranker::create_reranker; +#[cfg(feature = "mixedbread")] +use ck_models::{ModelRegistry, RerankModelRegistry}; + +fn main() { + #[cfg(not(feature = "mixedbread"))] + { + println!("This example requires the 'mixedbread' feature to be enabled."); + println!("Run with: cargo run --example test_mixedbread --features mixedbread"); + return; + } + + #[cfg(feature = "mixedbread")] + run_example(); +} + +#[cfg(feature = "mixedbread")] +fn run_example() { + println!("=== Testing Mixedbread Models ===\n"); + + // Test 1: Model Registry Resolution + println!("1. Testing Model Registry Resolution"); + println!(" Checking if 'mxbai-xsmall' alias resolves..."); + + let registry = ModelRegistry::default(); + match registry.resolve(Some("mxbai-xsmall")) { + Ok((alias, config)) => { + println!(" ✅ Resolved alias: '{}'", alias); + println!(" Model name: {}", config.name); + println!(" Provider: {}", config.provider); + println!(" Dimensions: {}", config.dimensions); + println!(" Max tokens: {}", config.max_tokens); + } + Err(e) => { + println!(" ❌ Failed to resolve alias: {}", e); + return; + } + } + + // Test 2: Embedder Creation + println!("\n2. Testing Mixedbread Embedder Creation"); + println!(" Attempting to create Mixedbread embedder..."); + + let result = create_embedder(Some("mixedbread-ai/mxbai-embed-xsmall-v1")); + + match result { + Ok(mut embedder) => { + println!(" ✅ Successfully created embedder: {}", embedder.id()); + println!(" Model name: {}", embedder.model_name()); + println!(" Dimensions: {}", embedder.dim()); + + // Test 3: Embedding Generation + println!("\n3. Testing Embedding Generation"); + let test_texts = vec![ + "Hello world".to_string(), + "Rust programming language".to_string(), + "Machine learning and artificial intelligence".to_string(), + ]; + println!(" Generating embeddings for {} texts...", test_texts.len()); + + match embedder.embed(&test_texts) { + Ok(embeddings) => { + println!(" ✅ Successfully generated embeddings"); + println!(" Shape: {} embeddings of {} dimensions", embeddings.len(), embeddings[0].len()); + + // Verify dimensions + assert_eq!(embeddings.len(), test_texts.len(), "Should have one embedding per text"); + assert_eq!(embeddings[0].len(), 384, "Mixedbread xsmall should produce 384-dim vectors"); + + // Check normalization (L2 norm should be ~1.0) + for (i, emb) in embeddings.iter().enumerate() { + let norm: f32 = emb.iter().map(|x| x * x).sum::().sqrt(); + println!(" Embedding {} L2 norm: {:.6} (should be ~1.0)", i, norm); + assert!((norm - 1.0).abs() < 0.01, "Embeddings should be L2-normalized"); + } + } + Err(e) => { + println!(" ❌ Failed to generate embeddings: {}", e); + return; + } + } + } + Err(e) => { + println!(" ❌ Failed to create Mixedbread embedder: {}", e); + println!(" Error details: {:?}", e); + return; + } + } + + // Test 4: Reranker Registry Resolution + println!("\n4. Testing Reranker Registry Resolution"); + println!(" Checking if 'mxbai' reranker alias resolves..."); + + let rerank_registry = RerankModelRegistry::default(); + match rerank_registry.resolve(Some("mxbai")) { + Ok((alias, config)) => { + println!(" ✅ Resolved reranker alias: '{}'", alias); + println!(" Model name: {}", config.name); + println!(" Provider: {}", config.provider); + } + Err(e) => { + println!(" ❌ Failed to resolve reranker alias: {}", e); + return; + } + } + + // Test 5: Reranker Creation + println!("\n5. Testing Mixedbread Reranker Creation"); + println!(" Attempting to create Mixedbread reranker..."); + + match create_reranker(Some("mixedbread-ai/mxbai-rerank-xsmall-v1")) { + Ok(mut reranker) => { + println!(" ✅ Successfully created reranker: {}", reranker.id()); + + // Test 6: Reranking + println!("\n6. Testing Reranking"); + let query = "error handling in Rust"; + let documents = vec![ + "Rust error handling with Result and Option types".to_string(), + "Python web development frameworks".to_string(), + "Rust provides excellent error handling mechanisms".to_string(), + "JavaScript async programming patterns".to_string(), + ]; + println!(" Query: '{}'", query); + println!(" Reranking {} documents...", documents.len()); + + match reranker.rerank(query, &documents) { + Ok(results) => { + println!(" ✅ Successfully reranked documents"); + println!(" Results (sorted by score):"); + for (i, result) in results.iter().enumerate() { + println!(" {}. Score: {:.4} | Doc: {}", + i + 1, + result.score, + if result.document.len() > 60 { + &result.document[..60] + } else { + &result.document + } + ); + } + + // Verify results are sorted by score (descending) + let scores: Vec = results.iter().map(|r| r.score).collect(); + let sorted_scores: Vec = { + let mut s = scores.clone(); + s.sort_by(|a, b| b.partial_cmp(a).unwrap()); + s + }; + assert_eq!(scores, sorted_scores, "Results should be sorted by score descending"); + + // Verify scores are in valid range [0, 1] + for result in &results { + assert!(result.score >= 0.0 && result.score <= 1.0, + "Rerank scores should be in [0, 1] range"); + } + } + Err(e) => { + println!(" ❌ Failed to rerank: {}", e); + return; + } + } + } + Err(e) => { + println!(" ❌ Failed to create Mixedbread reranker: {}", e); + println!(" Error details: {:?}", e); + return; + } + } + + println!("\n=== All Tests Passed! ==="); +} + diff --git a/ck-embed/src/lib.rs b/ck-embed/src/lib.rs index 375de2f..ae69be0 100644 --- a/ck-embed/src/lib.rs +++ b/ck-embed/src/lib.rs @@ -1,14 +1,21 @@ -use anyhow::Result; - -#[cfg(feature = "fastembed")] +use anyhow::{Result, bail}; +use ck_models::{ModelConfig, ModelRegistry}; use std::path::{Path, PathBuf}; pub mod reranker; pub mod tokenizer; -pub use reranker::{RerankResult, Reranker, create_reranker, create_reranker_with_progress}; +pub use reranker::{ + RerankResult, Reranker, create_reranker, create_reranker_for_config, + create_reranker_with_progress, +}; pub use tokenizer::TokenEstimator; +#[cfg(feature = "mixedbread")] +mod mixedbread; +#[cfg(feature = "mixedbread")] +use mixedbread::MixedbreadEmbedder; + pub trait Embedder: Send + Sync { fn id(&self) -> &'static str; fn dim(&self) -> usize; @@ -18,6 +25,20 @@ pub trait Embedder: Send + Sync { pub type ModelDownloadCallback = Box; +pub(crate) fn model_cache_root() -> Result { + let base = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") { + PathBuf::from(cache_home).join("ck") + } else if let Some(home) = std::env::var_os("HOME") { + PathBuf::from(home).join(".cache").join("ck") + } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") { + PathBuf::from(appdata).join("ck").join("cache") + } else { + PathBuf::from(".ck_models") + }; + + Ok(base.join("models")) +} + pub fn create_embedder(model_name: Option<&str>) -> Result> { create_embedder_with_progress(model_name, None) } @@ -26,22 +47,52 @@ pub fn create_embedder_with_progress( model_name: Option<&str>, progress_callback: Option, ) -> Result> { - let model = model_name.unwrap_or("BAAI/bge-small-en-v1.5"); + let registry = ModelRegistry::default(); + let (_, config) = registry.resolve(model_name)?; + create_embedder_for_config(&config, progress_callback) +} - #[cfg(feature = "fastembed")] - { - Ok(Box::new(FastEmbedder::new_with_progress( - model, - progress_callback, - )?)) - } +pub fn create_embedder_for_config( + config: &ModelConfig, + progress_callback: Option, +) -> Result> { + match config.provider.as_str() { + "fastembed" => { + #[cfg(feature = "fastembed")] + { + return Ok(Box::new(FastEmbedder::new_with_progress( + config.name.as_str(), + progress_callback, + )?)); + } - #[cfg(not(feature = "fastembed"))] - { - if let Some(callback) = progress_callback { - callback("Using dummy embedder (no model download required)"); + #[cfg(not(feature = "fastembed"))] + { + if let Some(callback) = progress_callback.as_ref() { + callback("fastembed provider unavailable; using dummy embedder"); + } + return Ok(Box::new(DummyEmbedder::new_with_model( + config.name.as_str(), + ))); + } + } + "mixedbread" => { + #[cfg(feature = "mixedbread")] + { + return Ok(Box::new(MixedbreadEmbedder::new( + config, + progress_callback, + )?)); + } + #[cfg(not(feature = "mixedbread"))] + { + bail!( + "Model '{}' requires the `mixedbread` feature. Rebuild ck with Mixedbread support.", + config.name + ); + } } - Ok(Box::new(DummyEmbedder::new_with_model(model))) + provider => bail!("Unsupported embedding provider '{}'", provider), } } @@ -128,7 +179,7 @@ impl FastEmbedder { }; // Configure permanent model cache directory - let model_cache_dir = Self::get_model_cache_dir()?; + let model_cache_dir = model_cache_root()?; std::fs::create_dir_all(&model_cache_dir)?; if let Some(ref callback) = progress_callback { @@ -198,22 +249,6 @@ impl FastEmbedder { }) } - fn get_model_cache_dir() -> Result { - // Use platform-appropriate cache directory - let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") { - PathBuf::from(cache_home).join("ck") - } else if let Some(home) = std::env::var_os("HOME") { - PathBuf::from(home).join(".cache").join("ck") - } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") { - PathBuf::from(appdata).join("ck").join("cache") - } else { - // Fallback to current directory if no home found - PathBuf::from(".ck_models") - }; - - Ok(cache_dir.join("models")) - } - fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool { // Simple heuristic - check if model directory exists let model_dir = cache_dir.join(model_name.replace("/", "_")); diff --git a/ck-embed/src/mixedbread.rs b/ck-embed/src/mixedbread.rs new file mode 100644 index 0000000..75cdab1 --- /dev/null +++ b/ck-embed/src/mixedbread.rs @@ -0,0 +1,411 @@ +#![cfg(feature = "mixedbread")] + +use std::path::PathBuf; + +use anyhow::{Context, Result, anyhow}; +use hf_hub::{Repo, RepoType, api::sync::ApiBuilder}; +use ndarray::{Array2, ArrayView, ArrayViewD, Axis, Ix1, Ix2, Ix3}; +use num_cpus; +use ort::session::{Session, builder::GraphOptimizationLevel}; +use ort::value::TensorRef; +use tokenizers::{EncodeInput, Tokenizer}; + +use crate::{ + Embedder, ModelDownloadCallback, model_cache_root, + reranker::{RerankModelDownloadCallback, RerankResult, Reranker}, +}; +use ck_models::{ModelConfig, RerankModelConfig}; + +const EMBED_TOKENIZER_PATH: &str = "tokenizer.json"; +const EMBED_MODEL_PATH: &str = "onnx/model_quantized.onnx"; +const RERANK_TOKENIZER_PATH: &str = "tokenizer.json"; +const RERANK_MODEL_PATH: &str = "onnx/model_quantized.onnx"; + +pub struct MixedbreadEmbedder { + session: Session, + tokenizer: Tokenizer, + dim: usize, + max_length: usize, + model_name: String, + requires_token_type_ids: bool, +} + +impl MixedbreadEmbedder { + pub fn new( + config: &ModelConfig, + progress_callback: Option, + ) -> Result { + if let Some(cb) = progress_callback.as_ref() { + cb(&format!( + "Downloading Mixedbread embedding model ({}) if needed...", + config.name + )); + } + + let (model_path, tokenizer_path) = + download_assets(&config.name, EMBED_MODEL_PATH, EMBED_TOKENIZER_PATH)?; + + if let Some(cb) = progress_callback.as_ref() { + cb("Loading Mixedbread embedder session..."); + } + + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(num_cpus::get().max(1))? + .commit_from_file(&model_path)?; + + let tokenizer = + Tokenizer::from_file(tokenizer_path).map_err(|e| anyhow!("Tokenizer error: {e}"))?; + + let requires_token_type_ids = session + .inputs + .iter() + .any(|input| input.name == "token_type_ids"); + + Ok(Self { + session, + tokenizer, + dim: config.dimensions, + max_length: config.max_tokens, + model_name: config.name.clone(), + requires_token_type_ids, + }) + } + + fn build_inputs( + &self, + texts: &[String], + ) -> Result<(Array2, Array2, Option>)> { + let mut encodings = Vec::with_capacity(texts.len()); + for text in texts { + let encoding = self + .tokenizer + .encode(text.as_str(), true) + .map_err(|e| anyhow!("Tokenizer encode failed: {e}"))?; + encodings.push(encoding); + } + + let seq_len = encodings + .iter() + .map(|encoding| encoding.len()) + .max() + .unwrap_or(1) + .min(self.max_length) + .max(1); + + let batch = encodings.len(); + let mut input_ids = vec![0i64; batch * seq_len]; + let mut attention_mask = vec![0i64; batch * seq_len]; + let mut token_types = if self.requires_token_type_ids { + Some(vec![0i64; batch * seq_len]) + } else { + None + }; + + for (row, encoding) in encodings.iter().enumerate() { + let ids = encoding.get_ids(); + let mask = encoding.get_attention_mask(); + let type_ids = encoding.get_type_ids(); + let len = ids.len().min(seq_len); + + let row_offset = row * seq_len; + for idx in 0..len { + input_ids[row_offset + idx] = ids[idx] as i64; + attention_mask[row_offset + idx] = mask[idx] as i64; + } + + if let Some(ref mut token_types_buf) = token_types { + if !type_ids.is_empty() { + for idx in 0..len { + token_types_buf[row_offset + idx] = type_ids[idx] as i64; + } + } + } + } + + let token_type_array = + token_types.map(|buf| Array2::from_shape_vec((batch, seq_len), buf).unwrap()); + + Ok(( + Array2::from_shape_vec((batch, seq_len), input_ids) + .expect("validated dimensions for input ids"), + Array2::from_shape_vec((batch, seq_len), attention_mask) + .expect("validated dimensions for attention mask"), + token_type_array, + )) + } + + fn normalize(rows: ArrayViewD<'_, f32>, dim: usize) -> Result>> { + let ndim = rows.ndim(); + match ndim { + 2 => { + let view = rows.into_dimensionality::()?; + Ok(view + .rows() + .into_iter() + .map(|row| normalize_row(row, dim)) + .collect()) + } + 3 => { + let view = rows.into_dimensionality::()?; + Ok(view + .outer_iter() + .map(|matrix| normalize_row(matrix.index_axis(Axis(0), 0), dim)) + .collect()) + } + other => Err(anyhow!("Unexpected embedding tensor rank: {other}")), + } + } +} + +impl Embedder for MixedbreadEmbedder { + fn id(&self) -> &'static str { + "mixedbread" + } + + fn dim(&self) -> usize { + self.dim + } + + fn model_name(&self) -> &str { + &self.model_name + } + + fn embed(&mut self, texts: &[String]) -> Result>> { + if texts.is_empty() { + return Ok(Vec::new()); + } + + let (input_ids, attention_mask, token_types) = self.build_inputs(texts)?; + + let outputs = if self.requires_token_type_ids { + let token_types = token_types.expect("token type ids required but missing"); + self.session.run(ort::inputs![ + TensorRef::from_array_view(input_ids.view())?, + TensorRef::from_array_view(attention_mask.view())?, + TensorRef::from_array_view(token_types.view())? + ])? + } else { + self.session.run(ort::inputs![ + TensorRef::from_array_view(input_ids.view())?, + TensorRef::from_array_view(attention_mask.view())? + ])? + }; + + let embedding_tensor = outputs[0] + .try_extract_array::() + .context("Failed to extract embedding tensor")?; + + Self::normalize(embedding_tensor, self.dim) + } +} + +pub struct MixedbreadReranker { + session: Session, + tokenizer: Tokenizer, + max_length: usize, + requires_token_type_ids: bool, +} + +impl MixedbreadReranker { + pub fn new( + config: &RerankModelConfig, + progress_callback: Option, + ) -> Result { + if let Some(cb) = progress_callback.as_ref() { + cb(&format!( + "Downloading Mixedbread reranker model ({}) if needed...", + config.name + )); + } + + let (model_path, tokenizer_path) = + download_assets(&config.name, RERANK_MODEL_PATH, RERANK_TOKENIZER_PATH)?; + + if let Some(cb) = progress_callback.as_ref() { + cb("Loading Mixedbread reranker session..."); + } + + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(num_cpus::get().max(1))? + .commit_from_file(&model_path)?; + + let tokenizer = + Tokenizer::from_file(tokenizer_path).map_err(|e| anyhow!("Tokenizer error: {e}"))?; + + let requires_token_type_ids = session + .inputs + .iter() + .any(|input| input.name == "token_type_ids"); + + Ok(Self { + session, + tokenizer, + max_length: 512, + requires_token_type_ids, + }) + } + + fn build_inputs( + &self, + query: &str, + documents: &[String], + ) -> Result<(Array2, Array2, Option>)> { + let mut encodings = Vec::with_capacity(documents.len()); + for doc in documents { + let encoding = self + .tokenizer + .encode(EncodeInput::Dual(query.into(), doc.as_str().into()), true) + .map_err(|e| anyhow!("Tokenizer encode failed: {e}"))?; + encodings.push(encoding); + } + + let seq_len = encodings + .iter() + .map(|encoding| encoding.len()) + .max() + .unwrap_or(1) + .min(self.max_length) + .max(1); + + let batch = encodings.len(); + let mut input_ids = vec![0i64; batch * seq_len]; + let mut attention_mask = vec![0i64; batch * seq_len]; + let mut token_types = if self.requires_token_type_ids { + Some(vec![0i64; batch * seq_len]) + } else { + None + }; + + for (row, encoding) in encodings.iter().enumerate() { + let ids = encoding.get_ids(); + let mask = encoding.get_attention_mask(); + let type_ids = encoding.get_type_ids(); + let len = ids.len().min(seq_len); + let offset = row * seq_len; + + for idx in 0..len { + input_ids[offset + idx] = ids[idx] as i64; + attention_mask[offset + idx] = mask[idx] as i64; + } + + if let Some(ref mut token_types_buf) = token_types { + if !type_ids.is_empty() { + for idx in 0..len { + token_types_buf[offset + idx] = type_ids[idx] as i64; + } + } + } + } + + let token_type_array = + token_types.map(|buf| Array2::from_shape_vec((batch, seq_len), buf).unwrap()); + + Ok(( + Array2::from_shape_vec((batch, seq_len), input_ids) + .expect("validated dimensions for input ids"), + Array2::from_shape_vec((batch, seq_len), attention_mask) + .expect("validated dimensions for attention mask"), + token_type_array, + )) + } +} + +impl Reranker for MixedbreadReranker { + fn id(&self) -> &'static str { + "mixedbread_reranker" + } + + fn rerank(&mut self, query: &str, documents: &[String]) -> Result> { + if documents.is_empty() { + return Ok(Vec::new()); + } + + let (input_ids, attention_mask, token_types) = self.build_inputs(query, documents)?; + + let outputs = if self.requires_token_type_ids { + let token_types = token_types.expect("token type ids required but missing"); + self.session.run(ort::inputs![ + TensorRef::from_array_view(input_ids.view())?, + TensorRef::from_array_view(attention_mask.view())?, + TensorRef::from_array_view(token_types.view())? + ])? + } else { + self.session.run(ort::inputs![ + TensorRef::from_array_view(input_ids.view())?, + TensorRef::from_array_view(attention_mask.view())? + ])? + }; + + let logits = outputs[0] + .try_extract_array::() + .context("Failed to extract reranker logits")? + .into_dimensionality::()?; + + let mut results = Vec::with_capacity(documents.len()); + for (i, row) in logits.rows().into_iter().enumerate() { + let logit = row + .get(0) + .copied() + .unwrap_or_else(|| row.iter().copied().next().unwrap_or(0.0)); + let score = 1.0 / (1.0 + (-logit).exp()); + results.push(RerankResult { + query: query.to_string(), + document: documents[i].clone(), + score, + }); + } + + Ok(results) + } +} + +fn normalize_row(row: ArrayView<'_, f32, Ix1>, dim: usize) -> Vec { + let take = row.len().min(dim); + let mut values = vec![0f32; dim]; + let mut norm = 0.0; + for (idx, value) in row.iter().take(take).enumerate() { + values[idx] = *value; + norm += value * value; + } + + if norm > 0.0 { + let inv = norm.sqrt().recip(); + for value in values.iter_mut().take(take) { + *value *= inv; + } + } + + values +} + +fn download_assets( + model_id: &str, + model_path: &str, + tokenizer_path: &str, +) -> Result<(PathBuf, PathBuf)> { + let cache_dir = model_cache_root()?; + std::fs::create_dir_all(&cache_dir)?; + + let api = ApiBuilder::new() + .with_cache_dir(cache_dir) + .build() + .context("Failed to initialize Hugging Face Hub client")?; + + let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, "main".to_string()); + let tokenizer = api + .repo(Repo::with_revision( + model_id.to_string(), + RepoType::Model, + "main".to_string(), + )) + .get(tokenizer_path) + .with_context(|| format!("Failed to download tokenizer for {model_id}"))?; + let model = api + .repo(repo) + .get(model_path) + .with_context(|| format!("Failed to download ONNX model for {model_id}"))?; + + Ok((model, tokenizer)) +} diff --git a/ck-embed/src/reranker.rs b/ck-embed/src/reranker.rs index 408e131..cc14cd5 100644 --- a/ck-embed/src/reranker.rs +++ b/ck-embed/src/reranker.rs @@ -1,4 +1,8 @@ -use anyhow::Result; +use anyhow::{Result, bail}; +use ck_models::{RerankModelConfig, RerankModelRegistry}; + +#[cfg(feature = "mixedbread")] +use crate::mixedbread::MixedbreadReranker; #[cfg(feature = "fastembed")] use std::path::PathBuf; @@ -25,23 +29,50 @@ pub fn create_reranker_with_progress( model_name: Option<&str>, progress_callback: Option, ) -> Result> { - let model = model_name.unwrap_or("jina-reranker-v1-turbo-en"); + let registry = RerankModelRegistry::default(); + let (_, config) = registry.resolve(model_name)?; + create_reranker_for_config(&config, progress_callback) +} - #[cfg(feature = "fastembed")] - { - Ok(Box::new(FastReranker::new_with_progress( - model, - progress_callback, - )?)) - } +pub fn create_reranker_for_config( + config: &RerankModelConfig, + progress_callback: Option, +) -> Result> { + match config.provider.as_str() { + "fastembed" => { + #[cfg(feature = "fastembed")] + { + return Ok(Box::new(FastReranker::new_with_progress( + config.name.as_str(), + progress_callback, + )?)); + } - #[cfg(not(feature = "fastembed"))] - { - let _ = model; // Suppress unused variable warning - if let Some(callback) = progress_callback { - callback("Using dummy reranker (no model download required)"); + #[cfg(not(feature = "fastembed"))] + { + if let Some(callback) = progress_callback.as_ref() { + callback("fastembed reranker unavailable; using dummy reranker"); + } + return Ok(Box::new(DummyReranker::new())); + } + } + "mixedbread" => { + #[cfg(feature = "mixedbread")] + { + return Ok(Box::new(MixedbreadReranker::new( + config, + progress_callback, + )?)); + } + #[cfg(not(feature = "mixedbread"))] + { + bail!( + "Reranking model '{}' requires the `mixedbread` feature. Rebuild ck with Mixedbread support.", + config.name + ); + } } - Ok(Box::new(DummyReranker::new())) + provider => bail!("Unsupported reranker provider '{}'", provider), } } diff --git a/ck-engine/src/lib.rs b/ck-engine/src/lib.rs index 9ec3132..fc3eb63 100644 --- a/ck-engine/src/lib.rs +++ b/ck-engine/src/lib.rs @@ -201,24 +201,28 @@ fn find_nearest_index_root(path: &Path) -> Option { #[derive(Clone, Debug)] pub struct ResolvedModel { - pub canonical_name: String, pub alias: String, - pub dimensions: usize, + pub config: ck_models::ModelConfig, } -fn find_model_entry<'a>( - registry: &'a ck_models::ModelRegistry, - key: &str, -) -> Option<(String, &'a ck_models::ModelConfig)> { - if let Some(config) = registry.get_model(key) { - return Some((key.to_string(), config)); +impl ResolvedModel { + pub fn canonical_name(&self) -> &str { + self.config.name.as_str() } - registry - .models - .iter() - .find(|(_, config)| config.name == key) - .map(|(alias, config)| (alias.clone(), config)) + pub fn dimensions(&self) -> usize { + self.config.dimensions + } +} + +fn legacy_model_config(name: &str, dimensions: usize) -> ck_models::ModelConfig { + ck_models::ModelConfig { + name: name.to_string(), + provider: "fastembed".to_string(), + dimensions, + max_tokens: 8192, + description: "Legacy ck embedding model preserved for backwards compatibility".to_string(), + } } pub(crate) fn resolve_model_from_root( @@ -236,35 +240,25 @@ pub(crate) fn resolve_model_from_root( let manifest: ck_index::IndexManifest = serde_json::from_slice(&data)?; if let Some(existing_model) = manifest.embedding_model { - let (alias, config_opt) = find_model_entry(®istry, &existing_model) - .map(|(alias, config)| (alias, Some(config))) - .unwrap_or_else(|| (existing_model.clone(), None)); - - let dims = manifest - .embedding_dimensions - .or_else(|| config_opt.map(|c| c.dimensions)) - .unwrap_or(384); + let dims_hint = manifest.embedding_dimensions.unwrap_or(384); + let resolved_existing = match registry.resolve(Some(existing_model.as_str())) { + Ok((alias, config)) => ResolvedModel { alias, config }, + Err(_) => ResolvedModel { + alias: existing_model.clone(), + config: legacy_model_config(&existing_model, dims_hint), + }, + }; if let Some(requested) = cli_model { - let (_, requested_config) = - find_model_entry(®istry, requested).ok_or_else(|| { - CkError::Embedding(format!( - "Unknown model '{}'. Available models: {}", - requested, - registry - .models - .keys() - .cloned() - .collect::>() - .join(", ") - )) - })?; - - if requested_config.name != existing_model { - let suggested_alias = alias.clone(); + let (requested_alias, requested_config) = registry + .resolve(Some(requested)) + .map_err(|e| CkError::Embedding(e.to_string()))?; + + if requested_config.name != resolved_existing.config.name { + let suggested_alias = resolved_existing.alias.clone(); return Err(CkError::Embedding(format!( "Index was built with embedding model '{}' (alias '{}'), but '--model {}' was requested. To switch models run `ck --clean .` then `ck --index --model {}`. To keep using this index rerun your command with '--model {}'.", - existing_model, + resolved_existing.config.name, suggested_alias, requested, requested, @@ -272,42 +266,22 @@ pub(crate) fn resolve_model_from_root( )) .into()); } + + return Ok(ResolvedModel { + alias: requested_alias, + config: requested_config, + }); } - return Ok(ResolvedModel { - canonical_name: existing_model, - alias, - dimensions: dims, - }); + return Ok(resolved_existing); } } - let (alias, config) = if let Some(requested) = cli_model { - find_model_entry(®istry, requested).ok_or_else(|| { - CkError::Embedding(format!( - "Unknown model '{}'. Available models: {}", - requested, - registry - .models - .keys() - .cloned() - .collect::>() - .join(", ") - )) - })? - } else { - let alias = registry.default_model.clone(); - let config = registry.get_default_model().ok_or_else(|| { - CkError::Embedding("No default embedding model configured".to_string()) - })?; - (alias, config) - }; + let (alias, config) = registry + .resolve(cli_model) + .map_err(|e| CkError::Embedding(e.to_string()))?; - Ok(ResolvedModel { - canonical_name: config.name.clone(), - alias, - dimensions: config.dimensions, - }) + Ok(ResolvedModel { alias, config }) } pub fn resolve_model_for_path(path: &Path, cli_model: Option<&str>) -> Result { diff --git a/ck-engine/src/semantic_v3.rs b/ck-engine/src/semantic_v3.rs index 08e0ead..81f8abe 100644 --- a/ck-engine/src/semantic_v3.rs +++ b/ck-engine/src/semantic_v3.rs @@ -85,13 +85,23 @@ pub async fn semantic_search_v3_with_progress( let resolved_model = resolve_model_from_root(&index_root, options.embedding_model.as_deref())?; if let Some(ref callback) = progress_callback { - callback(&format!( - "Using embedding model {} ({} dims)", - resolved_model.alias, resolved_model.dimensions - )); + if resolved_model.alias == resolved_model.canonical_name() { + callback(&format!( + "Using embedding model {} ({} dims)", + resolved_model.canonical_name(), + resolved_model.dimensions() + )); + } else { + callback(&format!( + "Using embedding model {} (alias '{}', {} dims)", + resolved_model.canonical_name(), + resolved_model.alias, + resolved_model.dimensions() + )); + } } - let mut embedder = ck_embed::create_embedder(Some(resolved_model.canonical_name.as_str()))?; + let mut embedder = ck_embed::create_embedder_for_config(&resolved_model.config, None)?; let query_embeddings = embedder.embed(std::slice::from_ref(&options.query))?; if query_embeddings.is_empty() { @@ -208,15 +218,17 @@ pub async fn semantic_search_v3_with_progress( callback("Reranking results for improved relevance..."); } - let rerank_model_name = match options.rerank_model.as_deref() { - Some("jina") => Some("jina-reranker-v1-base-en"), - Some("bge") => Some("BAAI/bge-reranker-base"), - Some(name) => Some(name), // Pass through custom model names - None => Some("jina-reranker-v1-base-en"), // Default to jina - }; + let rerank_registry = ck_models::RerankModelRegistry::default(); + let (rerank_alias, rerank_config) = rerank_registry + .resolve(options.rerank_model.as_deref()) + .map_err(|e| anyhow::anyhow!(e.to_string()))?; - match ck_embed::create_reranker(rerank_model_name) { + match ck_embed::create_reranker_for_config(&rerank_config, None) { Ok(mut reranker) => { + if let Some(ref callback) = progress_callback { + callback(&format!("Reranking results with model {}", rerank_alias)); + } + let documents: Vec = results.iter().map(|r| r.preview.clone()).collect(); match reranker.rerank(&options.query, &documents) { diff --git a/ck-index/src/lib.rs b/ck-index/src/lib.rs index b8b4373..b5e0897 100644 --- a/ck-index/src/lib.rs +++ b/ck-index/src/lib.rs @@ -15,6 +15,16 @@ use std::time::SystemTime; use tempfile::NamedTempFile; use walkdir::WalkDir; +fn legacy_model_config(name: &str, dimensions: Option) -> ck_models::ModelConfig { + ck_models::ModelConfig { + name: name.to_string(), + provider: "fastembed".to_string(), + dimensions: dimensions.unwrap_or(384), + max_tokens: 8192, + description: "Legacy ck embedding model (inferred from manifest)".to_string(), + } +} + pub type ProgressCallback = Box; /// Detailed progress information for embedding operations @@ -253,51 +263,27 @@ pub async fn index_directory( // Handle model configuration for embeddings let resolved_model = if compute_embeddings { - // Resolve the model name and get its dimensions let model_registry = ck_models::ModelRegistry::default(); - let selected_model = if let Some(model_name) = model { - // User specified a model - if let Some(model_config) = model_registry.get_model(model_name) { - model_config.name.clone() - } else { + let (alias, config) = model_registry + .resolve(model) + .map_err(|e| anyhow::anyhow!(e.to_string()))?; + + if let Some(existing_model) = &manifest.embedding_model { + if existing_model != &config.name { return Err(anyhow::anyhow!( - "Unknown model '{}'. Available models: bge-small, nomic-v1.5, jina-code", - model_name + "Model mismatch: Index was created with '{}', but you're trying to use '{}'. \ + Please run 'ck --clean {}' to remove the old index, then rerun with the new model.", + existing_model, + config.name, + path.display() )); } - } else { - // Use default model - let default_config = model_registry - .get_default_model() - .ok_or_else(|| anyhow::anyhow!("No default model available"))?; - default_config.name.clone() - }; - - // Check for model compatibility with existing index - if let Some(existing_model) = &manifest.embedding_model - && existing_model != &selected_model - { - // Model mismatch - this is an error to prevent reusing embeddings from a different model - return Err(anyhow::anyhow!( - "Model mismatch: Index was created with '{}', but you're trying to use '{}'. \ - Please run 'ck --clean {}' to remove the old index, then rerun with the new model.", - existing_model, - selected_model, - path.display() - )); } - // Set the model info in the manifest - manifest.embedding_model = Some(selected_model.clone()); - if let Some(model_name) = model { - if let Some(model_config) = model_registry.get_model(model_name) { - manifest.embedding_dimensions = Some(model_config.dimensions); - } - } else if let Some(default_config) = model_registry.get_default_model() { - manifest.embedding_dimensions = Some(default_config.dimensions); - } + manifest.embedding_model = Some(config.name.clone()); + manifest.embedding_dimensions = Some(config.dimensions); - Some(selected_model) + Some((alias, config)) } else { None }; @@ -307,7 +293,10 @@ pub async fn index_directory( if compute_embeddings { // Sequential processing with small-batch embeddings for streaming performance tracing::info!("Creating embedder for {} files", files.len()); - let mut embedder = ck_embed::create_embedder(resolved_model.as_deref())?; + let (_, config) = resolved_model + .as_ref() + .expect("resolved model must be present when computing embeddings"); + let mut embedder = ck_embed::create_embedder_for_config(config, None)?; for file_path in files.iter() { match index_single_file(file_path, path, Some(&mut embedder)) { @@ -416,9 +405,26 @@ pub async fn index_file(file_path: &Path, compute_embeddings: bool) -> Result<() let mut manifest = load_or_create_manifest(&manifest_path)?; let entry = if compute_embeddings { - // Use the model from the existing index, or default if none specified - let model_name = manifest.embedding_model.as_deref(); - let mut embedder = ck_embed::create_embedder(model_name)?; + let model_registry = ck_models::ModelRegistry::default(); + let (alias, config) = if let Some(existing) = manifest.embedding_model.as_deref() { + match model_registry.resolve(Some(existing)) { + Ok(resolved) => resolved, + Err(_) => ( + existing.to_string(), + legacy_model_config(existing, manifest.embedding_dimensions), + ), + } + } else { + model_registry + .resolve(None) + .map_err(|e| anyhow::anyhow!(e.to_string()))? + }; + + manifest.embedding_model = Some(config.name.clone()); + manifest.embedding_dimensions = Some(config.dimensions); + tracing::debug!("Using embedding model '{}' ({})", config.name, alias); + + let mut embedder = ck_embed::create_embedder_for_config(&config, None)?; index_single_file(file_path, &repo_root, Some(&mut embedder))? } else { index_single_file(file_path, &repo_root, None)? @@ -461,8 +467,30 @@ pub async fn update_index( let updates: Vec<(PathBuf, IndexEntry)> = if compute_embeddings { // Sequential processing when computing embeddings (for memory efficiency) - let model_name = manifest.embedding_model.as_deref(); - let mut embedder = ck_embed::create_embedder(model_name)?; + let model_registry = ck_models::ModelRegistry::default(); + let (alias, config) = if let Some(existing) = manifest.embedding_model.as_deref() { + match model_registry.resolve(Some(existing)) { + Ok(resolved) => resolved, + Err(_) => ( + existing.to_string(), + legacy_model_config(existing, manifest.embedding_dimensions), + ), + } + } else { + model_registry + .resolve(None) + .map_err(|e| anyhow::anyhow!(e.to_string()))? + }; + + manifest.embedding_model = Some(config.name.clone()); + manifest.embedding_dimensions = Some(config.dimensions); + tracing::debug!( + "Updating index with embedding model '{}' ({})", + config.name, + alias + ); + + let mut embedder = ck_embed::create_embedder_for_config(&config, None)?; files .iter() .filter_map(|file_path| { @@ -737,61 +765,45 @@ pub async fn smart_update_index_with_detailed_progress( normalize_manifest_paths(&mut manifest, &repo_root); // Handle model configuration for embeddings - let (resolved_model, _model_dimensions) = if compute_embeddings { - // Resolve the model name and get its dimensions + let resolved_model = if compute_embeddings { let model_registry = ck_models::ModelRegistry::default(); - let (selected_model, model_dims) = if let Some(model_name) = model { - // User specified a model - if let Some(model_config) = model_registry.get_model(model_name) { - (model_config.name.clone(), model_config.dimensions) - } else { - return Err(anyhow::anyhow!( - "Unknown model '{}'. Available models: bge-small, nomic-v1.5, jina-code", - model_name - )); + + let resolved = if let Some(requested) = model { + model_registry + .resolve(Some(requested)) + .map_err(|e| anyhow::anyhow!(e.to_string()))? + } else if let Some(existing_model) = &manifest.embedding_model { + match model_registry.resolve(Some(existing_model.as_str())) { + Ok(resolved) => resolved, + Err(_) => ( + existing_model.clone(), + legacy_model_config(existing_model, manifest.embedding_dimensions), + ), } } else { - // Use default model - let default_config = model_registry - .get_default_model() - .ok_or_else(|| anyhow::anyhow!("No default model available"))?; - (default_config.name.clone(), default_config.dimensions) + model_registry + .resolve(None) + .map_err(|e| anyhow::anyhow!(e.to_string()))? }; - // Check for model compatibility with existing index - let (final_model, final_dims) = if let Some(existing_model) = &manifest.embedding_model { - // If we're updating an existing index and no model was specified, - // use the existing model from the index - if model.is_none() { - // Use the existing model - this is an auto-update during search - ( - existing_model.clone(), - manifest.embedding_dimensions.unwrap_or(384), - ) - } else if existing_model != &selected_model { - // User explicitly specified a different model - that's an error + if let Some(existing_model) = &manifest.embedding_model { + if existing_model != &resolved.1.name { return Err(anyhow::anyhow!( "Model mismatch: Index was created with '{}', but you're trying to use '{}'. \ Please run 'ck --clean .' to remove the old index, then 'ck --index --model {}' to rebuild with the new model.", existing_model, - selected_model, + resolved.1.name, model.unwrap_or("default") )); - } else { - // Model matches, proceed - (selected_model, model_dims) } - } else { - // This is either a new index or an old index without model info - // Set the model info in the manifest - manifest.embedding_model = Some(selected_model.clone()); - manifest.embedding_dimensions = Some(model_dims); - (selected_model, model_dims) - }; + } - (Some(final_model), Some(final_dims)) + manifest.embedding_model = Some(resolved.1.name.clone()); + manifest.embedding_dimensions = Some(resolved.1.dimensions); + + Some(resolved) } else { - (None, None) + None }; // For incremental updates, only process files in the search scope @@ -872,7 +884,10 @@ pub async fn smart_update_index_with_detailed_progress( // Second pass: index the files that need updating if compute_embeddings { // Sequential processing with streaming - write each file immediately - let mut embedder = ck_embed::create_embedder(resolved_model.as_deref())?; + let (_, config) = resolved_model + .as_ref() + .expect("resolved model must exist for embedding updates"); + let mut embedder = ck_embed::create_embedder_for_config(config, None)?; let mut _processed_count = 0; for file_path in files_to_update.iter() { diff --git a/ck-models/src/lib.rs b/ck-models/src/lib.rs index 4147f53..223ba05 100644 --- a/ck-models/src/lib.rs +++ b/ck-models/src/lib.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{Result, anyhow}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::Path; @@ -69,6 +69,17 @@ impl Default for ModelRegistry { }, ); + models.insert( + "mxbai-xsmall".to_string(), + ModelConfig { + name: "mixedbread-ai/mxbai-embed-xsmall-v1".to_string(), + provider: "mixedbread".to_string(), + dimensions: 384, + max_tokens: 4096, + description: "Mixedbread xsmall embedding model (4k context, 384 dims) optimized for local semantic search".to_string(), + }, + ); + Self { models, default_model: "bge-small".to_string(), // Keep BGE as default for backward compatibility @@ -77,6 +88,50 @@ impl Default for ModelRegistry { } impl ModelRegistry { + fn format_available_models(&self) -> String { + self.models.keys().cloned().collect::>().join(", ") + } + + fn resolve_alias_or_name(&self, key: &str) -> Option<(String, &ModelConfig)> { + if let Some(config) = self.models.get(key) { + return Some((key.to_string(), config)); + } + + self.models + .iter() + .find(|(_, config)| config.name == key) + .map(|(alias, config)| (alias.clone(), config)) + } + + pub fn resolve(&self, requested: Option<&str>) -> Result<(String, ModelConfig)> { + match requested { + Some(name) => { + let (alias, config) = self.resolve_alias_or_name(name).ok_or_else(|| { + anyhow!( + "Unknown model '{}'. Available models: {}", + name, + self.format_available_models() + ) + })?; + Ok((alias, config.clone())) + } + None => { + let alias = self.default_model.clone(); + let config = self + .get_default_model() + .cloned() + .ok_or_else(|| anyhow!("No default model configured in registry"))?; + Ok((alias, config)) + } + } + } + + pub fn aliases(&self) -> Vec { + let mut keys = self.models.keys().cloned().collect::>(); + keys.sort(); + keys + } + pub fn load(path: &Path) -> Result { if path.exists() { let data = std::fs::read_to_string(path)?; @@ -101,6 +156,107 @@ impl ModelRegistry { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RerankModelConfig { + pub name: String, + pub provider: String, + pub description: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RerankModelRegistry { + pub models: HashMap, + pub default_model: String, +} + +impl Default for RerankModelRegistry { + fn default() -> Self { + let mut models = HashMap::new(); + + models.insert( + "jina".to_string(), + RerankModelConfig { + name: "jina-reranker-v1-turbo-en".to_string(), + provider: "fastembed".to_string(), + description: + "Jina Turbo reranker (default) tuned for English code + text relevance" + .to_string(), + }, + ); + + models.insert( + "bge".to_string(), + RerankModelConfig { + name: "BAAI/bge-reranker-base".to_string(), + provider: "fastembed".to_string(), + description: "BGE reranker base model for multilingual use cases".to_string(), + }, + ); + + models.insert( + "mxbai".to_string(), + RerankModelConfig { + name: "mixedbread-ai/mxbai-rerank-xsmall-v1".to_string(), + provider: "mixedbread".to_string(), + description: "Mixedbread xsmall reranker (quantized) optimized for local inference" + .to_string(), + }, + ); + + Self { + models, + default_model: "jina".to_string(), + } + } +} + +impl RerankModelRegistry { + fn format_available_models(&self) -> String { + self.models.keys().cloned().collect::>().join(", ") + } + + fn resolve_alias_or_name(&self, key: &str) -> Option<(String, &RerankModelConfig)> { + if let Some(config) = self.models.get(key) { + return Some((key.to_string(), config)); + } + + self.models + .iter() + .find(|(_, config)| config.name == key) + .map(|(alias, config)| (alias.clone(), config)) + } + + pub fn resolve(&self, requested: Option<&str>) -> Result<(String, RerankModelConfig)> { + match requested { + Some(name) => { + let (alias, config) = self.resolve_alias_or_name(name).ok_or_else(|| { + anyhow!( + "Unknown rerank model '{}'. Available models: {}", + name, + self.format_available_models() + ) + })?; + Ok((alias, config.clone())) + } + None => { + let alias = self.default_model.clone(); + let config = self + .models + .get(&self.default_model) + .cloned() + .ok_or_else(|| anyhow!("No default reranking model configured"))?; + Ok((alias, config)) + } + } + } + + pub fn aliases(&self) -> Vec { + let mut keys = self.models.keys().cloned().collect::>(); + keys.sort(); + keys + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ProjectConfig { pub model: String, diff --git a/docs-site/guide/changelog.md b/docs-site/guide/changelog.md index 22b58ad..853051a 100644 --- a/docs-site/guide/changelog.md +++ b/docs-site/guide/changelog.md @@ -10,6 +10,13 @@ All notable changes to ck are documented here following [Semantic Versioning](ht ## [Unreleased] ### Added +- **Mixedbread model support**: Added first-class support for Mixedbread embedding and reranking models + - Embedding model: `mxbai-xsmall` (`mixedbread-ai/mxbai-embed-xsmall-v1`) - 384 dimensions, 4K context window + - Reranker: `mxbai` (`mixedbread-ai/mxbai-rerank-xsmall-v1`) - Neural cross-encoder reranker + - Fully local inference using ONNX Runtime with quantized models + - Provider abstraction for clean model selection and routing + - CLI support: `--model mxbai-xsmall` and `--rerank-model mxbai` + - MCP server support for Mixedbread models in semantic/hybrid search tools - **VitePress documentation site**: Comprehensive documentation with improved navigation, search, and structure in `docs-site/` directory - **Documentation features**: Guide pages, feature documentation, CLI reference, embedding model guide, architecture docs, and contributing guides - **Local search**: Built-in search functionality in documentation site diff --git a/docs-site/guide/faq.md b/docs-site/guide/faq.md index d1b2b30..532bc45 100644 --- a/docs-site/guide/faq.md +++ b/docs-site/guide/faq.md @@ -307,6 +307,7 @@ ck --no-ignore --no-ckignore "pattern" . | Model | Best For | Trade-off | |-------|----------|-----------| | `bge-small` | General use, fast | Smaller chunks (400 tokens) | +| `mxbai-xsmall` | Local semantic search, balanced | Newer model, requires ONNX | | `nomic-v1.5` | Large functions, docs | Larger download (~500MB) | | `jina-code` | Code-specialized | Larger download (~500MB) | @@ -314,10 +315,9 @@ See [Embedding Models](/reference/models) for detailed comparison. ### Can I use custom embedding models? -Currently limited to models supported by fastembed-rs: -- `bge-small` (default) -- `nomic-v1.5` -- `jina-code` +Currently supported models: +- **FastEmbed provider**: `bge-small` (default), `nomic-v1.5`, `jina-code` +- **Mixedbread provider**: `mxbai-xsmall` (embedding), `mxbai` (reranker) **Future**: External embedding API support (OpenAI, HuggingFace, etc.) is being considered (#49). diff --git a/docs-site/guide/limitations.md b/docs-site/guide/limitations.md index 86cc395..b23d0ee 100644 --- a/docs-site/guide/limitations.md +++ b/docs-site/guide/limitations.md @@ -135,19 +135,18 @@ Rebuilding takes time but is necessary when changing models. ### Limited Embedding Model Options -**Issue**: Only models supported by fastembed-rs are available. +**Issue**: Only models supported by the built-in providers are available. **Current models**: -- `bge-small` (default) -- `nomic-v1.5` -- `jina-code` +- **FastEmbed provider**: `bge-small` (default), `nomic-v1.5`, `jina-code` +- **Mixedbread provider**: `mxbai-xsmall` (embedding), `mxbai` (reranker) **Not supported**: -- Custom ONNX models +- Custom ONNX models (beyond Mixedbread) - External API-based models (OpenAI, Anthropic, HuggingFace Inference API) - Proprietary embedding services -**Why**: `ck` uses fastembed-rs for fast local inference, which limits options to its supported models. +**Why**: `ck` uses fastembed-rs and Mixedbread ONNX Runtime for fast local inference, which limits options to supported models. The provider abstraction allows adding new providers in the future. **Future consideration**: External embedding API support is being considered for users who want to use specific models. diff --git a/docs-site/reference/models.md b/docs-site/reference/models.md index 35bdd4e..f5144bb 100644 --- a/docs-site/reference/models.md +++ b/docs-site/reference/models.md @@ -1,6 +1,6 @@ --- title: Embedding Models -description: Compare BGE-Small, Nomic V1.5, and Jina Code embedding models for ck. Understand chunk sizes, context windows, and performance trade-offs. +description: Compare BGE-Small, Mixedbread xsmall, Nomic V1.5, and Jina Code embedding models for ck. Understand chunk sizes, context windows, and performance trade-offs. --- # Embedding Models @@ -95,18 +95,49 @@ ck --index --model jina-code . - Larger model download - May be overkill for simple searches +### Mixedbread xsmall + +```bash +ck --index --model mxbai-xsmall . +``` + +**Specifications:** +- Chunk size: Variable (up to 4096 tokens) +- Model capacity: 4096 tokens +- Dimensions: 384 +- Size: ~150MB (quantized ONNX) +- Provider: Mixedbread (ONNX Runtime) + +**Best for:** +- Local semantic search +- Code + natural language understanding +- Balanced performance and quality +- Optimized for local inference + +**Pros:** +- Optimized for local inference +- Good balance of speed and quality +- 4K context window +- Quantized model (smaller download) +- Strong semantic understanding + +**Cons:** +- Newer model (less field-tested than BGE) +- Requires ONNX Runtime + ## Comparison Table -| Feature | BGE-Small | Nomic V1.5 | Jina Code | -|---------|-----------|------------|-----------| -| Chunk Size | 400 tokens | 1024 tokens | 1024 tokens | -| Context Window | 512 tokens | 8K tokens | 8K tokens | -| Dimensions | 384 | 768 | 768 | -| Download Size | ~80MB | ~500MB | ~500MB | -| Index Speed | ⚡⚡⚡ | ⚡⚡ | ⚡⚡ | -| Memory Usage | Low | Medium | Medium | -| Code Understanding | Good | Good | Excellent | -| Large Functions | Fair | Excellent | Excellent | +| Feature | BGE-Small | Mixedbread xsmall | Nomic V1.5 | Jina Code | +|---------|-----------|-------------------|------------|-----------| +| Chunk Size | 400 tokens | Up to 4096 tokens | 1024 tokens | 1024 tokens | +| Context Window | 512 tokens | 4K tokens | 8K tokens | 8K tokens | +| Dimensions | 384 | 384 | 768 | 768 | +| Download Size | ~80MB | ~150MB | ~500MB | ~500MB | +| Index Speed | ⚡⚡⚡ | ⚡⚡⚡ | ⚡⚡ | ⚡⚡ | +| Memory Usage | Low | Low | Medium | Medium | +| Code Understanding | Good | Excellent | Good | Excellent | +| Large Functions | Fair | Good | Excellent | Excellent | +| Provider | FastEmbed | Mixedbread | FastEmbed | FastEmbed | ## Model Selection Guide @@ -122,6 +153,8 @@ ck --index --model bge-small . ```bash ck --index --model bge-small . # Fast iteration # or +ck --index --model mxbai-xsmall . # Balanced performance +# or ck --index --model jina-code . # Better understanding ```