Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 153 additions & 3 deletions src/agentic_layer/memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
generate_multi_queries,
)
from agentic_layer.retrieval_utils import reciprocal_rank_fusion
from memory_layer.query_expansion import expand_query, merge_hits_by_id

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -623,7 +624,14 @@ async def get_vector_search_results(
async def retrieve_mem_hybrid(
self, retrieve_mem_request: 'RetrieveMemRequest'
) -> RetrieveMemResponse:
"""Hybrid memory retrieval: keyword + vector + rerank"""
"""Hybrid memory retrieval with LLM query expansion.

Generates 2-3 paraphrase variants of the query before retrieval so
that memories using different vocabulary are still recalled. Results
from all variants are union-merged (deduplicated by memory id) and
re-ranked against the original query. Falls back to plain hybrid
retrieval if expansion fails or produces no variants.
"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
Expand All @@ -632,7 +640,7 @@ async def retrieve_mem_hybrid(
)

try:
hits = await self._search_hybrid(
hits = await self._search_hybrid_with_query_expansion(
retrieve_mem_request, retrieve_method=RetrieveMethod.HYBRID.value
)
duration = time.perf_counter() - start_time
Expand Down Expand Up @@ -699,7 +707,7 @@ async def _search_hybrid(
request: 'RetrieveMemRequest',
retrieve_method: str = RetrieveMethod.HYBRID.value,
) -> List[Dict]:
"""Core hybrid search: keyword + vector + rerank, returns flat list"""
"""Core hybrid search: keyword + vector + rerank, returns flat list."""
memory_type = (
request.memory_types[0].value if request.memory_types else 'unknown'
)
Expand All @@ -717,6 +725,148 @@ async def _search_hybrid(
request.query, merged_results, request.top_k, memory_type, retrieve_method
)

async def _search_hybrid_with_query_expansion(
self,
request: 'RetrieveMemRequest',
retrieve_method: str = RetrieveMethod.HYBRID.value,
) -> List[Dict]:
"""Hybrid search augmented with LLM query expansion (RAG-Fusion style).

Before running hybrid retrieval, an LLM generates 2-3 paraphrase
variants of the original query. Retrieval is executed in parallel for
the original query *and* every variant. Results are then union-merged
and deduplicated by memory ``id``, ensuring that memories using
different vocabulary than the original query can still be surfaced.

Falls back silently to plain ``_search_hybrid`` when:
- No query text is present in the request.
- The LLM call for expansion fails.
- The expansion produces no usable variants.

The final rerank step always uses the *original* query so that
relevance scoring is consistent with the user's intent.
"""
original_query: Optional[str] = request.query

# Skip expansion when there is no query (e.g. keyword-only or empty)
if not original_query or not original_query.strip():
return await self._search_hybrid(request, retrieve_method=retrieve_method)

# Build a lightweight LLM provider for expansion (reuses env config)
try:
llm_provider = LLMProvider(
provider_type=os.getenv("LLM_PROVIDER", "openai"),
model=os.getenv("LLM_MODEL", "openai/gpt-4.1-mini"), # skip-sensitive-check
base_url=os.getenv("LLM_BASE_URL"),
api_key=os.getenv("LLM_API_KEY", "your-api-key"), # skip-sensitive-check
temperature=float(os.getenv("LLM_TEMPERATURE", "0.3")),
max_tokens=256,
)
except Exception as exc:
logger.warning(
"_search_hybrid_with_query_expansion: could not create LLM "
"provider, falling back to plain hybrid search: %s",
exc,
)
return await self._search_hybrid(request, retrieve_method=retrieve_method)

# Generate paraphrase variants (2 variants → 3 queries total with original)
variants: List[str] = await expand_query(
query=original_query,
llm_provider=llm_provider,
n_variants=2,
temperature=0.6,
)

if not variants:
# Expansion failed or produced nothing — plain hybrid is fine
logger.debug(
"_search_hybrid_with_query_expansion: no variants generated, "
"falling back to plain hybrid search"
)
return await self._search_hybrid(request, retrieve_method=retrieve_method)

all_queries: List[str] = [original_query] + variants
logger.info(
"_search_hybrid_with_query_expansion: running %d queries in parallel "
"(original + %d variants)",
len(all_queries),
len(variants),
)

# Run hybrid search (keyword + vector) for each query in parallel.
# Note: we intentionally skip the rerank step here — rerank happens
# once on the merged result set below using the original query.
async def _raw_hybrid(q: str) -> List[Dict]:
"""Hybrid search without rerank for a single query string."""
memory_type = (
request.memory_types[0].value if request.memory_types else 'unknown'
)
sub_request = RetrieveMemRequest(
query=q,
user_id=request.user_id,
group_id=request.group_id,
# Fetch more candidates per variant so the union is rich
top_k=min(request.top_k * 2, 100),
memory_types=request.memory_types,
start_time=request.start_time,
end_time=request.end_time,
retrieve_method=request.retrieve_method,
radius=request.radius,
)
kw, vec = await asyncio.gather(
self.get_keyword_search_results(
sub_request, retrieve_method=retrieve_method
),
self.get_vector_search_results(
sub_request, retrieve_method=retrieve_method
),
)
seen: set[str] = {h.get('id') for h in kw}
return kw + [h for h in vec if h.get('id') not in seen]

raw_results = await asyncio.gather(
*[_raw_hybrid(q) for q in all_queries], return_exceptions=True
)

# Collect valid (non-exception) result sets
hits_per_query: List[List[Dict]] = []
for i, result in enumerate(raw_results):
if isinstance(result, Exception):
logger.warning(
"_search_hybrid_with_query_expansion: query %d failed: %s",
i,
result,
)
else:
hits_per_query.append(result)

if not hits_per_query:
# All parallel searches failed; fall back to original query
logger.warning(
"_search_hybrid_with_query_expansion: all expanded queries "
"failed, retrying with original query only"
)
return await self._search_hybrid(request, retrieve_method=retrieve_method)

# Union-merge, deduplicated by memory id. The original-query hits come
# first (index 0) so their scores are preserved when there are ties.
merged: List[Dict] = merge_hits_by_id(hits_per_query)
logger.info(
"_search_hybrid_with_query_expansion: merged %d unique hits from "
"%d query result sets",
len(merged),
len(hits_per_query),
)

# Final rerank on the *original* query keeps scoring consistent
memory_type_str = (
request.memory_types[0].value if request.memory_types else 'unknown'
)
return await self._rerank(
original_query, merged, request.top_k, memory_type_str, retrieve_method
)

async def _search_rrf(
self,
request: 'RetrieveMemRequest',
Expand Down
178 changes: 178 additions & 0 deletions src/memory_layer/query_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""LLM-based query expansion for improved memory retrieval.

Generates paraphrase variants of a query so that retrieval can surface
memories that use different vocabulary than the original query. Results
from each variant are merged (union, deduplicated by memory id) using
Reciprocal Rank Fusion.

Inspired by RAG-Fusion / HyDE: instead of a single fixed query string,
we produce 2-3 semantically equivalent rewordings and fuse their recall
sets. This directly addresses the "vocabulary mismatch" failure mode
where stored memories use different terms than the retrieval query
(e.g. "rescue inhaler protocol" stored vs "gym bag" queried, or vice
versa).
"""

from __future__ import annotations

import asyncio
import json
import logging
from typing import Any, Dict, List, Optional, Tuple

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Prompt template
# ---------------------------------------------------------------------------

_PARAPHRASE_PROMPT = """\
You are a query expansion assistant helping a memory retrieval system.

Given the user query below, generate {n} paraphrase variants that capture
the same intent using different vocabulary. The variants should:
- Cover synonyms and related terms the user might not have typed
- Use both formal and informal phrasings when applicable
- Stay concise (under 120 characters each)
- NOT repeat the original query verbatim

Original query:
{query}

Reply with ONLY a JSON object in this exact format (no markdown fences):
{{
"variants": [
"paraphrase 1",
"paraphrase 2",
"paraphrase 3"
]
}}
"""


# ---------------------------------------------------------------------------
# Core function
# ---------------------------------------------------------------------------


async def expand_query(
query: str,
llm_provider: Any,
n_variants: int = 2,
temperature: float = 0.6,
) -> List[str]:
"""Generate paraphrase variants for *query* using the LLM.

Args:
query: The original retrieval query.
llm_provider: Any object exposing ``async generate(prompt, temperature,
max_tokens) -> str``; the ``LLMProvider`` from
``memory_layer.llm.llm_provider`` satisfies this interface.
n_variants: How many paraphrases to request (default 2; capped at 3).
temperature: Sampling temperature for the LLM call. Higher values
produce more diverse paraphrases.

Returns:
A list of paraphrase strings. On any failure the list is empty so
callers can fall back to the original query without crashing.
"""
# Guard: don't bother for very short queries or if provider absent
if not query or not query.strip() or llm_provider is None:
return []

n_variants = max(1, min(n_variants, 3))

prompt = _PARAPHRASE_PROMPT.format(query=query.strip(), n=n_variants)

try:
raw = await llm_provider.generate(
prompt=prompt,
temperature=temperature,
max_tokens=256,
)
except Exception as exc: # noqa: BLE001
logger.warning(
"query_expansion: LLM call failed, skipping expansion: %s", exc
)
return []

return _parse_variants(raw, query, n_variants)


def _parse_variants(
raw: str,
original_query: str,
max_variants: int,
) -> List[str]:
"""Extract and validate the paraphrase list from the LLM response."""
try:
# Strip any accidental markdown fences
text = raw.strip()
if text.startswith("```"):
lines = text.splitlines()
# drop first and last fence lines
text = "\n".join(lines[1:-1] if lines[-1].startswith("```") else lines[1:])

start = text.find("{")
end = text.rfind("}") + 1
if start == -1 or end == 0:
raise ValueError("No JSON object found in LLM response")

parsed: Dict[str, Any] = json.loads(text[start:end])
variants_raw: List[Any] = parsed.get("variants", [])

if not isinstance(variants_raw, list):
raise ValueError("'variants' is not a list")

variants: List[str] = []
orig_lower = original_query.lower().strip()
for item in variants_raw[:max_variants]:
if not isinstance(item, str):
continue
item = item.strip()
# Drop empty strings or verbatim copies of the original query
if item and item.lower() != orig_lower:
variants.append(item)

logger.debug("query_expansion: generated %d variants: %s", len(variants), variants)
return variants

except Exception as exc: # noqa: BLE001
logger.warning(
"query_expansion: failed to parse LLM response, skipping: %s", exc
)
return []


# ---------------------------------------------------------------------------
# Merge helper (union dedup by dict key "id")
# ---------------------------------------------------------------------------


def merge_hits_by_id(
hits_per_query: List[List[Dict[str, Any]]],
) -> List[Dict[str, Any]]:
"""Union-merge multiple hit lists, deduplicating by the ``id`` field.

The first occurrence of each ``id`` wins (preserves the score from the
query whose results are listed first — typically the original query).

Args:
hits_per_query: One list of hit dicts per query (original + variants).

Returns:
A flat deduplicated list of hit dicts.
"""
seen_ids: set[str] = set()
merged: List[Dict[str, Any]] = []

for hits in hits_per_query:
for hit in hits:
hit_id: str = hit.get("id", "")
if hit_id and hit_id in seen_ids:
continue
if hit_id:
seen_ids.add(hit_id)
merged.append(hit)

return merged