diff --git a/src/memos/api/client.py b/src/memos/api/client.py index d45276f2c..912f883a7 100644 --- a/src/memos/api/client.py +++ b/src/memos/api/client.py @@ -50,6 +50,7 @@ def get_message( ) response.raise_for_status() response_data = response.json() + return MemOSGetMessagesResponse(**response_data) except Exception as e: logger.error(f"Failed to get messages (retry {retry + 1}/3): {e}") @@ -74,6 +75,7 @@ def add_message( ) response.raise_for_status() response_data = response.json() + return MemOSAddResponse(**response_data) except Exception as e: logger.error(f"Failed to add memory (retry {retry + 1}/3): {e}") @@ -102,6 +104,7 @@ def search_memory( ) response.raise_for_status() response_data = response.json() + return MemOSSearchResponse(**response_data) except Exception as e: logger.error(f"Failed to search memory (retry {retry + 1}/3): {e}") diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 355ee0385..c9ff70d4e 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -518,6 +518,13 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "embedder": APIConfig.get_embedder_config(), "internet_retriever": internet_config, "reranker": APIConfig.get_reranker_config(), + "reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower() + == "true", + "memory_size": { + "WorkingMemory": os.getenv("NEBULAR_WORKING_MEMORY", 20), + "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), + "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), + }, }, }, "act_mem": {} @@ -575,6 +582,11 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower() == "true", "internet_retriever": internet_config, + "memory_size": { + "WorkingMemory": os.getenv("NEBULAR_WORKING_MEMORY", 20), + "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), + "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), + }, }, }, "act_mem": {} diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 7e425415b..2d03d2946 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -191,7 +191,7 @@ class GetMessagesData(BaseModel): """Data model for get messages response based on actual API.""" message_detail_list: list[MessageDetail] = Field( - default_factory=list, alias="memory_detail_list", description="List of message details" + default_factory=list, alias="message_detail_list", description="List of message details" ) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 66ad894ad..45656b770 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -188,6 +188,19 @@ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "N client = cls._CLIENT_CACHE.get(key) if client is None: # Connection setting + + tmp_client = NebulaClient( + hosts=cfg.uri, + username=cfg.user, + password=cfg.password, + session_config=SessionConfig(graph=None), + session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000), + ) + try: + cls._ensure_space_exists(tmp_client, cfg) + finally: + tmp_client.close() + conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None) if conn_conf is None: conn_conf = ConnectionConfig.from_defults( @@ -318,6 +331,7 @@ def __init__(self, config: NebulaGraphDBConfig): } """ + assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED" self.config = config self.db_name = config.space self.user_name = config.user_name @@ -429,15 +443,21 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: if not self.config.use_multi_db and self.config.user_name: optional_condition = f"AND n.user_name = '{self.config.user_name}'" - query = f""" - MATCH (n@Memory) - WHERE n.memory_type = '{memory_type}' - {optional_condition} - ORDER BY n.updated_at DESC - OFFSET {keep_latest} - DETACH DELETE n - """ - self.execute_query(query) + count = self.count_nodes(memory_type) + + if count > keep_latest: + delete_query = f""" + MATCH (n@Memory) + WHERE n.memory_type = '{memory_type}' + {optional_condition} + ORDER BY n.updated_at DESC + OFFSET {keep_latest} + DETACH DELETE n + """ + try: + self.execute_query(delete_query) + except Exception as e: + logger.warning(f"Delete old mem error: {e}") @timed def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: @@ -597,14 +617,19 @@ def get_memory_count(self, memory_type: str) -> int: return -1 @timed - def count_nodes(self, scope: str) -> int: - query = f""" - MATCH (n@Memory) - WHERE n.memory_type = "{scope}" - """ + def count_nodes(self, scope: str | None = None) -> int: + query = "MATCH (n@Memory)" + conditions = [] + + if scope: + conditions.append(f'n.memory_type = "{scope}"') if not self.config.use_multi_db and self.config.user_name: user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + conditions.append(f"n.user_name = '{user_name}'") + + if conditions: + query += "\nWHERE " + " AND ".join(conditions) + query += "\nRETURN count(n) AS count" result = self.execute_query(query) @@ -985,8 +1010,7 @@ def search_by_embedding( dim = len(vector) vector_str = ",".join(f"{float(x)}" for x in vector) gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])" - - where_clauses = [] + where_clauses = [f"n.{self.dim_field} IS NOT NULL"] if scope: where_clauses.append(f'n.memory_type = "{scope}"') if status: @@ -1008,15 +1032,12 @@ def search_by_embedding( where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" gql = f""" - MATCH (n@Memory) + let a = {gql_vector} + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) {where_clause} - ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC - APPROXIMATE + ORDER BY inner_product(n.{self.dim_field}, a) DESC LIMIT {top_k} - OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }} - RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score - """ - + RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score""" try: result = self.execute_query(gql) except Exception as e: @@ -1471,6 +1492,25 @@ def merge_nodes(self, id1: str, id2: str) -> str: """ raise NotImplementedError + @classmethod + def _ensure_space_exists(cls, tmp_client, cfg): + """Lightweight check to ensure target graph (space) exists.""" + db_name = getattr(cfg, "space", None) + if not db_name: + logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.") + return + + try: + res = tmp_client.execute("SHOW GRAPHS;") + existing = {row.values()[0].as_string() for row in res} + if db_name not in existing: + tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type;") + logger.info(f"✅ Graph `{db_name}` created before session binding.") + else: + logger.debug(f"Graph `{db_name}` already exists.") + except Exception: + logger.exception("[NebulaGraphDBSync] Failed to ensure space exists") + @timed def _ensure_database_exists(self): graph_type_name = "MemOSBgeM3Type" diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index f324f41c9..0048f4a59 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -326,10 +326,10 @@ def load(self, dir: str) -> None: except Exception as e: logger.error(f"An error occurred while loading memories: {e}") - def dump(self, dir: str) -> None: + def dump(self, dir: str, include_embedding: bool = False) -> None: """Dump memories to os.path.join(dir, self.config.memory_filename)""" try: - json_memories = self.graph_store.export_graph() + json_memories = self.graph_store.export_graph(include_embedding=include_embedding) os.makedirs(dir, exist_ok=True) memory_file = os.path.join(dir, self.config.memory_filename) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index c9cd4de8a..b0224655c 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -44,6 +44,7 @@ def __init__( "LongTermMemory": 1500, "UserMemory": 480, } + logger.info(f"MemorySize is {self.memory_size}") self._threshold = threshold self.is_reorganize = is_reorganize self.reorganizer = GraphStructureReorganizer( @@ -66,30 +67,33 @@ def add(self, memories: list[TextualMemoryItem]) -> list[str]: except Exception as e: logger.exception("Memory processing error: ", exc_info=e) - try: - self.graph_store.remove_oldest_memory( - memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"] - ) - except Exception: - logger.warning(f"Remove WorkingMemory error: {traceback.format_exc()}") - - try: - self.graph_store.remove_oldest_memory( - memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"] - ) - except Exception: - logger.warning(f"Remove LongTermMemory error: {traceback.format_exc()}") - - try: - self.graph_store.remove_oldest_memory( - memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"] - ) - except Exception: - logger.warning(f"Remove UserMemory error: {traceback.format_exc()}") + # Only clean up if we're close to or over the limit + self._cleanup_memories_if_needed() self._refresh_memory_size() return added_ids + def _cleanup_memories_if_needed(self) -> None: + """ + Only clean up memories if we're close to or over the limit. + This reduces unnecessary database operations. + """ + cleanup_threshold = 0.8 # Clean up when 80% full + + for memory_type, limit in self.memory_size.items(): + current_count = self.current_memory_size.get(memory_type, 0) + threshold = int(limit * cleanup_threshold) + + # Only clean up if we're at or above the threshold + if current_count >= threshold: + try: + self.graph_store.remove_oldest_memory( + memory_type=memory_type, keep_latest=limit + ) + logger.debug(f"Cleaned up {memory_type}: {current_count} -> {limit}") + except Exception: + logger.warning(f"Remove {memory_type} error: {traceback.format_exc()}") + def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: """ Replace WorkingMemory