diff --git a/src/services/client.ts b/src/services/client.ts index 79224a7..3c64ea7 100644 --- a/src/services/client.ts +++ b/src/services/client.ts @@ -97,10 +97,13 @@ export class LocalMemoryClient { } async searchMemories(query: string, containerTag: string) { + return this.hybridSearch(query, containerTag); + } + + async fullTextSearch(query: string, containerTag: string) { try { await this.initialize(); - const queryVector = await embeddingService.embedWithTimeout(query); const { scope, hash } = extractScopeFromContainerTag(containerTag); const shards = shardManager.getAllShards(scope, hash); @@ -108,19 +111,114 @@ export class LocalMemoryClient { return { success: true as const, results: [], total: 0, timing: 0 }; } - const results = await vectorSearch.searchAcrossShards( - shards, - queryVector, - containerTag, - CONFIG.maxMemories, - CONFIG.similarityThreshold, - query + const shardPromises = shards.map(async (shard) => { + try { + const db = connectionManager.getConnection(shard.dbPath); + return vectorSearch.fullTextSearch(db, query, containerTag, CONFIG.maxMemories); + } catch (error) { + log("fullTextSearch: shard search error", { shardId: shard.id, error: String(error) }); + return []; + } + }); + + const results = (await Promise.all(shardPromises)).flat(); + const dedupedById = new Map(); + + for (const result of results) { + const existing = dedupedById.get(result.id); + if (!existing || result.similarity > existing.similarity) { + dedupedById.set(result.id, result); + } + } + + const dedupedResults = Array.from(dedupedById.values()).sort( + (a, b) => b.similarity - a.similarity ); - return { success: true as const, results, total: results.length, timing: 0 }; + return { + success: true as const, + results: dedupedResults.slice(0, CONFIG.maxMemories), + total: dedupedResults.length, + timing: 0, + }; + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error); + log("fullTextSearch: error", { error: errorMessage }); + return { success: false as const, error: errorMessage, results: [], total: 0, timing: 0 }; + } + } + + async hybridSearch(query: string, containerTag: string) { + try { + await this.initialize(); + + const { scope, hash } = extractScopeFromContainerTag(containerTag); + const shards = shardManager.getAllShards(scope, hash); + + if (shards.length === 0) { + return { success: true as const, results: [], total: 0, timing: 0 }; + } + + const vectorPromise = (async () => { + const queryVector = await embeddingService.embedWithTimeout(query); + return vectorSearch.searchAcrossShards( + shards, + queryVector, + containerTag, + CONFIG.maxMemories, + CONFIG.similarityThreshold, + query + ); + })(); + + const ftsPromise = this.fullTextSearch(query, containerTag); + const [vectorResults, ftsResponse] = await Promise.all([vectorPromise, ftsPromise]); + const ftsResults = ftsResponse.success ? ftsResponse.results : []; + + const fusedById = new Map(); + + for (const [i, result] of vectorResults.entries()) { + const vectorRank = 1 / (i + 1); + fusedById.set(result.id, { + result, + score: 0.6 * vectorRank, + }); + } + + for (const [i, result] of ftsResults.entries()) { + const ftsRank = 1 / (i + 1); + const existing = fusedById.get(result.id); + + if (existing) { + existing.score += 0.4 * ftsRank; + if (result.similarity > existing.result.similarity) { + existing.result = result; + } + } else { + fusedById.set(result.id, { + result, + score: 0.4 * ftsRank, + }); + } + } + + const combinedResults = Array.from(fusedById.values()) + .sort((a, b) => b.score - a.score) + .slice(0, CONFIG.maxMemories) + .map((entry) => ({ + ...entry.result, + similarity: entry.score, + })); + + return { + success: true as const, + results: combinedResults, + total: combinedResults.length, + timing: 0, + }; } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); - log("searchMemories: error", { error: errorMessage }); + log("hybridSearch: error", { error: errorMessage }); return { success: false as const, error: errorMessage, results: [], total: 0, timing: 0 }; } } diff --git a/src/services/sqlite/connection-manager.ts b/src/services/sqlite/connection-manager.ts index 34b3118..ad5b966 100644 --- a/src/services/sqlite/connection-manager.ts +++ b/src/services/sqlite/connection-manager.ts @@ -112,8 +112,9 @@ export class ConnectionManager { try { const columns = db.prepare("PRAGMA table_info(memories)").all() as any[]; const hasTags = columns.some((c) => c.name === "tags"); + const hasMemoriesTable = columns.length > 0; - if (!hasTags && columns.length > 0) { + if (!hasTags && hasMemoriesTable) { db.run("ALTER TABLE memories ADD COLUMN tags TEXT"); } @@ -123,6 +124,32 @@ export class ConnectionManager { embedding float32[${CONFIG.embeddingDimensions}] distance_metric=cosine ) `); + + if (hasMemoriesTable) { + const hasFtsTable = Boolean( + db + .prepare("SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'memories_fts'") + .get() + ); + + db.run(` + CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5( + content, + tags, + container_tag UNINDEXED, + memory_id UNINDEXED, + tokenize='porter unicode61' + ) + `); + + if (!hasFtsTable) { + db.run(` + INSERT INTO memories_fts(content, tags, container_tag, memory_id) + SELECT content, COALESCE(tags, ''), container_tag, id + FROM memories + `); + } + } } catch (error) { log("Schema migration error", { error: String(error) }); } diff --git a/src/services/sqlite/shard-manager.ts b/src/services/sqlite/shard-manager.ts index 11abf9f..318e62d 100644 --- a/src/services/sqlite/shard-manager.ts +++ b/src/services/sqlite/shard-manager.ts @@ -182,6 +182,16 @@ export class ShardManager { ) `); + db.run(` + CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5( + content, + tags, + container_tag UNINDEXED, + memory_id UNINDEXED, + tokenize='porter unicode61' + ) + `); + db.run(`CREATE INDEX IF NOT EXISTS idx_container_tag ON memories(container_tag)`); db.run(`CREATE INDEX IF NOT EXISTS idx_type ON memories(type)`); db.run(`CREATE INDEX IF NOT EXISTS idx_created_at ON memories(created_at DESC)`); diff --git a/src/services/sqlite/vector-search.ts b/src/services/sqlite/vector-search.ts index 1503de8..5eb2f8f 100644 --- a/src/services/sqlite/vector-search.ts +++ b/src/services/sqlite/vector-search.ts @@ -4,6 +4,40 @@ import { log } from "../logger.js"; import type { MemoryRecord, SearchResult, ShardInfo } from "./types.js"; export class VectorSearch { + private sanitizeFtsQuery(query: string): string { + const tokens = query + .normalize("NFKC") + .toLowerCase() + .split(/\s+/) + .map((token) => token.replace(/[^\p{L}\p{N}_-]/gu, "").trim()) + .filter(Boolean) + .slice(0, 24); + + if (tokens.length === 0) { + return ""; + } + + return tokens.map((token) => `"${token}"*`).join(" AND "); + } + + private mapRowToSearchResult(row: any, similarity: number) { + return { + id: row.id, + memory: row.content, + similarity, + tags: row.tags ? row.tags.split(",") : [], + metadata: row.metadata ? JSON.parse(row.metadata) : undefined, + containerTag: row.container_tag, + displayName: row.display_name, + userName: row.user_name, + userEmail: row.user_email, + projectPath: row.project_path, + projectName: row.project_name, + gitRepoUrl: row.git_repo_url, + isPinned: row.is_pinned, + }; + } + insertVector(db: Database, record: MemoryRecord): void { const insertMemory = db.prepare(` INSERT INTO memories ( @@ -37,6 +71,11 @@ export class VectorSearch { `); insertVec.run(record.id, vectorBuffer); + const insertFts = db.prepare(` + INSERT INTO memories_fts (content, tags, container_tag, memory_id) VALUES (?, ?, ?, ?) + `); + insertFts.run(record.content, record.tags || "", record.containerTag, record.id); + if (record.tagsVector) { const tagsVectorBuffer = new Uint8Array(record.tagsVector.buffer); const insertTagsVec = db.prepare(` @@ -124,21 +163,7 @@ export class VectorSearch { const tagSim = Math.max(scores.tagsSim, exactMatchBoost); const similarity = tagSim * 0.8 + scores.contentSim * 0.2; - return { - id: row.id, - memory: row.content, - similarity, - tags: memoryTagsStr ? memoryTagsStr.split(",") : [], - metadata: row.metadata ? JSON.parse(row.metadata) : undefined, - containerTag: row.container_tag, - displayName: row.display_name, - userName: row.user_name, - userEmail: row.user_email, - projectPath: row.project_path, - projectName: row.project_name, - gitRepoUrl: row.git_repo_url, - isPinned: row.is_pinned, - }; + return this.mapRowToSearchResult(row, similarity); }); } @@ -166,7 +191,35 @@ export class VectorSearch { return allResults.filter((r) => r.similarity >= similarityThreshold).slice(0, limit); } + fullTextSearch(db: Database, query: string, containerTag: string, limit: number): SearchResult[] { + const sanitizedQuery = this.sanitizeFtsQuery(query); + + if (!sanitizedQuery) { + return []; + } + + const rows = db + .prepare( + ` + SELECT m.*, fts.rank + FROM memories_fts fts + JOIN memories m ON m.id = fts.memory_id + WHERE memories_fts MATCH ? AND fts.container_tag = ? + ORDER BY fts.rank + LIMIT ? + ` + ) + .all(sanitizedQuery, containerTag, limit) as any[]; + + return rows.map((row: any, index: number) => { + const rank = Number(row.rank); + const rankSimilarity = Number.isFinite(rank) ? 1 / (1 + Math.max(rank, 0)) : 1 / (index + 1); + return this.mapRowToSearchResult(row, rankSimilarity); + }); + } + deleteVector(db: Database, memoryId: string): void { + db.prepare(`DELETE FROM memories_fts WHERE memory_id = ?`).run(memoryId); db.prepare(`DELETE FROM vec_memories WHERE memory_id = ?`).run(memoryId); db.prepare(`DELETE FROM vec_tags WHERE memory_id = ?`).run(memoryId); db.prepare(`DELETE FROM memories WHERE id = ?`).run(memoryId);