Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
349 changes: 340 additions & 9 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ num_cpus = "1.16"
indicatif = "0.17"
inquire = "0.7"
console = "0.15"
sqlite-vec = "0.1.6"
regex = "1.10"
lazy_static = "1.4"

# MLX for embedding generation (macOS only)
[target.'cfg(target_os = "macos")'.dependencies]
mlx-rs = { version = "0.25", optional = true }

[features]
default = ["reqwest"]
reqwest = ["dep:reqwest"]
mlx = ["mlx-rs"]

[dev-dependencies]
tempfile = "3.8"
Expand Down Expand Up @@ -164,4 +170,4 @@ path = "tests/contract/test_cli_add_command.rs"
[profile.release]
lto = true
codegen-units = 1
panic = "abort"
panic = "abort"
11 changes: 11 additions & 0 deletions migrations/011_add_message_embeddings.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- Add embedding column to messages table
-- Migration: 008_add_message_embeddings
-- Description: Add embedding vector column (768 dimensions) for semantic search

-- Add embedding column to messages table (can be NULL for now)
-- Using BLOB to store float32 vectors of 768 dimensions
ALTER TABLE messages ADD COLUMN embedding BLOB;

-- Create virtual table for vector similarity search using sqlite-vec
-- Note: This will be created dynamically when needed in the application code
-- since sqlite-vec needs to be loaded as an extension first
15 changes: 13 additions & 2 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ pub enum Commands {
/// Messages until this time (e.g., "now", "2024-10-31", "today")
#[arg(long)]
until: Option<String>,
/// Use embedding-based semantic search instead of full-text search
#[arg(long)]
use_embedding: bool,
},
/// [Alias for 'retrospect execute'] Review and analyze a chat session
Review {
Expand Down Expand Up @@ -187,6 +190,9 @@ pub enum QueryCommands {
/// Messages until this time (e.g., "now", "2024-10-31", "today")
#[arg(long)]
until: Option<String>,
/// Use embedding-based semantic search instead of full-text search
#[arg(long)]
use_embedding: bool,
},
/// Query messages by time range
Timeline {
Expand Down Expand Up @@ -286,7 +292,11 @@ impl Cli {
limit,
since,
until,
} => query::handle_search_command(query, limit, since, until).await,
use_embedding,
} => {
query::handle_search_command(query, limit, since, until, use_embedding)
.await
}
QueryCommands::Timeline {
since,
until,
Expand Down Expand Up @@ -363,7 +373,8 @@ impl Cli {
limit,
since,
until,
} => query::handle_search_command(query, limit, since, until).await,
use_embedding,
} => query::handle_search_command(query, limit, since, until, use_embedding).await,
Commands::Review { session_id } => {
// For now, delegate to retrospect execute
// TODO: Could make this more interactive
Expand Down
9 changes: 8 additions & 1 deletion src/cli/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ pub async fn handle_search_command(
limit: Option<i32>,
since: Option<String>,
until: Option<String>,
use_embedding: bool,
) -> Result<()> {
let db_path = crate::database::config::get_default_db_path()?;
let db_manager = DatabaseManager::new(&db_path).await?;
Expand Down Expand Up @@ -154,14 +155,20 @@ pub async fn handle_search_command(
None
};

let search_type = if use_embedding {
Some("embedding".to_string())
} else {
None
};

let request = SearchRequest {
query,
page: Some(1),
page_size: limit,
date_range,
projects: None,
providers: None,
search_type: None,
search_type,
};

let response = query_service.search_messages(request).await?;
Expand Down
53 changes: 44 additions & 9 deletions src/database/message_repo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,37 @@ impl MessageRepository {
}
}

/// Convert embedding vector to BLOB (bytes) for storage
fn embedding_to_blob(embedding: &[f32]) -> Vec<u8> {
embedding.iter().flat_map(|&f| f.to_le_bytes()).collect()
}

/// Convert BLOB (bytes) back to embedding vector
fn blob_to_embedding(blob: &[u8]) -> Option<Vec<f32>> {
if !blob.len().is_multiple_of(4) {
return None;
}

let mut embedding = Vec::with_capacity(blob.len() / 4);
for chunk in blob.chunks_exact(4) {
let bytes: [u8; 4] = chunk.try_into().ok()?;
embedding.push(f32::from_le_bytes(bytes));
}
Some(embedding)
}

pub async fn create(&self, message: &Message) -> AnyhowResult<()> {
let embedding_blob = message
.embedding
.as_ref()
.map(|emb| Self::embedding_to_blob(emb));

sqlx::query(
r#"
INSERT INTO messages (
id, session_id, role, content, timestamp, token_count,
metadata, sequence_number, message_type, tool_operation_id
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
metadata, sequence_number, message_type, tool_operation_id, embedding
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(message.id.to_string())
Expand All @@ -37,6 +61,7 @@ impl MessageRepository {
.bind(message.sequence_number)
.bind(message.message_type.to_string())
.bind(message.tool_operation_id.map(|id| id.to_string()))
.bind(embedding_blob)
.execute(&self.pool)
.await
.context("Failed to create message")?;
Expand All @@ -48,7 +73,7 @@ impl MessageRepository {
let row = sqlx::query(
r#"
SELECT id, session_id, role, content, timestamp, token_count,
metadata, sequence_number, message_type, tool_operation_id
metadata, sequence_number, message_type, tool_operation_id, embedding
FROM messages
WHERE id = ?
"#,
Expand All @@ -71,7 +96,7 @@ impl MessageRepository {
let rows = sqlx::query(
r#"
SELECT id, session_id, role, content, timestamp, token_count,
metadata, sequence_number, message_type, tool_operation_id
metadata, sequence_number, message_type, tool_operation_id, embedding
FROM messages
WHERE session_id = ?
ORDER BY sequence_number ASC
Expand Down Expand Up @@ -142,7 +167,7 @@ impl MessageRepository {
let mut sql = r#"
SELECT m.id, m.session_id, m.role, m.content, m.timestamp,
m.token_count, m.metadata, m.sequence_number,
m.message_type, m.tool_operation_id
m.message_type, m.tool_operation_id, m.embedding
FROM messages m
JOIN messages_fts fts ON m.rowid = fts.rowid
WHERE messages_fts MATCH ?
Expand Down Expand Up @@ -197,7 +222,7 @@ impl MessageRepository {
let mut sql = r#"
SELECT m.id, m.session_id, m.role, m.content, m.timestamp,
m.token_count, m.metadata, m.sequence_number,
m.message_type, m.tool_operation_id
m.message_type, m.tool_operation_id, m.embedding
FROM messages m
JOIN messages_fts fts ON m.rowid = fts.rowid
WHERE messages_fts MATCH ?
Expand Down Expand Up @@ -291,7 +316,7 @@ impl MessageRepository {
r#"
SELECT m.id, m.session_id, m.role, m.content, m.timestamp,
m.token_count, m.metadata, m.sequence_number,
m.message_type, m.tool_operation_id
m.message_type, m.tool_operation_id, m.embedding
FROM messages m
"#,
);
Expand Down Expand Up @@ -376,12 +401,17 @@ impl MessageRepository {
.context("Failed to start transaction")?;

for message in messages {
let embedding_blob = message
.embedding
.as_ref()
.map(|emb| Self::embedding_to_blob(emb));

sqlx::query(
r#"
INSERT INTO messages (
id, session_id, role, content, timestamp, token_count,
metadata, sequence_number, message_type, tool_operation_id
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
metadata, sequence_number, message_type, tool_operation_id, embedding
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(message.id.to_string())
Expand All @@ -394,6 +424,7 @@ impl MessageRepository {
.bind(message.sequence_number)
.bind(message.message_type.to_string())
.bind(message.tool_operation_id.map(|id| id.to_string()))
.bind(embedding_blob)
.execute(&mut *tx)
.await
.context("Failed to insert message in bulk")?;
Expand All @@ -417,6 +448,7 @@ impl MessageRepository {
let sequence_number: i64 = row.try_get("sequence_number")?;
let message_type_str: String = row.try_get("message_type")?;
let tool_operation_id_str: Option<String> = row.try_get("tool_operation_id")?;
let embedding_blob: Option<Vec<u8>> = row.try_get("embedding")?;

let id = Uuid::parse_str(&id_str).context("Invalid message ID format")?;
let session_id = Uuid::parse_str(&session_id_str).context("Invalid session ID format")?;
Expand All @@ -433,6 +465,8 @@ impl MessageRepository {
None
};

let embedding = embedding_blob.and_then(|blob| Self::blob_to_embedding(&blob));

let metadata: Option<serde_json::Value> = serde_json::from_str("{}").ok();

Ok(Message {
Expand All @@ -448,6 +482,7 @@ impl MessageRepository {
tool_operation_id,
tool_uses: None,
tool_results: None,
embedding,
})
}
}
8 changes: 8 additions & 0 deletions src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,11 @@ pub mod retrospection {
/// Maximum concurrent analysis operations
pub const CONCURRENT: &str = "RETROCHAT_CONCURRENT";
}

/// Embedding service configuration
pub mod embedding {
/// Enable MLX-based embedding generation (macOS only)
/// When enabled on macOS, uses MLX for embedding extraction
/// Shows warning and disables on unsupported platforms (Windows, Linux)
pub const USE_MLX: &str = "RETROCHAT_USE_MLX";
}
9 changes: 9 additions & 0 deletions src/models/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ pub struct Message {
pub tool_uses: Option<Vec<ToolUse>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_results: Option<Vec<ToolResult>>,
/// Embedding vector for semantic search (768 dimensions, f32 values)
#[serde(skip_serializing_if = "Option::is_none")]
pub embedding: Option<Vec<f32>>,
}

impl Message {
Expand All @@ -142,6 +145,7 @@ impl Message {
tool_operation_id: None,
tool_uses: None,
tool_results: None,
embedding: None,
}
}

Expand Down Expand Up @@ -177,6 +181,11 @@ impl Message {
self
}

pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.embedding = Some(embedding);
self
}

pub fn is_valid(&self) -> bool {
!self.content.is_empty()
}
Expand Down
Loading