Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ services:
- memos_network

neo4j:
image: neo4j:5.26.4
image: neo4j:5.26.6
container_name: neo4j-docker
ports:
- "7474:7474" # HTTP
Expand Down
83 changes: 48 additions & 35 deletions examples/basic_modules/neo4j_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,36 @@

from datetime import datetime

from dotenv import load_dotenv

from memos.configs.embedder import EmbedderConfigFactory
from memos.configs.graph_db import GraphDBConfigFactory
from memos.embedders.factory import EmbedderFactory
from memos.graph_dbs.factory import GraphStoreFactory
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata


load_dotenv()

NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "12345678")
NEO4J_DB_NAME = os.getenv("NEO4J_DB_NAME", "neo4j")
EMBEDDING_DIMENSION = int(os.getenv("EMBEDDING_DIMENSION", "3072"))

QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))

embedder_config = EmbedderConfigFactory.model_validate(
{
"backend": "universal_api",
"backend": os.getenv("MOS_EMBEDDER_BACKEND", "universal_api"),
"config": {
"provider": "openai",
"api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
"model_name_or_path": "text-embedding-3-large",
"base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
"provider": os.getenv("MOS_EMBEDDER_PROVIDER", "openai"),
"api_key": os.getenv("MOS_EMBEDDER_API_KEY", os.getenv("OPENAI_API_KEY", "")),
"model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"),
"base_url": os.getenv(
"MOS_EMBEDDER_API_BASE", os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
),
},
}
)
Expand All @@ -31,12 +46,12 @@ def get_neo4j_graph(db_name: str = "paper"):
config = GraphDBConfigFactory(
backend="neo4j",
config={
"uri": "bolt://xxxx:7687",
"user": "neo4j",
"password": "xxxx",
"uri": NEO4J_URI,
"user": NEO4J_USER,
"password": NEO4J_PASSWORD,
"db_name": db_name,
"auto_create": True,
"embedding_dimension": 3072,
"embedding_dimension": EMBEDDING_DIMENSION,
"use_multi_db": True,
},
)
Expand All @@ -49,12 +64,12 @@ def example_multi_db(db_name: str = "paper"):
config = GraphDBConfigFactory(
backend="neo4j",
config={
"uri": "bolt://localhost:7687",
"user": "neo4j",
"password": "12345678",
"uri": NEO4J_URI,
"user": NEO4J_USER,
"password": NEO4J_PASSWORD,
"db_name": db_name,
"auto_create": True,
"embedding_dimension": 3072,
"embedding_dimension": EMBEDDING_DIMENSION,
"use_multi_db": True,
},
)
Expand Down Expand Up @@ -288,14 +303,14 @@ def example_shared_db(db_name: str = "shared-traval-group"):
config = GraphDBConfigFactory(
backend="neo4j",
config={
"uri": "bolt://localhost:7687",
"user": "neo4j",
"password": "12345678",
"uri": NEO4J_URI,
"user": NEO4J_USER,
"password": NEO4J_PASSWORD,
"db_name": db_name,
"user_name": user_name,
"use_multi_db": False,
"auto_create": True,
"embedding_dimension": 3072,
"embedding_dimension": EMBEDDING_DIMENSION,
},
)
# Step 2: Instantiate graph store
Expand Down Expand Up @@ -353,12 +368,12 @@ def example_shared_db(db_name: str = "shared-traval-group"):
config_alice = GraphDBConfigFactory(
backend="neo4j",
config={
"uri": "bolt://localhost:7687",
"user": "neo4j",
"password": "12345678",
"uri": NEO4J_URI,
"user": NEO4J_USER,
"password": NEO4J_PASSWORD,
"db_name": db_name,
"user_name": user_list[0],
"embedding_dimension": 3072,
"embedding_dimension": EMBEDDING_DIMENSION,
},
)
graph_alice = GraphStoreFactory.from_config(config_alice)
Expand All @@ -382,24 +397,22 @@ def run_user_session(
config = GraphDBConfigFactory(
backend="neo4j-community",
config={
"uri": "bolt://localhost:7687",
"user": "neo4j",
"password": "12345678",
"uri": NEO4J_URI,
"user": NEO4J_USER,
"password": NEO4J_PASSWORD,
"db_name": db_name,
"user_name": user_name,
"use_multi_db": False,
"auto_create": False, # Neo4j Community does not allow auto DB creation
"embedding_dimension": 3072,
"auto_create": False,
"embedding_dimension": EMBEDDING_DIMENSION,
"vec_config": {
# Pass nested config to initialize external vector DB
# If you use qdrant, please use Server instead of local mode.
"backend": "qdrant",
"config": {
"collection_name": "neo4j_vec_db",
"vector_dimension": 3072,
"vector_dimension": EMBEDDING_DIMENSION,
"distance_metric": "cosine",
"host": "localhost",
"port": 6333,
"host": QDRANT_HOST,
"port": QDRANT_PORT,
},
},
},
Expand All @@ -408,14 +421,14 @@ def run_user_session(
config = GraphDBConfigFactory(
backend="neo4j",
config={
"uri": "bolt://localhost:7687",
"user": "neo4j",
"password": "12345678",
"uri": NEO4J_URI,
"user": NEO4J_USER,
"password": NEO4J_PASSWORD,
"db_name": db_name,
"user_name": user_name,
"use_multi_db": False,
"auto_create": True,
"embedding_dimension": 3072,
"embedding_dimension": EMBEDDING_DIMENSION,
},
)
graph = GraphStoreFactory.from_config(config)
Expand Down
1 change: 1 addition & 0 deletions src/memos/api/server_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def health_check():
"version": app.version,
}


# Request validation failed
app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler)
# Invalid business code parameters
Expand Down
54 changes: 36 additions & 18 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
metadata["embedding"] = [float(x) for x in embedding]

# serialization
if metadata["sources"]:
if metadata.get("sources"):
for idx in range(len(metadata["sources"])):
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
return metadata
Expand Down Expand Up @@ -73,7 +73,10 @@ def _flatten_info_fields(metadata: dict[str, Any]) -> dict[str, Any]:


class Neo4jGraphDB(BaseGraphDB):
"""Neo4j-based implementation of a graph memory store."""
"""Neo4j-based implementation of a graph memory store.

Requires Neo4j >= 5.18 for vector.similarity.cosine() pre-filtering support.
"""

@require_python_package(
import_name="neo4j",
Expand Down Expand Up @@ -226,7 +229,7 @@ def add_node(
"""

# serialization
if metadata["sources"]:
if metadata.get("sources"):
for idx in range(len(metadata["sources"])):
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])

Expand Down Expand Up @@ -843,13 +846,14 @@ def search_by_embedding(
If return_fields is specified, each dict also includes the requested fields.

Notes:
- This method uses Neo4j native vector indexing to search for similar nodes.
- If scope is provided, it restricts results to nodes with matching memory_type.
- If 'status' is provided, only nodes with the matching status will be returned.
- When filters are present (scope, status, user_name, etc.), this method uses
Neo4j 5.18+ pre-filtering: MATCH + WHERE narrows candidates first, then
vector.similarity.cosine() computes similarity only on the filtered set.
This avoids the post-filter problem where queryNodes' global top-k excludes
the target user's nodes in a multi-tenant shared database.
- When no filters are present, the ANN vector index (db.index.vector.queryNodes)
is used for maximum efficiency.
- If threshold is provided, only results with score >= threshold will be returned.
- If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
- Typical use case: restrict to 'status = activated' to avoid
matching archived or merged nodes.
"""
user_name = user_name if user_name else self.config.user_name
# Build WHERE clause dynamically
Expand Down Expand Up @@ -901,14 +905,28 @@ def search_by_embedding(
if extra_fields:
return_clause = f"RETURN node.id AS id, score, {extra_fields}"

query = f"""
CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding)
YIELD node, score
{where_clause}
{return_clause}
"""

parameters = {"embedding": vector, "k": top_k}
if where_clause:
# Pre-filtering (Neo4j 5.18+): filter nodes first, then compute similarity.
# This avoids the post-filter problem where relevant nodes are excluded
# from the global top-k returned by queryNodes.
where_clause += " AND node.embedding IS NOT NULL"
query = f"""
MATCH (node:Memory)
{where_clause}
WITH node, vector.similarity.cosine(node.embedding, $embedding) AS score
{return_clause}
ORDER BY score DESC
LIMIT $top_k
"""
parameters = {"embedding": vector, "top_k": top_k}
else:
# No filter: use ANN vector index for efficiency.
query = f"""
CALL db.index.vector.queryNodes('memory_vector_index', $top_k, $embedding)
YIELD node, score
{return_clause}
"""
parameters = {"embedding": vector, "top_k": top_k}

if scope:
parameters["scope"] = scope
Expand Down Expand Up @@ -1842,7 +1860,7 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
if not (
isinstance(node["sources"][idx], str)
and node["sources"][idx][0] == "{"
and node["sources"][idx][0] == "}"
and node["sources"][idx][-1] == "}"
):
break
node["sources"][idx] = json.loads(node["sources"][idx])
Expand Down
70 changes: 37 additions & 33 deletions src/memos/graph_dbs/neo4j_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,34 +61,35 @@ def add_node(
metadata.setdefault("delete_record_id", "")

# serialization
if metadata["sources"]:
if metadata.get("sources"):
for idx in range(len(metadata["sources"])):
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
# Extract required fields
embedding = metadata.pop("embedding", None)
if embedding is None:
raise ValueError(f"Missing 'embedding' in metadata for node {id}")

# Merge node and set metadata
created_at = metadata.pop("created_at")
updated_at = metadata.pop("updated_at")
vector_sync_status = "success"
vector_sync_status = "skipped"

try:
# Write to Vector DB
item = VecDBItem(
id=id,
vector=embedding,
payload={
"memory": memory,
"vector_sync": vector_sync_status,
**metadata, # unpack all metadata keys to top-level
},
)
self.vec_db.add([item])
except Exception as e:
logger.warning(f"[VecDB] Vector insert failed for node {id}: {e}")
vector_sync_status = "failed"
if embedding is not None:
vector_sync_status = "success"
try:
item = VecDBItem(
id=id,
vector=embedding,
payload={
"memory": memory,
"vector_sync": vector_sync_status,
**metadata,
},
)
self.vec_db.add([item])
except Exception as e:
logger.warning(f"[VecDB] Vector insert failed for node {id}: {e}")
vector_sync_status = "failed"
else:
logger.warning(f"[add_node] No embedding for node {id}, skipping vector DB insert")

metadata["vector_sync"] = vector_sync_status
query = """
Expand Down Expand Up @@ -141,18 +142,21 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N

embedding = metadata.pop("embedding", None)

vector_sync_status = "success"
vec_items.append(
VecDBItem(
id=node_id,
vector=embedding,
payload={
"memory": memory,
"vector_sync": vector_sync_status,
**metadata,
},
if embedding is not None:
vector_sync_status = "success"
vec_items.append(
VecDBItem(
id=node_id,
vector=embedding,
payload={
"memory": memory,
"vector_sync": vector_sync_status,
**metadata,
},
)
)
)
else:
vector_sync_status = "skipped"

created_at = metadata.pop("created_at")
updated_at = metadata.pop("updated_at")
Expand Down Expand Up @@ -1138,12 +1142,12 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
node[time_field] = node[time_field].isoformat()
node.pop("user_name", None)
# serialization
if node["sources"]:
if node.get("sources"):
for idx in range(len(node["sources"])):
if not (
isinstance(node["sources"][idx], str)
and node["sources"][idx][0] == "{"
and node["sources"][idx][0] == "}"
and node["sources"][idx][-1] == "}"
):
break
node["sources"][idx] = json.loads(node["sources"][idx])
Expand Down Expand Up @@ -1179,7 +1183,7 @@ def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]]
if not (
isinstance(node["sources"][idx], str)
and node["sources"][idx][0] == "{"
and node["sources"][idx][0] == "}"
and node["sources"][idx][-1] == "}"
):
break
node["sources"][idx] = json.loads(node["sources"][idx])
Expand Down
Loading
Loading