diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 355ee0385..9a226cf30 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -76,6 +76,24 @@ def get_activation_config() -> dict[str, Any]: }, } + @staticmethod + def get_memreader_config() -> dict[str, Any]: + """Get MemReader configuration.""" + return { + "backend": "openai", + "config": { + "model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"), + "temperature": 0.6, + "max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "5000")), + "top_p": 0.95, + "top_k": 20, + "api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"), + "api_base": os.getenv("MEMRADER_API_BASE"), + "remove_think_prefix": True, + "extra_body": {"chat_template_kwargs": {"enable_thinking": False}}, + }, + } + @staticmethod def get_activation_vllm_config() -> dict[str, Any]: """Get Ollama configuration.""" @@ -351,10 +369,7 @@ def get_product_default_config() -> dict[str, Any]: "mem_reader": { "backend": "simple_struct", "config": { - "llm": { - "backend": "openai", - "config": openai_config, - }, + "llm": APIConfig.get_memreader_config(), "embedder": APIConfig.get_embedder_config(), "chunker": { "backend": "sentence", @@ -447,10 +462,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "mem_reader": { "backend": "simple_struct", "config": { - "llm": { - "backend": "openai", - "config": openai_config, - }, + "llm": APIConfig.get_memreader_config(), "embedder": APIConfig.get_embedder_config(), "chunker": { "backend": "sentence", diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 10c3c75d0..a6f6b82a4 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -432,7 +432,7 @@ def remove_oldest_memory( optional_condition = f"AND n.user_name = '{user_name}'" query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE n.memory_type = '{memory_type}' {optional_condition} ORDER BY n.updated_at DESC @@ -1158,7 +1158,7 @@ def get_grouped_counts( group_by_fields.append(alias) # Full GQL query construction gql = f""" - MATCH (n) + MATCH (n /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {", ".join(return_fields)}, COUNT(n) AS count GROUP BY {", ".join(group_by_fields)} diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index ccc91c48b..55db60ed2 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -101,12 +101,13 @@ def create_index( # Create indexes self._create_basic_property_indexes() - def get_memory_count(self, memory_type: str) -> int: + def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name query = """ MATCH (n:Memory) WHERE n.memory_type = $memory_type """ - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): query += "\nAND n.user_name = $user_name" query += "\nRETURN COUNT(n) AS count" with self.driver.session(database=self.db_name) as session: @@ -114,17 +115,18 @@ def get_memory_count(self, memory_type: str) -> int: query, { "memory_type": memory_type, - "user_name": self.config.user_name if self.config.user_name else None, + "user_name": user_name, }, ) return result.single()["count"] - def node_not_exist(self, scope: str) -> int: + def node_not_exist(self, scope: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name query = """ MATCH (n:Memory) WHERE n.memory_type = $scope """ - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): query += "\nAND n.user_name = $user_name" query += "\nRETURN n LIMIT 1" @@ -133,12 +135,14 @@ def node_not_exist(self, scope: str) -> int: query, { "scope": scope, - "user_name": self.config.user_name if self.config.user_name else None, + "user_name": user_name, }, ) return result.single() is None - def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str | None = None + ) -> None: """ Remove all WorkingMemory nodes except the latest `keep_latest` entries. @@ -146,12 +150,13 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). keep_latest (int): Number of latest WorkingMemory entries to keep. """ + user_name = user_name if user_name else self.config.user_name query = f""" MATCH (n:Memory) WHERE n.memory_type = '{memory_type}' """ - if not self.config.use_multi_db and self.config.user_name: - query += f"\nAND n.user_name = '{self.config.user_name}'" + if not self.config.use_multi_db and (self.config.user_name or user_name): + query += f"\nAND n.user_name = '{user_name}'" query += f""" WITH n ORDER BY n.updated_at DESC @@ -161,9 +166,12 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: with self.driver.session(database=self.db_name) as session: session.run(query) - def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + ) -> None: + user_name = user_name if user_name else self.config.user_name + if not self.config.use_multi_db and (self.config.user_name or user_name): + metadata["user_name"] = user_name # Safely process metadata metadata = _prepare_node_metadata(metadata) @@ -195,10 +203,11 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: metadata=metadata, ) - def update_node(self, id: str, fields: dict[str, Any]) -> None: + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: """ Update node fields in Neo4j, auto-converting `created_at` and `updated_at` to datetime type if present. """ + user_name = user_name if user_name else self.config.user_name fields = fields.copy() # Avoid mutating external dict set_clauses = [] params = {"id": id, "fields": fields} @@ -215,27 +224,28 @@ def update_node(self, id: str, fields: dict[str, Any]) -> None: query = """ MATCH (n:Memory {id: $id}) """ - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): query += "\nWHERE n.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name query += f"\nSET {set_clause_str}" with self.driver.session(database=self.db_name) as session: session.run(query, **params) - def delete_node(self, id: str) -> None: + def delete_node(self, id: str, user_name: str | None = None) -> None: """ Delete a node from the graph. Args: id: Node identifier to delete. """ + user_name = user_name if user_name else self.config.user_name query = "MATCH (n:Memory {id: $id})" params = {"id": id} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): query += " WHERE n.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name query += " DETACH DELETE n" @@ -243,7 +253,9 @@ def delete_node(self, id: str) -> None: session.run(query, **params) # Edge (Relationship) Management - def add_edge(self, source_id: str, target_id: str, type: str) -> None: + def add_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: """ Create an edge from source node to target node. Args: @@ -251,23 +263,26 @@ def add_edge(self, source_id: str, target_id: str, type: str) -> None: target_id: ID of the target node. type: Relationship type (e.g., 'RELATE_TO', 'PARENT'). """ + user_name = user_name if user_name else self.config.user_name query = """ MATCH (a:Memory {id: $source_id}) MATCH (b:Memory {id: $target_id}) """ params = {"source_id": source_id, "target_id": target_id} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): query += """ WHERE a.user_name = $user_name AND b.user_name = $user_name """ - params["user_name"] = self.config.user_name + params["user_name"] = user_name query += f"\nMERGE (a)-[:{type}]->(b)" with self.driver.session(database=self.db_name) as session: session.run(query, params) - def delete_edge(self, source_id: str, target_id: str, type: str) -> None: + def delete_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: """ Delete a specific edge between two nodes. Args: @@ -275,6 +290,7 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: target_id: ID of the target node. type: Relationship type to remove. """ + user_name = user_name if user_name else self.config.user_name query = f""" MATCH (a:Memory {{id: $source}}) -[r:{type}]-> @@ -282,9 +298,9 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: """ params = {"source": source_id, "target": target_id} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): query += "\nWHERE a.user_name = $user_name AND b.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name query += "\nDELETE r" @@ -292,7 +308,12 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: session.run(query, params) def edge_exists( - self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING" + self, + source_id: str, + target_id: str, + type: str = "ANY", + direction: str = "OUTGOING", + user_name: str | None = None, ) -> bool: """ Check if an edge exists between two nodes. @@ -305,6 +326,7 @@ def edge_exists( Returns: True if the edge exists, otherwise False. """ + user_name = user_name if user_name else self.config.user_name # Prepare the relationship pattern rel = "r" if type == "ANY" else f"r:{type}" @@ -322,9 +344,9 @@ def edge_exists( query = f"MATCH {pattern}" params = {"source": source_id, "target": target_id} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): query += "\nWHERE a.user_name = $user_name AND b.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name query += "\nRETURN r" @@ -342,12 +364,12 @@ def get_node(self, id: str, **kwargs) -> dict[str, Any] | None: Returns: Dictionary of node fields, or None if not found. """ - + user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name where_user = "" params = {"id": id} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): where_user = " AND n.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name query = f"MATCH (n:Memory) WHERE n.id = $id {where_user} RETURN n" @@ -370,16 +392,16 @@ def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]: if not ids: return [] - + user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name where_user = "" params = {"ids": ids} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): where_user = " AND n.user_name = $user_name" if kwargs.get("cube_name"): params["user_name"] = kwargs["cube_name"] else: - params["user_name"] = self.config.user_name + params["user_name"] = user_name query = f"MATCH (n:Memory) WHERE n.id IN $ids{where_user} RETURN n" @@ -387,7 +409,9 @@ def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]: results = session.run(query, params) return [self._parse_node(dict(record["n"])) for record in results] - def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]: + def get_edges( + self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None + ) -> list[dict[str, str]]: """ Get edges connected to a node, with optional type and direction filter. @@ -403,6 +427,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ ... ] """ + user_name = user_name if user_name else self.config.user_name # Build relationship type filter rel_type = "" if type == "ANY" else f":{type}" @@ -421,9 +446,9 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ params = {"id": id} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): where_clause += " AND a.user_name = $user_name AND b.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name query = f""" MATCH {pattern} @@ -441,7 +466,11 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ return edges def get_neighbors( - self, id: str, type: str, direction: Literal["in", "out", "both"] = "out" + self, + id: str, + type: str, + direction: Literal["in", "out", "both"] = "out", + user_name: str | None = None, ) -> list[str]: """ Get connected node IDs in a specific direction and relationship type. @@ -460,6 +489,7 @@ def get_neighbors_by_tag( exclude_ids: list[str], top_k: int = 5, min_overlap: int = 1, + user_name: str | None = None, ) -> list[dict[str, Any]]: """ Find top-K neighbor nodes with maximum tag overlap. @@ -473,6 +503,7 @@ def get_neighbors_by_tag( Returns: List of dicts with node details and overlap count. """ + user_name = user_name if user_name else self.config.user_name where_user = "" params = { "tags": tags, @@ -481,9 +512,9 @@ def get_neighbors_by_tag( "top_k": top_k, } - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): where_user = "AND n.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name query = f""" MATCH (n:Memory) @@ -503,13 +534,16 @@ def get_neighbors_by_tag( result = session.run(query, params) return [self._parse_node(dict(record["n"])) for record in result] - def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: + def get_children_with_embeddings( + self, id: str, user_name: str | None = None + ) -> list[dict[str, Any]]: + user_name = user_name if user_name else self.config.user_name where_user = "" params = {"id": id} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): where_user = "AND p.user_name = $user_name AND c.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name query = f""" MATCH (p:Memory)-[:PARENT]->(c:Memory) @@ -523,7 +557,9 @@ def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: {"id": r["id"], "embedding": r["embedding"], "memory": r["memory"]} for r in result ] - def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: + def get_path( + self, source_id: str, target_id: str, max_depth: int = 3, user_name: str | None = None + ) -> list[str]: """ Get the path of nodes from source to target within a limited depth. Args: @@ -536,7 +572,11 @@ def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[s raise NotImplementedError def get_subgraph( - self, center_id: str, depth: int = 2, center_status: str = "activated" + self, + center_id: str, + depth: int = 2, + center_status: str = "activated", + user_name: str | None = None, ) -> dict[str, Any]: """ Retrieve a local subgraph centered at a given node. @@ -551,15 +591,16 @@ def get_subgraph( "edges": [...] } """ + user_name = user_name if user_name else self.config.user_name with self.driver.session(database=self.db_name) as session: params = {"center_id": center_id} center_user_clause = "" neighbor_user_clause = "" - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): center_user_clause = " AND center.user_name = $user_name" neighbor_user_clause = " WHERE neighbor.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name status_clause = f" AND center.status = '{center_status}'" if center_status else "" query = f""" @@ -618,6 +659,7 @@ def search_by_embedding( status: str | None = None, threshold: float | None = None, search_filter: dict | None = None, + user_name: str | None = None, **kwargs, ) -> list[dict]: """ @@ -645,13 +687,14 @@ def search_by_embedding( - Typical use case: restrict to 'status = activated' to avoid matching archived or merged nodes. """ + user_name = user_name if user_name else self.config.user_name # Build WHERE clause dynamically where_clauses = [] if scope: where_clauses.append("node.memory_type = $scope") if status: where_clauses.append("node.status = $status") - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): where_clauses.append("node.user_name = $user_name") # Add search_filter conditions @@ -677,11 +720,11 @@ def search_by_embedding( parameters["scope"] = scope if status: parameters["status"] = status - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): if kwargs.get("cube_name"): parameters["user_name"] = kwargs["cube_name"] else: - parameters["user_name"] = self.config.user_name + parameters["user_name"] = user_name # Add search_filter parameters if search_filter: @@ -699,7 +742,9 @@ def search_by_embedding( return records - def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: + def get_by_metadata( + self, filters: list[dict[str, Any]], user_name: str | None = None + ) -> list[str]: """ TODO: 1. ADD logic: "AND" vs "OR"(support logic combination); @@ -724,6 +769,7 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: - Supports structured querying such as tag/category/importance/time filtering. - Can be used for faceted recall or prefiltering before embedding rerank. """ + user_name = user_name if user_name else self.config.user_name where_clauses = [] params = {} @@ -755,9 +801,9 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: else: raise ValueError(f"Unsupported operator: {op}") - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): where_clauses.append("n.user_name = $user_name") - params["user_name"] = self.config.user_name + params["user_name"] = user_name where_str = " AND ".join(where_clauses) query = f"MATCH (n:Memory) WHERE {where_str} RETURN n.id AS id" @@ -771,6 +817,7 @@ def get_grouped_counts( group_fields: list[str], where_clause: str = "", params: dict[str, Any] | None = None, + user_name: str | None = None, ) -> list[dict[str, Any]]: """ Count nodes grouped by any fields. @@ -784,14 +831,15 @@ def get_grouped_counts( Returns: list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] """ + user_name = user_name if user_name else self.config.user_name if not group_fields: raise ValueError("group_fields cannot be empty") final_params = params.copy() if params else {} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): user_clause = "n.user_name = $user_name" - final_params["user_name"] = self.config.user_name + final_params["user_name"] = user_name if where_clause: where_clause = where_clause.strip() if where_clause.upper().startswith("WHERE"): @@ -845,14 +893,15 @@ def merge_nodes(self, id1: str, id2: str) -> str: raise NotImplementedError # Utilities - def clear(self) -> None: + def clear(self, user_name: str | None = None) -> None: """ Clear the entire graph if the target database exists. """ + user_name = user_name if user_name else self.config.user_name try: - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): query = "MATCH (n:Memory) WHERE n.user_name = $user_name DETACH DELETE n" - params = {"user_name": self.config.user_name} + params = {"user_name": user_name} else: query = "MATCH (n) DETACH DELETE n" params = {} @@ -876,16 +925,17 @@ def export_graph(self, **kwargs) -> dict[str, Any]: "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] } """ + user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name with self.driver.session(database=self.db_name) as session: # Export nodes node_query = "MATCH (n:Memory)" edge_query = "MATCH (a:Memory)-[r]->(b:Memory)" params = {} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): node_query += " WHERE n.user_name = $user_name" edge_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name node_result = session.run(f"{node_query} RETURN n", params) nodes = [self._parse_node(dict(record["n"])) for record in node_result] @@ -901,19 +951,20 @@ def export_graph(self, **kwargs) -> dict[str, Any]: return {"nodes": nodes, "edges": edges} - def import_graph(self, data: dict[str, Any]) -> None: + def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: """ Import the entire graph from a serialized dictionary. Args: data: A dictionary containing all nodes and edges to be loaded. """ + user_name = user_name if user_name else self.config.user_name with self.driver.session(database=self.db_name) as session: for node in data.get("nodes", []): id, memory, metadata = _compose_node(node) - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name + if not self.config.use_multi_db and (self.config.user_name or user_name): + metadata["user_name"] = user_name metadata = _prepare_node_metadata(metadata) @@ -958,15 +1009,16 @@ def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]: Returns: list[dict]: Full list of memory items under this scope. """ + user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: raise ValueError(f"Unsupported memory type scope: {scope}") where_clause = "WHERE n.memory_type = $scope" params = {"scope": scope} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): where_clause += " AND n.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name query = f""" MATCH (n:Memory) @@ -984,7 +1036,7 @@ def get_structure_optimization_candidates(self, scope: str, **kwargs) -> list[di - Isolated nodes, nodes with empty background, or nodes with exactly one child. - Plus: the child of any parent node that has exactly one child. """ - + user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name where_clause = """ WHERE n.memory_type = $scope AND n.status = 'activated' @@ -992,9 +1044,9 @@ def get_structure_optimization_candidates(self, scope: str, **kwargs) -> list[di """ params = {"scope": scope} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): where_clause += " AND n.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name query = f""" MATCH (n:Memory) diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 54000a51d..6f7786834 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -1,4 +1,5 @@ import json + from typing import Any from memos.configs.graph_db import Neo4jGraphDBConfig @@ -43,9 +44,12 @@ def create_index( # Create indexes self._create_basic_property_indexes() - def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + ) -> None: + user_name = user_name if user_name else self.config.user_name + if not self.config.use_multi_db and (self.config.user_name or user_name): + metadata["user_name"] = user_name # Safely process metadata metadata = _prepare_node_metadata(metadata) @@ -98,13 +102,16 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: metadata=metadata, ) - def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: + def get_children_with_embeddings( + self, id: str, user_name: str | None = None + ) -> list[dict[str, Any]]: + user_name = user_name if user_name else self.config.user_name where_user = "" params = {"id": id} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): where_user = "AND p.user_name = $user_name AND c.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name query = f""" MATCH (p:Memory)-[:PARENT]->(c:Memory) @@ -135,6 +142,7 @@ def search_by_embedding( status: str | None = None, threshold: float | None = None, search_filter: dict | None = None, + user_name: str | None = None, **kwargs, ) -> list[dict]: """ @@ -159,6 +167,7 @@ def search_by_embedding( - If 'search_filter' is provided, it applies additional metadata-based filtering. - The returned IDs can be used to fetch full node data from Neo4j if needed. """ + user_name = user_name if user_name else self.config.user_name # Build VecDB filter vec_filter = {} if scope: @@ -169,7 +178,7 @@ def search_by_embedding( if kwargs.get("cube_name"): vec_filter["user_name"] = kwargs["cube_name"] else: - vec_filter["user_name"] = self.config.user_name + vec_filter["user_name"] = user_name # Add search_filter conditions if search_filter: @@ -194,15 +203,16 @@ def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]: Returns: list[dict]: Full list of memory items under this scope. """ + user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory"}: raise ValueError(f"Unsupported memory type scope: {scope}") where_clause = "WHERE n.memory_type = $scope" params = {"scope": scope} - if not self.config.use_multi_db and self.config.user_name: + if not self.config.use_multi_db and (self.config.user_name or user_name): where_clause += " AND n.user_name = $user_name" - params["user_name"] = self.config.user_name + params["user_name"] = user_name query = f""" MATCH (n:Memory) @@ -214,23 +224,24 @@ def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]: results = session.run(query, params) return [self._parse_node(dict(record["n"])) for record in results] - def clear(self) -> None: + def clear(self, user_name: str | None = None) -> None: """ Clear the entire graph if the target database exists. """ # Step 1: clear Neo4j part via parent logic - super().clear() + user_name = user_name if user_name else self.config.user_name + super().clear(user_name=user_name) # Step2: Clear the vector db try: - items = self.vec_db.get_by_filter({"user_name": self.config.user_name}) + items = self.vec_db.get_by_filter({"user_name": user_name}) if items: self.vec_db.delete([item.id for item in items]) - logger.info(f"Cleared {len(items)} vectors for user '{self.config.user_name}'.") + logger.info(f"Cleared {len(items)} vectors for user '{user_name}'.") else: - logger.info(f"No vectors to clear for user '{self.config.user_name}'.") + logger.info(f"No vectors to clear for user '{user_name}'.") except Exception as e: - logger.warning(f"Failed to clear vector DB for user '{self.config.user_name}': {e}") + logger.warning(f"Failed to clear vector DB for user '{user_name}': {e}") def drop_database(self) -> None: """