Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,819 changes: 4,419 additions & 400 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ rusqlite = { version = "0.30", features = ["bundled", "backup"] }
rmcp = { version = "0.11", features = ["server", "macros", "transport-io"] }
schemars = { version = "1.0", features = ["chrono04", "uuid1"] }

# Semantic Search / Vector DB
lancedb = "0.15"
arrow = { version = "53", default-features = false, features = ["chrono-tz"] }
arrow-array = "53"
arrow-schema = "53"
fastembed = "4"
sha2 = "0.10"

[workspace.package]
version = "0.1.0"
edition = "2021"
Expand Down
9 changes: 9 additions & 0 deletions crates/retrochat-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,18 @@ notify = { workspace = true }
similar = { workspace = true }
crossterm = { workspace = true }

# Semantic Search (optional)
lancedb = { workspace = true, optional = true }
arrow = { workspace = true, optional = true }
arrow-array = { workspace = true, optional = true }
arrow-schema = { workspace = true, optional = true }
fastembed = { workspace = true, optional = true }
sha2 = { workspace = true }

[features]
default = ["reqwest"]
reqwest = ["dep:reqwest"]
semantic-search = ["lancedb", "arrow", "arrow-array", "arrow-schema", "fastembed"]

[dev-dependencies]
tempfile = "3.8"
Expand Down
99 changes: 99 additions & 0 deletions crates/retrochat-core/src/embedding/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//! Configuration for embedding generation.

use std::path::PathBuf;

use super::models::EmbeddingModel;

/// Configuration for the embedding service.
#[derive(Debug, Clone)]
pub struct EmbeddingConfig {
/// The embedding model to use.
pub model: EmbeddingModel,

/// Directory to cache downloaded models.
/// Defaults to `~/.retrochat/models/` if not specified.
pub cache_dir: Option<PathBuf>,

/// Whether to show download progress when fetching models.
pub show_download_progress: bool,
}

impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model: EmbeddingModel::BGESmallENV15,
cache_dir: None,
show_download_progress: true,
}
}
}

impl EmbeddingConfig {
/// Create a new configuration with the specified model.
pub fn new(model: EmbeddingModel) -> Self {
Self {
model,
..Default::default()
}
}

/// Set the cache directory for downloaded models.
pub fn with_cache_dir(mut self, path: PathBuf) -> Self {
self.cache_dir = Some(path);
self
}

/// Set whether to show download progress.
pub fn with_show_download_progress(mut self, show: bool) -> Self {
self.show_download_progress = show;
self
}

/// Get the cache directory, using default if not specified.
pub fn get_cache_dir(&self) -> PathBuf {
self.cache_dir.clone().unwrap_or_else(|| {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".retrochat")
.join("models")
})
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_default_config() {
let config = EmbeddingConfig::default();
assert!(matches!(config.model, EmbeddingModel::BGESmallENV15));
assert!(config.cache_dir.is_none());
assert!(config.show_download_progress);
}

#[test]
fn test_config_builder() {
let config = EmbeddingConfig::new(EmbeddingModel::AllMiniLML6V2)
.with_cache_dir(PathBuf::from("/tmp/models"))
.with_show_download_progress(false);

assert!(matches!(config.model, EmbeddingModel::AllMiniLML6V2));
assert_eq!(config.cache_dir, Some(PathBuf::from("/tmp/models")));
assert!(!config.show_download_progress);
}

#[test]
fn test_get_cache_dir_default() {
let config = EmbeddingConfig::default();
let cache_dir = config.get_cache_dir();
assert!(cache_dir.to_string_lossy().contains(".retrochat"));
assert!(cache_dir.to_string_lossy().contains("models"));
}

#[test]
fn test_get_cache_dir_custom() {
let config = EmbeddingConfig::default().with_cache_dir(PathBuf::from("/custom/path"));
assert_eq!(config.get_cache_dir(), PathBuf::from("/custom/path"));
}
}
12 changes: 12 additions & 0 deletions crates/retrochat-core/src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//! Embedding generation module for semantic search.
//!
//! This module provides text embedding generation using FastEmbed-rs
//! for local, CPU-based inference without external API calls.

mod config;
mod models;
mod service;

pub use config::EmbeddingConfig;
pub use models::{EmbeddingModel, ModelInfo};
pub use service::EmbeddingService;
181 changes: 181 additions & 0 deletions crates/retrochat-core/src/embedding/models.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
//! Supported embedding models and their metadata.

use std::fmt;

/// Supported embedding models.
///
/// These map to FastEmbed model variants. Quantized versions (with Q suffix)
/// are smaller and faster but may have slightly lower quality.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EmbeddingModel {
/// BGE Small English v1.5 - Good balance of quality and speed.
/// 384 dimensions, ~33MB model size.
BGESmallENV15,

/// BGE Small English v1.5 (Quantized) - Faster, smaller.
/// 384 dimensions, ~17MB model size.
BGESmallENV15Q,

/// All MiniLM L6 v2 - Fast and lightweight.
/// 384 dimensions, ~23MB model size.
AllMiniLML6V2,

/// All MiniLM L6 v2 (Quantized) - Fastest option.
/// 384 dimensions, ~12MB model size.
AllMiniLML6V2Q,

/// BGE Base English v1.5 - Higher quality, slower.
/// 768 dimensions, ~110MB model size.
BGEBaseENV15,

/// BGE Base English v1.5 (Quantized).
/// 768 dimensions, ~55MB model size.
BGEBaseENV15Q,
}

impl EmbeddingModel {
/// Get the number of dimensions for this model's embeddings.
pub fn dimensions(&self) -> usize {
match self {
Self::BGESmallENV15 | Self::BGESmallENV15Q => 384,
Self::AllMiniLML6V2 | Self::AllMiniLML6V2Q => 384,
Self::BGEBaseENV15 | Self::BGEBaseENV15Q => 768,
}
}

/// Get the approximate model size in MB.
pub fn model_size_mb(&self) -> usize {
match self {
Self::BGESmallENV15 => 33,
Self::BGESmallENV15Q => 17,
Self::AllMiniLML6V2 => 23,
Self::AllMiniLML6V2Q => 12,
Self::BGEBaseENV15 => 110,
Self::BGEBaseENV15Q => 55,
}
}

/// Check if this is a quantized model variant.
pub fn is_quantized(&self) -> bool {
matches!(
self,
Self::BGESmallENV15Q | Self::AllMiniLML6V2Q | Self::BGEBaseENV15Q
)
}

/// Get the model name as used by FastEmbed.
pub fn fastembed_name(&self) -> &'static str {
match self {
Self::BGESmallENV15 => "BAAI/bge-small-en-v1.5",
Self::BGESmallENV15Q => "BAAI/bge-small-en-v1.5",
Self::AllMiniLML6V2 => "sentence-transformers/all-MiniLM-L6-v2",
Self::AllMiniLML6V2Q => "sentence-transformers/all-MiniLM-L6-v2",
Self::BGEBaseENV15 => "BAAI/bge-base-en-v1.5",
Self::BGEBaseENV15Q => "BAAI/bge-base-en-v1.5",
}
}
}

impl fmt::Display for EmbeddingModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::BGESmallENV15 => write!(f, "BGESmallENV15"),
Self::BGESmallENV15Q => write!(f, "BGESmallENV15Q"),
Self::AllMiniLML6V2 => write!(f, "AllMiniLML6V2"),
Self::AllMiniLML6V2Q => write!(f, "AllMiniLML6V2Q"),
Self::BGEBaseENV15 => write!(f, "BGEBaseENV15"),
Self::BGEBaseENV15Q => write!(f, "BGEBaseENV15Q"),
}
}
}

impl std::str::FromStr for EmbeddingModel {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"bgesmallenv15" | "bge-small-en-v1.5" => Ok(Self::BGESmallENV15),
"bgesmallenv15q" | "bge-small-en-v1.5-q" => Ok(Self::BGESmallENV15Q),
"allminiml6v2" | "all-minilm-l6-v2" => Ok(Self::AllMiniLML6V2),
"allminiml6v2q" | "all-minilm-l6-v2-q" => Ok(Self::AllMiniLML6V2Q),
"bgebaseenv15" | "bge-base-en-v1.5" => Ok(Self::BGEBaseENV15),
"bgebaseenv15q" | "bge-base-en-v1.5-q" => Ok(Self::BGEBaseENV15Q),
_ => Err(format!("Unknown embedding model: {s}")),
}
}
}

/// Information about the loaded embedding model.
#[derive(Debug, Clone)]
pub struct ModelInfo {
/// The model variant.
pub model: EmbeddingModel,

/// Human-readable model name.
pub name: String,

/// Number of embedding dimensions.
pub dimensions: usize,

/// Whether the model is quantized.
pub quantized: bool,
}

impl From<EmbeddingModel> for ModelInfo {
fn from(model: EmbeddingModel) -> Self {
Self {
name: model.to_string(),
dimensions: model.dimensions(),
quantized: model.is_quantized(),
model,
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_model_dimensions() {
assert_eq!(EmbeddingModel::BGESmallENV15.dimensions(), 384);
assert_eq!(EmbeddingModel::BGESmallENV15Q.dimensions(), 384);
assert_eq!(EmbeddingModel::AllMiniLML6V2.dimensions(), 384);
assert_eq!(EmbeddingModel::BGEBaseENV15.dimensions(), 768);
}

#[test]
fn test_model_is_quantized() {
assert!(!EmbeddingModel::BGESmallENV15.is_quantized());
assert!(EmbeddingModel::BGESmallENV15Q.is_quantized());
assert!(!EmbeddingModel::AllMiniLML6V2.is_quantized());
assert!(EmbeddingModel::AllMiniLML6V2Q.is_quantized());
}

#[test]
fn test_model_from_str() {
assert_eq!(
"BGESmallENV15".parse::<EmbeddingModel>().unwrap(),
EmbeddingModel::BGESmallENV15
);
assert_eq!(
"bge-small-en-v1.5".parse::<EmbeddingModel>().unwrap(),
EmbeddingModel::BGESmallENV15
);
assert!("invalid".parse::<EmbeddingModel>().is_err());
}

#[test]
fn test_model_display() {
assert_eq!(EmbeddingModel::BGESmallENV15.to_string(), "BGESmallENV15");
assert_eq!(EmbeddingModel::AllMiniLML6V2Q.to_string(), "AllMiniLML6V2Q");
}

#[test]
fn test_model_info_from_model() {
let info = ModelInfo::from(EmbeddingModel::BGESmallENV15Q);
assert_eq!(info.name, "BGESmallENV15Q");
assert_eq!(info.dimensions, 384);
assert!(info.quantized);
}
}
Loading