diff --git a/.gitignore b/.gitignore index ae7bdc4d6..8319a4d2f 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ evaluation/.env !evaluation/configs-example/*.json evaluation/configs/* **tree_textual_memory_locomo** +**script.py** .env evaluation/scripts/personamem diff --git a/examples/mem_reader/reader.py b/examples/mem_reader/reader.py index e26d00a67..3da5d5e76 100644 --- a/examples/mem_reader/reader.py +++ b/examples/mem_reader/reader.py @@ -2,6 +2,11 @@ from memos.configs.mem_reader import SimpleStructMemReaderConfig from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.memories.textual.item import ( + SourceMessage, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) def main(): @@ -11,7 +16,7 @@ def main(): ) reader = SimpleStructMemReader(reader_config) - # 3. Define scene data + # 2. Define scene data scene_data = [ [ {"role": "user", "chat_time": "3 May 2025", "content": "I'm feeling a bit down today."}, @@ -187,32 +192,389 @@ def main(): ], ] - # 4. Acquiring memories + print("=== Mem-Reader Fast vs Fine Mode Comparison ===\n") + + # 3. Test Fine Mode (default) + print("🔄 Testing FINE mode (default, with LLM processing)...") + start_time = time.time() + fine_memory = reader.get_memory( + scene_data, type="chat", info={"user_id": "user1", "session_id": "session1"}, mode="fine" + ) + fine_time = time.time() - start_time + print(f"✅ Fine mode completed in {fine_time:.2f} seconds") + print(f"📊 Fine mode generated {sum(len(mem_list) for mem_list in fine_memory)} memory items") + + # 4. Test Fast Mode + print("\n⚡ Testing FAST mode (quick processing, no LLM calls)...") start_time = time.time() - chat_memory = reader.get_memory( - scene_data, type="chat", info={"user_id": "user1", "session_id": "session1"} + fast_memory = reader.get_memory( + scene_data, type="chat", info={"user_id": "user1", "session_id": "session1"}, mode="fast" ) - print("\nChat Memory:\n", chat_memory) + fast_time = time.time() - start_time + print(f"✅ Fast mode completed in {fast_time:.2f} seconds") + print(f"📊 Fast mode generated {sum(len(mem_list) for mem_list in fast_memory)} memory items") + + # 5. Performance Comparison + print("\n📈 Performance Comparison:") + print(f" Fine mode: {fine_time:.2f}s") + print(f" Fast mode: {fast_time:.2f}s") + print(f" Speed improvement: {fine_time / fast_time:.1f}x faster") + + # 6. Show sample results from both modes + print("\n🔍 Sample Results Comparison:") + print("\n--- FINE Mode Results (first 3 items) ---") + for i, mem_list in enumerate(fine_memory[:3]): + for j, mem_item in enumerate(mem_list[:2]): # Show first 2 items from each list + print(f" [{i}][{j}] {mem_item.memory[:100]}...") - # 5. Example of processing documents - print("\n=== Processing Documents ===") + print("\n--- FAST Mode Results (first 3 items) ---") + for i, mem_list in enumerate(fast_memory[:3]): + for j, mem_item in enumerate(mem_list[:2]): # Show first 2 items from each list + print(f" [{i}][{j}] {mem_item.memory[:100]}...") + + # 7. Example of transfer fast mode result into fine result + fast_mode_memories = [ + TextualMemoryItem( + id="4553141b-3a33-4548-b779-e677ec797a9f", + memory="user: Nate:Oh cool! I might check that one out some time soon! I do love watching classics.\nassistant: Joanna:Yep, that movie is awesome. I first watched it around 3 years ago. I even went out and got a physical copy!\nuser: Nate:Sounds cool! Have you seen it a lot? sounds like you know the movie well!\nassistant: Joanna:A few times. It's one of my favorites! I really like the idea and the acting.\nuser: Nate:Cool! I'll definitely check it out. Thanks for the recommendation!\nassistant: Joanna:No problem, Nate! Let me know if you like it!\n", + metadata=TreeNodeTextualMemoryMetadata( + user_id="nate_test", + session_id="root_session", + status="activated", + type="fact", + key="user: Nate:Oh cool", + confidence=0.9900000095367432, + source=None, + tags=["mode:fast", "lang:en", "role:assistant", "role:user"], + visibility=None, + updated_at="2025-10-16T17:16:30.094877+08:00", + memory_type="LongTermMemory", + sources=[ + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=0, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=1, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=2, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=3, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=4, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=5, + ), + ], + embedding=None, + created_at="2025-10-16T17:16:30.094919+08:00", + usage=[], + background="", + ), + ), + TextualMemoryItem( + id="752e42fa-92b6-491a-a430-6864a7730fba", + memory="user: Nate:It was! How about you? Do you have any hobbies you love?\nassistant: Joanna:Yeah! Besides writing, I also enjoy reading, watching movies, and exploring nature. Anything else you enjoy doing, Nate?\nuser: Nate:Playing video games and watching movies are my main hobbies.\nassistant: Joanna:Cool, Nate! So we both have similar interests. What type of movies do you like best?\nuser: Nate:I love action and sci-fi movies, the effects are so cool! What about you, what's your favorite genre?\nassistant: Joanna:I'm all about dramas and romcoms. I love getting immersed in the feelings and plots.\nuser: Nate:Wow, movies can be so powerful! Do you have any recommendations for me?\nassistant: Joanna:Yeah, totally! Have you seen this romantic drama that's all about memory and relationships? It's such a good one.\nuser: Nate:Oh cool! I might check that one out some time soon! I do love watching classics.\nassistant: Joanna:Yep, that movie is awesome. I first watched it around 3 years ago. I even went out and got a physical copy!\n", + metadata=TreeNodeTextualMemoryMetadata( + user_id="nate_test", + session_id="root_session", + status="activated", + type="fact", + key="user: Nate:It was", + confidence=0.9900000095367432, + source=None, + tags=["mode:fast", "lang:en", "role:assistant", "role:user"], + visibility=None, + updated_at="2025-10-16T17:16:30.095726+08:00", + memory_type="LongTermMemory", + sources=[ + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=0, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=1, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=2, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=3, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=4, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=5, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=6, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=7, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=8, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=9, + ), + ], + embedding=None, + created_at="2025-10-16T17:16:30.095767+08:00", + usage=[], + background="", + ), + ), + TextualMemoryItem( + id="c9cf448c-deee-43a8-bafd-eb15fde535b2", + memory="user: Nate:Hey Joanna! Long time no see! What's up? Anything fun going on?\nassistant: Joanna:Hey Nate! Long time no see! I've been working on a project lately - it's been pretty cool. What about you - any fun projects or hobbies?\nuser: Nate:Hey Joanna! That's cool! I won my first video game tournament last week - so exciting!\nassistant: Joanna:Wow Nate! Congrats on winning! Tell me more - what game was it?\nuser: Nate:Thanks! it's a team shooter game.\nassistant: Joanna:Wow, great job! What was is called?\nuser: Nate:The game was called Counter-Strike: Global Offensive, and me and my team had a blast to the very end!\nassistant: Joanna:Cool, Nate! Sounds like a fun experience, even if I'm not into games.\nuser: Nate:It was! How about you? Do you have any hobbies you love?\nassistant: Joanna:Yeah! Besides writing, I also enjoy reading, watching movies, and exploring nature. Anything else you enjoy doing, Nate?\n", + metadata=TreeNodeTextualMemoryMetadata( + user_id="nate_test", + session_id="root_session", + status="activated", + type="fact", + key="user: Nate:Hey Joanna", + confidence=0.9900000095367432, + source=None, + tags=["mode:fast", "lang:en", "role:assistant", "role:user"], + visibility=None, + updated_at="2025-10-16T17:16:30.098208+08:00", + memory_type="LongTermMemory", + sources=[ + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=0, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=1, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=2, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=3, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=4, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=5, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=6, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=7, + ), + SourceMessage( + type="chat", + role="user", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=8, + ), + SourceMessage( + type="chat", + role="assistant", + chat_time="7:31 pm on 21 January, 2022", + message_id=None, + content=None, + doc_path=None, + index=9, + ), + ], + embedding=None, + created_at="2025-10-16T17:16:30.098246+08:00", + usage=[], + background="", + ), + ), + ] + fine_memories = reader.fine_transfer_simple_mem(fast_mode_memories, type="chat") + print("\n--- Transfer Mode Results (first 3 items) ---") + for i, mem_list in enumerate(fine_memories[:3]): + for j, mem_item in enumerate(mem_list[:2]): # Show first 2 items from each list + print(f" [{i}][{j}] {mem_item.memory[:100]}...") + + # 7. Example of processing documents (only in fine mode) + print("\n=== Processing Documents (Fine Mode Only) ===") # Example document paths (you should replace these with actual document paths) doc_paths = [ "examples/mem_reader/text1.txt", "examples/mem_reader/text2.txt", ] - # 6. Acquiring memories from documents - doc_memory = reader.get_memory( - doc_paths, - "doc", - info={ - "user_id": "1111", - "session_id": "2222", - }, - ) - print("\nDocument Memory:\n", doc_memory) - end_time = time.time() - print(f"The runtime is {end_time - start_time} seconds.") + + try: + # 6. Acquiring memories from documents + doc_memory = reader.get_memory( + doc_paths, + "doc", + info={ + "user_id": "1111", + "session_id": "2222", + }, + mode="fine", + ) + print( + f"\n📄 Document Memory generated {sum(len(mem_list) for mem_list in doc_memory)} items" + ) + except Exception as e: + print(f"⚠️ Document processing failed: {e}") + print(" (This is expected if document files don't exist)") + + print("\n🎯 Summary:") + print(f" • Fast mode: {fast_time:.2f}s - Quick processing, no LLM calls") + print(f" • Fine mode: {fine_time:.2f}s - Full LLM processing for better understanding") + print(" • Use fast mode for: Real-time applications, high-throughput scenarios") + print(" • Use fine mode for: Quality analysis, detailed memory extraction") if __name__ == "__main__": diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index 4de0cf32b..080962482 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -28,7 +28,7 @@ def __init__(self, config: SentenceChunkerConfig): ) logger.info(f"Initialized SentenceChunker with config: {config}") - def chunk(self, text: str) -> list[Chunk]: + def chunk(self, text: str) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" chonkie_chunks = self.chunker.chunk(text) diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index 39586081c..2d6155ec2 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -28,13 +28,11 @@ class BaseSchedulerConfig(BaseConfig): thread_pool_max_workers: int = Field( default=DEFAULT_THREAD_POOL_MAX_WORKERS, gt=1, - lt=20, description=f"Maximum worker threads in pool (default: {DEFAULT_THREAD_POOL_MAX_WORKERS})", ) consume_interval_seconds: float = Field( default=DEFAULT_CONSUME_INTERVAL_SECONDS, gt=0, - le=60, description=f"Interval for consuming messages from queue in seconds (default: {DEFAULT_CONSUME_INTERVAL_SECONDS})", ) auth_config_path: str | None = Field( diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 237450e15..2c3a715f7 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -179,6 +179,11 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ), ) + mode: str | None = Field( + default="sync", + description=("whether use asynchronous mode in memory add"), + ) + class SimpleTreeTextMemoryConfig(TreeTextMemoryConfig): """Simple tree text memory configuration class.""" diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 9a74373d7..12b493e58 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -440,20 +440,22 @@ def remove_oldest_memory( memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). keep_latest (int): Number of latest WorkingMemory entries to keep. """ - optional_condition = "" - - user_name = user_name if user_name else self.config.user_name - - optional_condition = f"AND n.user_name = '{user_name}'" - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE n.memory_type = '{memory_type}' - {optional_condition} - ORDER BY n.updated_at DESC - OFFSET {int(keep_latest)} - DETACH DELETE n - """ - self.execute_query(query) + try: + user_name = user_name if user_name else self.config.user_name + optional_condition = f"AND n.user_name = '{user_name}'" + count = self.count_nodes(memory_type, user_name) + if count > keep_latest: + delete_query = f""" + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) + WHERE n.memory_type = '{memory_type}' + {optional_condition} + ORDER BY n.updated_at DESC + OFFSET {int(keep_latest)} + DETACH DELETE n + """ + self.execute_query(delete_query) + except Exception as e: + logger.warning(f"Delete old mem error: {e}") @timed def add_node( @@ -1175,7 +1177,6 @@ def get_grouped_counts( MATCH (n /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {", ".join(return_fields)}, COUNT(n) AS count - GROUP BY {", ".join(group_by_fields)} """ result = self.execute_query(gql) # Pure GQL string execution @@ -1620,7 +1621,13 @@ def _create_basic_property_indexes(self) -> None: Create standard B-tree indexes on user_name when use Shared Database Multi-Tenant Mode. """ - fields = ["status", "memory_type", "created_at", "updated_at", "user_name"] + fields = [ + "status", + "memory_type", + "created_at", + "updated_at", + "user_name", + ] for field in fields: index_name = f"idx_memory_{field}" diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 55db60ed2..f51b3465d 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -669,7 +669,7 @@ def search_by_embedding( vector (list[float]): The embedding vector representing query semantics. top_k (int): Number of top similar nodes to retrieve. scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory'). - status (str, optional): Node status filter (e.g., 'active', 'archived'). + status (str, optional): Node status filter (e.g., 'activated', 'archived'). If provided, restricts results to nodes with matching status. threshold (float, optional): Minimum similarity score threshold (0 ~ 1). search_filter (dict, optional): Additional metadata filters for search results. diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 958cc140c..0010897c0 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -17,6 +17,7 @@ from memos.mem_scheduler.schemas.general_schemas import ( ADD_LABEL, ANSWER_LABEL, + MEM_READ_LABEL, QUERY_LABEL, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -70,6 +71,7 @@ def __init__(self, config: MOSConfig, user_manager: UserManager | None = None): if self.enable_mem_scheduler: self._mem_scheduler = self._initialize_mem_scheduler() self._mem_scheduler.mem_cubes = self.mem_cubes + self._mem_scheduler.mem_reader = self.mem_reader else: self._mem_scheduler: GeneralScheduler = None @@ -681,6 +683,12 @@ def add( logger.info( f"time add: get mem_cube_id check in mem_cubes time user_id: {target_user_id} time is: {time.time() - time_start_0}" ) + sync_mode = self.mem_cubes[mem_cube_id].text_mem.mode + if sync_mode == "async": + assert self.mem_scheduler is not None, ( + "Mem-Scheduler must be working when use asynchronous memory adding." + ) + logger.debug(f"Mem-reader mode is: {sync_mode}") time_start_1 = time.time() if ( (messages is not None) @@ -690,6 +698,7 @@ def add( logger.info( f"time add: messages is not None and enable_textual_memory and text_mem is not None time user_id: {target_user_id} time is: {time.time() - time_start_1}" ) + if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text": add_memory = [] metadata = TextualMemoryMetadata( @@ -707,21 +716,30 @@ def add( messages_list, type="chat", info={"user_id": target_user_id, "session_id": target_session_id}, + mode="fast" if sync_mode == "async" else "fine", ) logger.info( f"time add: get mem_reader time user_id: {target_user_id} time is: {time.time() - time_start_2}" ) - mem_ids = [] - for mem in memories: - mem_id_list: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(mem) - mem_ids.extend(mem_id_list) - logger.info( - f"Added memory user {target_user_id} to memcube {mem_cube_id}: {mem_id_list}" - ) - + memories_flatten = [m for m_list in memories for m in m_list] + mem_ids: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(memories_flatten) + logger.info( + f"Added memory user {target_user_id} to memcube {mem_cube_id}: {mem_ids}" + ) # submit messages for scheduler if self.enable_mem_scheduler and self.mem_scheduler is not None: mem_cube = self.mem_cubes[mem_cube_id] + if sync_mode == "async": + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) + message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, @@ -749,10 +767,12 @@ def add( messages_list = [ [{"role": "user", "content": memory_content}] ] # for only user-str input and convert message + memories = self.mem_reader.get_memory( messages_list, type="chat", info={"user_id": target_user_id, "session_id": target_session_id}, + mode="fast" if sync_mode == "async" else "fine", ) mem_ids = [] @@ -766,6 +786,16 @@ def add( # submit messages for scheduler if self.enable_mem_scheduler and self.mem_scheduler is not None: mem_cube = self.mem_cubes[mem_cube_id] + if sync_mode == "async": + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, diff --git a/src/memos/mem_reader/base.py b/src/memos/mem_reader/base.py index f092c3870..3095a0bc6 100644 --- a/src/memos/mem_reader/base.py +++ b/src/memos/mem_reader/base.py @@ -18,10 +18,17 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: @abstractmethod def get_memory( - self, scene_data: list, type: str, info: dict[str, Any] + self, scene_data: list, type: str, info: dict[str, Any], mode: str = "fast" ) -> list[list[TextualMemoryItem]]: """Various types of memories extracted from scene_data""" @abstractmethod def transform_memreader(self, data: dict) -> list[TextualMemoryItem]: """Transform the memory data into a list of TextualMemoryItem objects.""" + + @abstractmethod + def fine_transfer_simple_mem( + self, input_memories: list[list[TextualMemoryItem]], type: str + ) -> list[list[TextualMemoryItem]]: + """Fine Transform TextualMemoryItem List into another list of + TextualMemoryItem objects via calling llm to better understand users.""" diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index b439cb2b2..9f5eb9832 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -3,6 +3,7 @@ import json import os import re +import traceback from abc import ABC from typing import Any @@ -41,6 +42,26 @@ "doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH}, } +try: + import tiktoken + + try: + _ENC = tiktoken.encoding_for_model("gpt-4o-mini") + except Exception: + _ENC = tiktoken.get_encoding("cl100k_base") + + def _count_tokens_text(s: str) -> int: + return len(_ENC.encode(s or "")) +except Exception: + # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars + def _count_tokens_text(s: str) -> int: + if not s: + return 0 + zh_chars = re.findall(r"[\u4e00-\u9fff]", s) + zh = len(zh_chars) + rest = len(s) - zh + return zh + max(1, rest // 4) + def detect_lang(text): try: @@ -112,6 +133,14 @@ def _build_node(idx, message, info, scene_file, llm, parse_json_result, embedder return None +def _derive_key(text: str, max_len: int = 80) -> str: + """default key when without LLM: first max_len words""" + if not text: + return "" + sent = re.split(r"[。!?!?]\s*|\n", text.strip())[0] + return (sent[:max_len]).strip() + + class SimpleStructMemReader(BaseMemReader, ABC): """Naive implementation of MemReader.""" @@ -126,27 +155,50 @@ def __init__(self, config: SimpleStructMemReaderConfig): self.llm = LLMFactory.from_config(config.llm) self.embedder = EmbedderFactory.from_config(config.embedder) self.chunker = ChunkerFactory.from_config(config.chunker) + self.memory_max_length = 8000 + # Use token-based windowing; default to ~5000 tokens if not configured + self.chat_window_max_tokens = getattr(self.config, "chat_window_max_tokens", 1024) + self._count_tokens = _count_tokens_text + + def _make_memory_item( + self, + value: str, + info: dict, + memory_type: str, + tags: list[str] | None = None, + key: str | None = None, + sources: list | None = None, + background: str = "", + type_: str = "fact", + confidence: float = 0.99, + ) -> TextualMemoryItem: + """construct memory item""" + return TextualMemoryItem( + memory=value, + metadata=TreeNodeTextualMemoryMetadata( + user_id=info.get("user_id", ""), + session_id=info.get("session_id", ""), + memory_type=memory_type, + status="activated", + tags=tags or [], + key=key if key is not None else _derive_key(value), + embedding=self.embedder.embed([value])[0], + usage=[], + sources=sources or [], + background=background, + confidence=confidence, + type=type_, + ), + ) - @timed - def _process_chat_data(self, scene_data_info, info): - mem_list = [] - for item in scene_data_info: - if "chat_time" in item: - mem = item["role"] + ": " + f"[{item['chat_time']}]: " + item["content"] - mem_list.append(mem) - else: - mem = item["role"] + ":" + item["content"] - mem_list.append(mem) - lang = detect_lang("\n".join(mem_list)) + def _get_llm_response(self, mem_str: str) -> dict: + lang = detect_lang(mem_str) template = PROMPT_DICT["chat"][lang] examples = PROMPT_DICT["chat"][f"{lang}_example"] - - prompt = template.replace("${conversation}", "\n".join(mem_list)) + prompt = template.replace("${conversation}", mem_str) if self.config.remove_prompt_example: prompt = prompt.replace(examples, "") - messages = [{"role": "user", "content": prompt}] - try: response_text = self.llm.generate(messages) response_json = self.parse_json_result(response_text) @@ -155,15 +207,111 @@ def _process_chat_data(self, scene_data_info, info): response_json = { "memory list": [ { - "key": "\n".join(mem_list)[:10], + "key": mem_str[:10], "memory_type": "UserMemory", - "value": "\n".join(mem_list), + "value": mem_str, "tags": [], } ], - "summary": "\n".join(mem_list), + "summary": mem_str, } + return response_json + def _iter_chat_windows(self, scene_data_info, max_tokens=None, overlap=200): + """ + use token counter to get a slide window generator + """ + max_tokens = max_tokens or self.chat_window_max_tokens + buf, sources, start_idx = [], [], 0 + cur_text = "" + + for idx, item in enumerate(scene_data_info): + role = item.get("role", "") + content = item.get("content", "") + chat_time = item.get("chat_time", None) + parts = [] + if role and str(role).lower() != "mix": + parts.append(f"{role}: ") + if chat_time: + parts.append(f"[{chat_time}]: ") + prefix = "".join(parts) + line = f"{prefix}{content}\n" + + if self._count_tokens(cur_text + line) > max_tokens and cur_text: + text = "".join(buf) + yield {"text": text, "sources": sources.copy(), "start_idx": start_idx} + while buf and self._count_tokens("".join(buf)) > overlap: + buf.pop(0) + sources.pop(0) + start_idx = idx + cur_text = "".join(buf) + + buf.append(line) + sources.append({"type": "chat", "index": idx, "role": role, "chat_time": chat_time}) + cur_text = "".join(buf) + + if buf: + yield {"text": "".join(buf), "sources": sources.copy(), "start_idx": start_idx} + + @timed + def _process_chat_data(self, scene_data_info, info, **kwargs): + mode = kwargs.get("mode", "fine") + windows = list(self._iter_chat_windows(scene_data_info)) + + if mode == "fast": + logger.debug("Using unified Fast Mode") + + def _build_fast_node(w): + text = w["text"] + roles = {s.get("role", "") for s in w["sources"] if s.get("role")} + mem_type = "UserMemory" if roles == {"user"} else "LongTermMemory" + tags = ["mode:fast"] + return self._make_memory_item( + value=text, info=info, memory_type=mem_type, tags=tags, sources=w["sources"] + ) + + with ContextThreadPoolExecutor(max_workers=8) as ex: + futures = {ex.submit(_build_fast_node, w): i for i, w in enumerate(windows)} + results = [None] * len(futures) + for fut in concurrent.futures.as_completed(futures): + i = futures[fut] + try: + node = fut.result() + if node: + results[i] = node + except Exception as e: + logger.error(f"[ChatFast] error: {e}") + chat_nodes = [r for r in results if r] + return chat_nodes + else: + logger.debug("Using unified Fine Mode") + chat_read_nodes = [] + for w in windows: + resp = self._get_llm_response(w["text"]) + for m in resp.get("memory list", []): + try: + memory_type = ( + m.get("memory_type", "LongTermMemory") + .replace("长期记忆", "LongTermMemory") + .replace("用户记忆", "UserMemory") + ) + node = self._make_memory_item( + value=m.get("value", ""), + info=info, + memory_type=memory_type, + tags=m.get("tags", []), + key=m.get("key", ""), + sources=w["sources"], + background=resp.get("summary", ""), + ) + chat_read_nodes.append(node) + except Exception as e: + logger.error(f"[ChatFine] parse error: {e}") + return chat_read_nodes + + def _process_transfer_chat_data(self, raw_node: TextualMemoryItem): + raw_memory = raw_node.memory + response_json = self._get_llm_response(raw_memory) chat_read_nodes = [] for memory_i_raw in response_json.get("memory list", []): try: @@ -172,28 +320,23 @@ def _process_chat_data(self, scene_data_info, info): .replace("长期记忆", "LongTermMemory") .replace("用户记忆", "UserMemory") ) - if memory_type not in ["LongTermMemory", "UserMemory"]: memory_type = "LongTermMemory" - - node_i = TextualMemoryItem( - memory=memory_i_raw.get("value", ""), - metadata=TreeNodeTextualMemoryMetadata( - user_id=info.get("user_id"), - session_id=info.get("session_id"), - memory_type=memory_type, - status="activated", - tags=memory_i_raw.get("tags", []) - if type(memory_i_raw.get("tags", [])) is list - else [], - key=memory_i_raw.get("key", ""), - embedding=self.embedder.embed([memory_i_raw.get("value", "")])[0], - usage=[], - sources=scene_data_info, - background=response_json.get("summary", ""), - confidence=0.99, - type="fact", - ), + node_i = self._make_memory_item( + value=memory_i_raw.get("value", ""), + info={ + "user_id": raw_node.metadata.user_id, + "session_id": raw_node.metadata.session_id, + }, + memory_type=memory_type, + tags=memory_i_raw.get("tags", []) + if isinstance(memory_i_raw.get("tags", []), list) + else [], + key=memory_i_raw.get("key", ""), + sources=raw_node.metadata.sources, + background=response_json.get("summary", ""), + type_="fact", + confidence=0.99, ) chat_read_nodes.append(node_i) except Exception as e: @@ -202,7 +345,7 @@ def _process_chat_data(self, scene_data_info, info): return chat_read_nodes def get_memory( - self, scene_data: list, type: str, info: dict[str, Any] + self, scene_data: list, type: str, info: dict[str, Any], mode: str = "fine" ) -> list[list[TextualMemoryItem]]: """ Extract and classify memory content from scene_data. @@ -219,6 +362,8 @@ def get_memory( - topic_chunk_overlap: Overlap for large topic chunks (default: 100) - chunk_size: Size for small chunks (default: 256) - chunk_overlap: Overlap for small chunks (default: 50) + mode: mem-reader mode, fast for quick process while fine for + better understanding via calling llm Returns: list[list[TextualMemoryItem]] containing memory content with summaries as keys and original text as values Raises: @@ -253,13 +398,48 @@ def get_memory( # Process Q&A pairs concurrently with context propagation with ContextThreadPoolExecutor() as executor: futures = [ - executor.submit(processing_func, scene_data_info, info) + executor.submit(processing_func, scene_data_info, info, mode=mode) for scene_data_info in list_scene_data_info ] for future in concurrent.futures.as_completed(futures): - res_memory = future.result() - memory_list.append(res_memory) + try: + res_memory = future.result() + if res_memory is not None: + memory_list.append(res_memory) + except Exception as e: + logger.error(f"Task failed with exception: {e}") + logger.error(traceback.format_exc()) + return memory_list + + def fine_transfer_simple_mem( + self, input_memories: list[TextualMemoryItem], type: str + ) -> list[list[TextualMemoryItem]]: + if not input_memories: + return [] + + memory_list = [] + if type == "chat": + processing_func = self._process_transfer_chat_data + elif type == "doc": + processing_func = self._process_transfer_doc_data + else: + processing_func = self._process_transfer_doc_data + + # Process Q&A pairs concurrently with context propagation + with ContextThreadPoolExecutor() as executor: + futures = [ + executor.submit(processing_func, scene_data_info) + for scene_data_info in input_memories + ] + for future in concurrent.futures.as_completed(futures): + try: + res_memory = future.result() + if res_memory is not None: + memory_list.append(res_memory) + except Exception as e: + logger.error(f"Task failed with exception: {e}") + logger.error(traceback.format_exc()) return memory_list def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: @@ -275,13 +455,6 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: List of strings containing the processed scene data """ results = [] - parser_config = ParserConfigFactory.model_validate( - { - "backend": "markitdown", - "config": {}, - } - ) - parser = ParserFactory.from_config(parser_config) if type == "chat": for items in scene_data: @@ -299,6 +472,13 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: if result: results.append(result) elif type == "doc": + parser_config = ParserConfigFactory.model_validate( + { + "backend": "markitdown", + "config": {}, + } + ) + parser = ParserFactory.from_config(parser_config) for item in scene_data: try: if os.path.exists(item): @@ -317,6 +497,9 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: return results def _process_doc_data(self, scene_data_info, info, **kwargs): + mode = kwargs.get("mode", "fine") + if mode == "fast": + raise NotImplementedError chunks = self.chunker.chunk(scene_data_info["text"]) messages = [] for chunk in chunks: @@ -357,19 +540,48 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): logger.error(f"[DocReader] Future task failed: {e}") return doc_nodes - def parse_json_result(self, response_text): + def _process_transfer_doc_data(self, raw_node: TextualMemoryItem): + raise NotImplementedError + + def parse_json_result(self, response_text: str) -> dict: + s = (response_text or "").strip() + + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) + s = (m.group(1) if m else s.replace("```", "")).strip() + + i = s.find("{") + if i == -1: + return {} + s = s[i:].strip() + try: - json_start = response_text.find("{") - response_text = response_text[json_start:] - response_text = response_text.replace("```", "").strip() - if not response_text.endswith("}"): - response_text += "}" - return json.loads(response_text) + return json.loads(s) + except json.JSONDecodeError: + pass + + j = max(s.rfind("}"), s.rfind("]")) + if j != -1: + try: + return json.loads(s[: j + 1]) + except json.JSONDecodeError: + pass + + def _cheap_close(t: str) -> str: + t += "}" * max(0, t.count("{") - t.count("}")) + t += "]" * max(0, t.count("[") - t.count("]")) + return t + + t = _cheap_close(s) + try: + return json.loads(t) except json.JSONDecodeError as e: - logger.error(f"[JSONParse] Failed to decode JSON: {e}\nRaw:\n{response_text}") - return {} - except Exception as e: - logger.error(f"[JSONParse] Unexpected error: {e}") + if "Invalid \\escape" in str(e): + s = s.replace("\\", "\\\\") + return json.loads(s) + logger.error( + f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \ + json: {s}" + ) return {} def transform_memreader(self, data: dict) -> list[TextualMemoryItem]: diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 4f8b0719b..1e8b042b1 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -76,6 +76,7 @@ def __init__(self, config: BaseSchedulerConfig): self.db_engine: Engine | None = None self.monitor: SchedulerGeneralMonitor | None = None self.dispatcher_monitor: SchedulerDispatcherMonitor | None = None + self.mem_reader = None # Will be set by MOSCore self.dispatcher = SchedulerDispatcher( config=self.config, max_workers=self.thread_pool_max_workers, @@ -87,7 +88,7 @@ def __init__(self, config: BaseSchedulerConfig): # internal message queue self.max_internal_message_queue_size = self.config.get( - "max_internal_message_queue_size", 100 + "max_internal_message_queue_size", 10000 ) self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( maxsize=self.max_internal_message_queue_size @@ -138,12 +139,17 @@ def initialize_modules( self.dispatcher_monitor.start() # initialize with auth_config - if self.auth_config_path is not None and Path(self.auth_config_path).exists(): - self.auth_config = AuthConfig.from_local_config(config_path=self.auth_config_path) - elif AuthConfig.default_config_exists(): - self.auth_config = AuthConfig.from_local_config() - else: - self.auth_config = AuthConfig.from_local_env() + try: + if self.auth_config_path is not None and Path(self.auth_config_path).exists(): + self.auth_config = AuthConfig.from_local_config( + config_path=self.auth_config_path + ) + elif AuthConfig.default_config_exists(): + self.auth_config = AuthConfig.from_local_config() + else: + self.auth_config = AuthConfig.from_local_env() + except Exception: + pass if self.auth_config is not None: self.rabbitmq_config = self.auth_config.rabbitmq @@ -730,3 +736,139 @@ def _cleanup_queues(self) -> None: self._web_log_message_queue.get_nowait() except queue.Empty: pass + + def mem_scheduler_wait( + self, timeout: float = 180.0, poll: float = 0.1, log_every: float = 0.01 + ) -> bool: + """ + Uses EWMA throughput, detects leaked `unfinished_tasks`, and waits for dispatcher. + """ + deadline = time.monotonic() + timeout + + # --- helpers (local, no external deps) --- + def _unfinished() -> int: + """Prefer `unfinished_tasks`; fallback to `qsize()`.""" + try: + u = getattr(self.memos_message_queue, "unfinished_tasks", None) + if u is not None: + return int(u) + except Exception: + pass + try: + return int(self.memos_message_queue.qsize()) + except Exception: + return 0 + + def _fmt_eta(seconds: float | None) -> str: + """Format seconds to human-readable string.""" + if seconds is None or seconds != seconds or seconds == float("inf"): + return "unknown" + s = max(0, int(seconds)) + h, s = divmod(s, 3600) + m, s = divmod(s, 60) + if h > 0: + return f"{h:d}h{m:02d}m{s:02d}s" + if m > 0: + return f"{m:d}m{s:02d}s" + return f"{s:d}s" + + # --- EWMA throughput state (tasks/s) --- + alpha = 0.3 + rate = 0.0 + last_t = None # type: float | None + last_done = 0 + + # --- dynamic totals & stuck detection --- + init_unfinished = _unfinished() + done_total = 0 + last_unfinished = None + stuck_ticks = 0 + next_log = 0.0 + + while True: + # 1) read counters + curr_unfinished = _unfinished() + try: + qsz = int(self.memos_message_queue.qsize()) + except Exception: + qsz = -1 + + pend = run = 0 + stats_fn = getattr(self.dispatcher, "stats", None) + if self.enable_parallel_dispatch and self.dispatcher is not None and callable(stats_fn): + try: + st = ( + stats_fn() + ) # expected: {'pending':int,'running':int,'done':int?,'rate':float?} + pend = int(st.get("pending", 0)) + run = int(st.get("running", 0)) + except Exception: + pass + + # 2) dynamic total (allows new tasks queued while waiting) + total_now = max(init_unfinished, done_total + curr_unfinished) + done_total = max(0, total_now - curr_unfinished) + + # 3) update EWMA throughput + now = time.monotonic() + if last_t is None: + last_t = now + else: + dt = max(1e-6, now - last_t) + dc = max(0, done_total - last_done) + inst = dc / dt + rate = inst if rate == 0.0 else alpha * inst + (1 - alpha) * rate + last_t = now + last_done = done_total + + eta = None if rate <= 1e-9 else (curr_unfinished / rate) + + # 4) progress log (throttled) + if now >= next_log: + print( + f"[mem_scheduler_wait] remaining≈{curr_unfinished} | throughput≈{rate:.2f} msg/s | ETA≈{_fmt_eta(eta)} " + f"| qsize={qsz} pending={pend} running={run}" + ) + next_log = now + max(0.2, log_every) + + # 5) exit / stuck detection + idle_dispatcher = ( + (pend == 0 and run == 0) + if (self.enable_parallel_dispatch and self.dispatcher is not None) + else True + ) + if curr_unfinished == 0: + break + if curr_unfinished > 0 and qsz == 0 and idle_dispatcher: + if last_unfinished == curr_unfinished: + stuck_ticks += 1 + else: + stuck_ticks = 0 + else: + stuck_ticks = 0 + last_unfinished = curr_unfinished + + if stuck_ticks >= 3: + logger.warning( + "mem_scheduler_wait: detected leaked 'unfinished_tasks' -> treating queue as drained" + ) + break + + if now >= deadline: + logger.warning("mem_scheduler_wait: queue did not drain before timeout") + return False + + time.sleep(poll) + + # 6) wait dispatcher (second stage) + remaining = max(0.0, deadline - time.monotonic()) + if self.enable_parallel_dispatch and self.dispatcher is not None: + try: + ok = self.dispatcher.join(timeout=remaining if remaining > 0 else 0) + except TypeError: + ok = self.dispatcher.join() + if not ok: + logger.warning("mem_scheduler_wait: dispatcher did not complete before timeout") + return False + + return True diff --git a/src/memos/mem_scheduler/general_modules/misc.py b/src/memos/mem_scheduler/general_modules/misc.py index 7dda25a29..6f05bf72f 100644 --- a/src/memos/mem_scheduler/general_modules/misc.py +++ b/src/memos/mem_scheduler/general_modules/misc.py @@ -205,7 +205,9 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non """Put an item into the queue. If the queue is full, the oldest item will be automatically removed to make space. - This operation is thread-safe. + IMPORTANT: When we drop an item we also call `task_done()` to keep + the internal `unfinished_tasks` counter consistent (the dropped task + will never be processed). Args: item: The item to be put into the queue @@ -216,19 +218,34 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non # First try non-blocking put super().put(item, block=block, timeout=timeout) except Full: + # Remove oldest item and mark it done to avoid leaking unfinished_tasks with suppress(Empty): - self.get_nowait() # Remove oldest item + _ = self.get_nowait() + # If the removed item had previously incremented unfinished_tasks, + # we must decrement here since it will never be processed. + with suppress(ValueError): + self.task_done() # Retry putting the new item super().put(item, block=block, timeout=timeout) def get_queue_content_without_pop(self) -> list[T]: """Return a copy of the queue's contents without modifying it.""" - return list(self.queue) + # Ensure a consistent snapshot by holding the mutex + with self.mutex: + return list(self.queue) def clear(self) -> None: """Remove all items from the queue. This operation is thread-safe. + IMPORTANT: We also decrement `unfinished_tasks` by the number of + items cleared, since those tasks will never be processed. """ with self.mutex: + dropped = len(self.queue) self.queue.clear() + # Call task_done() outside of the mutex to avoid deadlocks because + # Queue.task_done() acquires the same condition bound to `self.mutex`. + for _ in range(dropped): + with suppress(ValueError): + self.task_done() diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 25c7b78fd..f47cc0cc5 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -1,4 +1,6 @@ +import concurrent.futures import json +import traceback from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.log import get_logger @@ -8,6 +10,8 @@ ADD_LABEL, ANSWER_LABEL, DEFAULT_MAX_QUERY_KEY_WORDS, + MEM_ORGANIZE_LABEL, + MEM_READ_LABEL, QUERY_LABEL, WORKING_MEMORY_TYPE, MemCubeID, @@ -34,6 +38,8 @@ def __init__(self, config: GeneralSchedulerConfig): QUERY_LABEL: self._query_message_consumer, ANSWER_LABEL: self._answer_message_consumer, ADD_LABEL: self._add_message_consumer, + MEM_READ_LABEL: self._mem_read_message_consumer, + MEM_ORGANIZE_LABEL: self._mem_reorganize_message_consumer, } self.dispatcher.register_handlers(handlers) @@ -180,7 +186,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.info(f"Messages {messages} assigned to {ADD_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher.group_messages_by_user_and_cube(messages=messages) + grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) self.validate_schedule_messages(messages=messages, label=ADD_LABEL) try: @@ -203,7 +209,15 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: mem_cube = msg.mem_cube for memory_id in userinput_memory_ids: - mem_item: TextualMemoryItem = mem_cube.text_mem.get(memory_id=memory_id) + try: + mem_item: TextualMemoryItem = mem_cube.text_mem.get( + memory_id=memory_id + ) + except Exception: + logger.warning( + f"This MemoryItem {memory_id} has already been deleted." + ) + continue mem_type = mem_item.metadata.memory_type mem_content = mem_item.memory @@ -222,6 +236,238 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: except Exception as e: logger.error(f"Error: {e}", exc_info=True) + def _mem_read_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {MEM_READ_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + user_id = message.user_id + mem_cube_id = message.mem_cube_id + mem_cube = message.mem_cube + content = message.content + + # Parse the memory IDs from content + mem_ids = json.loads(content) if isinstance(content, str) else content + if not mem_ids: + return + + logger.info( + f"Processing mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}" + ) + + # Get the text memory from the mem_cube + text_mem = mem_cube.text_mem + if not isinstance(text_mem, TreeTextMemory): + logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}") + return + + # Use mem_reader to process the memories + self._process_memories_with_reader( + mem_ids=mem_ids, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + text_mem=text_mem, + ) + + logger.info( + f"Successfully processed mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}" + ) + + except Exception as e: + logger.error(f"Error processing mem_read message: {e}", exc_info=True) + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error(f"Thread task failed: {e}", exc_info=True) + + def _process_memories_with_reader( + self, + mem_ids: list[str], + user_id: str, + mem_cube_id: str, + mem_cube: GeneralMemCube, + text_mem: TreeTextMemory, + ) -> None: + """ + Process memories using mem_reader for enhanced memory processing. + + Args: + mem_ids: List of memory IDs to process + user_id: User ID + mem_cube_id: Memory cube ID + mem_cube: Memory cube instance + text_mem: Text memory instance + """ + try: + # Get the mem_reader from the parent MOSCore + if not hasattr(self, "mem_reader") or self.mem_reader is None: + logger.warning( + "mem_reader not available in scheduler, skipping enhanced processing" + ) + return + + # Get the original memory items + memory_items = [] + for mem_id in mem_ids: + try: + memory_item = text_mem.get(mem_id) + memory_items.append(memory_item) + except Exception as e: + logger.warning(f"Failed to get memory {mem_id}: {e}") + continue + + if not memory_items: + logger.warning("No valid memory items found for processing") + return + + # Use mem_reader to process the memories + logger.info(f"Processing {len(memory_items)} memories with mem_reader") + + # Extract memories using mem_reader + try: + processed_memories = self.mem_reader.fine_transfer_simple_mem( + memory_items, + type="chat", + ) + except Exception as e: + logger.warning(f"{e}: Fail to transfer mem: {memory_items}") + processed_memories = [] + + if processed_memories and len(processed_memories) > 0: + # Flatten the results (mem_reader returns list of lists) + flattened_memories = [] + for memory_list in processed_memories: + flattened_memories.extend(memory_list) + + logger.info(f"mem_reader processed {len(flattened_memories)} enhanced memories") + + # Add the enhanced memories back to the memory system + if flattened_memories: + enhanced_mem_ids = text_mem.add(flattened_memories) + logger.info( + f"Added {len(enhanced_mem_ids)} enhanced memories: {enhanced_mem_ids}" + ) + else: + logger.info("No enhanced memories generated by mem_reader") + else: + logger.info("mem_reader returned no processed memories") + + text_mem.delete(mem_ids) + logger.info("Delete raw mem_ids") + text_mem.memory_manager.remove_and_refresh_memory() + logger.info("Remove and Refresh Memories") + logger.debug(f"Finished add {user_id} memory: {mem_ids}") + + except Exception: + logger.error( + f"Error in _process_memories_with_reader: {traceback.format_exc()}", exc_info=True + ) + + def _mem_reorganize_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {MEM_READ_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + user_id = message.user_id + mem_cube_id = message.mem_cube_id + mem_cube = message.mem_cube + content = message.content + + # Parse the memory IDs from content + mem_ids = json.loads(content) if isinstance(content, str) else content + if not mem_ids: + return + + logger.info( + f"Processing mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}" + ) + + # Get the text memory from the mem_cube + text_mem = mem_cube.text_mem + if not isinstance(text_mem, TreeTextMemory): + logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}") + return + + # Use mem_reader to process the memories + self._process_memories_with_reorganize( + mem_ids=mem_ids, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + text_mem=text_mem, + ) + + logger.info( + f"Successfully processed mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}" + ) + + except Exception as e: + logger.error(f"Error processing mem_read message: {e}", exc_info=True) + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error(f"Thread task failed: {e}", exc_info=True) + + def _process_memories_with_reorganize( + self, + mem_ids: list[str], + user_id: str, + mem_cube_id: str, + mem_cube: GeneralMemCube, + text_mem: TreeTextMemory, + ) -> None: + """ + Process memories using mem_reorganize for enhanced memory processing. + + Args: + mem_ids: List of memory IDs to process + user_id: User ID + mem_cube_id: Memory cube ID + mem_cube: Memory cube instance + text_mem: Text memory instance + """ + try: + # Get the mem_reader from the parent MOSCore + if not hasattr(self, "mem_reader") or self.mem_reader is None: + logger.warning( + "mem_reader not available in scheduler, skipping enhanced processing" + ) + return + + # Get the original memory items + memory_items = [] + for mem_id in mem_ids: + try: + memory_item = text_mem.get(mem_id) + memory_items.append(memory_item) + except Exception as e: + logger.warning(f"Failed to get memory {mem_id}: {e}") + continue + + if not memory_items: + logger.warning("No valid memory items found for processing") + return + + # Use mem_reader to process the memories + logger.info(f"Processing {len(memory_items)} memories with mem_reader") + text_mem.memory_manager.remove_and_refresh_memory() + logger.info("Remove and Refresh Memories") + logger.debug(f"Finished add {user_id} memory: {mem_ids}") + + except Exception: + logger.error( + f"Error in _process_memories_with_reader: {traceback.format_exc()}", exc_info=True + ) + def process_session_turn( self, queries: str | list[str], diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index d0d83091b..248c42e80 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -8,6 +8,8 @@ QUERY_LABEL = "query" ANSWER_LABEL = "answer" ADD_LABEL = "add" +MEM_READ_LABEL = "mem_read" +MEM_ORGANIZE_LABEL = "mem_organize" TreeTextMemory_SEARCH_METHOD = "tree_text_memory_search" TreeTextMemory_FINE_SEARCH_METHOD = "tree_text_memory_fine_search" diff --git a/src/memos/memories/textual/base.py b/src/memos/memories/textual/base.py index 82dad4486..8a6113345 100644 --- a/src/memos/memories/textual/base.py +++ b/src/memos/memories/textual/base.py @@ -10,6 +10,9 @@ class BaseTextMemory(BaseMemory): """Base class for all textual memory implementations.""" + # Default mode configuration - can be overridden by subclasses + mode: str = "sync" # Default mode: 'async' or 'sync' + @abstractmethod def __init__(self, config: BaseTextMemoryConfig): """Initialize memory with the given configuration.""" diff --git a/src/memos/memories/textual/general.py b/src/memos/memories/textual/general.py index 9793224b5..d71a86d2e 100644 --- a/src/memos/memories/textual/general.py +++ b/src/memos/memories/textual/general.py @@ -26,6 +26,8 @@ class GeneralTextMemory(BaseTextMemory): def __init__(self, config: GeneralTextMemoryConfig): """Initialize memory with the given configuration.""" + # Set mode from class default or override if needed + self.mode = getattr(self.__class__, "mode", "sync") self.config: GeneralTextMemoryConfig = config self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( config.extractor_llm diff --git a/src/memos/memories/textual/naive.py b/src/memos/memories/textual/naive.py index f8684729a..7bc49e767 100644 --- a/src/memos/memories/textual/naive.py +++ b/src/memos/memories/textual/naive.py @@ -61,6 +61,8 @@ class NaiveTextMemory(BaseTextMemory): def __init__(self, config: NaiveTextMemoryConfig): """Initialize memory with the given configuration.""" + # Set mode from class default or override if needed + self.mode = getattr(self.__class__, "mode", "sync") self.config = config self.extractor_llm = LLMFactory.from_config(config.extractor_llm) self.memories = [] diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 0048f4a59..fccd83fa6 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -2,7 +2,6 @@ import os import shutil import tempfile -import time from datetime import datetime from pathlib import Path @@ -33,28 +32,17 @@ class TreeTextMemory(BaseTextMemory): def __init__(self, config: TreeTextMemoryConfig): """Initialize memory with the given configuration.""" - time_start = time.time() + # Set mode from class default or override if needed + self.mode = config.mode self.config: TreeTextMemoryConfig = config self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( config.extractor_llm ) - logger.info(f"time init: extractor_llm time is: {time.time() - time_start}") - - time_start_ex = time.time() self.dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( config.dispatcher_llm ) - logger.info(f"time init: dispatcher_llm time is: {time.time() - time_start_ex}") - - time_start_em = time.time() self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder) - logger.info(f"time init: embedder time is: {time.time() - time_start_em}") - - time_start_gs = time.time() self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db) - logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}") - - time_start_rr = time.time() if config.reranker is None: default_cfg = RerankerConfigFactory.model_validate( { @@ -68,10 +56,7 @@ def __init__(self, config: TreeTextMemoryConfig): self.reranker = RerankerFactory.from_config(default_cfg) else: self.reranker = RerankerFactory.from_config(config.reranker) - logger.info(f"time init: reranker time is: {time.time() - time_start_rr}") self.is_reorganize = config.reorganize - - time_start_mm = time.time() self.memory_manager: MemoryManager = MemoryManager( self.graph_store, self.embedder, @@ -84,8 +69,6 @@ def __init__(self, config: TreeTextMemoryConfig): }, is_reorganize=self.is_reorganize, ) - logger.info(f"time init: memory_manager time is: {time.time() - time_start_mm}") - time_start_ir = time.time() # Create internet retriever if configured self.internet_retriever = None if config.internet_retriever is not None: @@ -97,19 +80,13 @@ def __init__(self, config: TreeTextMemoryConfig): ) else: logger.info("No internet retriever configured") - logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}") - def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: + def add(self, memories: list[TextualMemoryItem | dict[str, Any]], **kwargs) -> list[str]: """Add memories. Args: memories: List of TextualMemoryItem objects or dictionaries to add. - Later: - memory_items = [TextualMemoryItem(**m) if isinstance(m, dict) else m for m in memories] - metadata = extract_metadata(memory_items, self.extractor_llm) - plan = plan_memory_operations(memory_items, metadata, self.graph_store) - execute_plan(memory_items, metadata, plan, self.graph_store) """ - return self.memory_manager.add(memories) + return self.memory_manager.add(memories, mode=self.mode) def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: self.memory_manager.replace_working_memory(memories) @@ -294,7 +271,14 @@ def get_all(self) -> dict: return all_items def delete(self, memory_ids: list[str]) -> None: - raise NotImplementedError + """Hard delete: permanently remove nodes and their edges from the graph.""" + if not memory_ids: + return + for mid in memory_ids: + try: + self.graph_store.delete_node(mid) + except Exception as e: + logger.warning(f"TreeTextMemory.delete_hard: failed to delete {mid}: {e}") def delete_all(self) -> None: """Delete all memories and their relationships from the graph store.""" diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 3e1609cb7..54776134b 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -52,13 +52,15 @@ def __init__( ) self._merged_threshold = merged_threshold - def add(self, memories: list[TextualMemoryItem], user_name: str | None = None) -> list[str]: + def add( + self, memories: list[TextualMemoryItem], user_name: str | None = None, mode: str = "sync" + ) -> list[str]: """ - Add new memories in parallel to different memory types (WorkingMemory, LongTermMemory, UserMemory). + Add new memories in parallel to different memory types. """ added_ids: list[str] = [] - with ContextThreadPoolExecutor(max_workers=8) as executor: + with ContextThreadPoolExecutor(max_workers=20) as executor: futures = {executor.submit(self._process_memory, m, user_name): m for m in memories} for future in as_completed(futures, timeout=60): try: @@ -67,17 +69,18 @@ def add(self, memories: list[TextualMemoryItem], user_name: str | None = None) - except Exception as e: logger.exception("Memory processing error: ", exc_info=e) - for mem_type in ["WorkingMemory", "LongTermMemory", "UserMemory"]: - try: - self.graph_store.remove_oldest_memory( - memory_type="WorkingMemory", - keep_latest=self.memory_size[mem_type], - user_name=user_name, - ) - except Exception: - logger.warning(f"Remove {mem_type} error: {traceback.format_exc()}") - - self._refresh_memory_size(user_name=user_name) + if mode == "sync": + for mem_type in ["WorkingMemory", "LongTermMemory", "UserMemory"]: + try: + self.graph_store.remove_oldest_memory( + memory_type="WorkingMemory", + keep_latest=self.memory_size[mem_type], + user_name=user_name, + ) + except Exception: + logger.warning(f"Remove {mem_type} error: {traceback.format_exc()}") + + self._refresh_memory_size(user_name=user_name) return added_ids def replace_working_memory( @@ -129,17 +132,29 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non Process and add memory to different memory types (WorkingMemory, LongTermMemory, UserMemory). This method runs asynchronously to process each memory item. """ - ids = [] - - # Add to WorkingMemory do not return working_id - self._add_memory_to_db(memory, "WorkingMemory", user_name) + ids: list[str] = [] + futures = [] + + with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: + f_working = ex.submit(self._add_memory_to_db, memory, "WorkingMemory", user_name) + futures.append(f_working) + + if memory.metadata.memory_type in ("LongTermMemory", "UserMemory"): + f_graph = ex.submit( + self._add_to_graph_memory, + memory=memory, + memory_type=memory.metadata.memory_type, + user_name=user_name, + ) + futures.append(f_graph) - # Add to LongTermMemory and UserMemory - if memory.metadata.memory_type in ["LongTermMemory", "UserMemory"]: - added_id = self._add_to_graph_memory( - memory=memory, memory_type=memory.metadata.memory_type, user_name=user_name - ) - ids.append(added_id) + for fut in as_completed(futures): + try: + res = fut.result() + if isinstance(res, str) and res: + ids.append(res) + except Exception: + logger.warning("Parallel memory processing failed:\n%s", traceback.format_exc()) return ids @@ -157,7 +172,6 @@ def _add_memory_to_db( # Insert node into graph self.graph_store.add_node(working_memory.id, working_memory.memory, metadata, user_name) - return working_memory.id def _add_to_graph_memory( self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None @@ -268,6 +282,31 @@ def _ensure_structure_path( # Step 3: Return this structure node ID as the parent_id return node_id + def remove_and_refresh_memory(self): + self._cleanup_memories_if_needed() + self._refresh_memory_size() + + def _cleanup_memories_if_needed(self) -> None: + """ + Only clean up memories if we're close to or over the limit. + This reduces unnecessary database operations. + """ + cleanup_threshold = 0.8 # Clean up when 80% full + + for memory_type, limit in self.memory_size.items(): + current_count = self.current_memory_size.get(memory_type, 0) + threshold = int(limit * cleanup_threshold) + + # Only clean up if we're at or above the threshold + if current_count >= threshold: + try: + self.graph_store.remove_oldest_memory( + memory_type=memory_type, keep_latest=limit + ) + logger.debug(f"Cleaned up {memory_type}: {current_count} -> {limit}") + except Exception: + logger.warning(f"Remove {memory_type} error: {traceback.format_exc()}") + def wait_reorganizer(self): """ Wait for the reorganizer to finish processing all messages. diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index d4cfcf501..c1ade3021 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -197,6 +197,7 @@ def _vector_recall( memory_scope: str, top_k: int = 20, max_num: int = 3, + status: str = "activated", cube_name: str | None = None, search_filter: dict | None = None, user_name: str | None = None, @@ -213,6 +214,7 @@ def search_single(vec, filt=None): self.graph_store.search_by_embedding( vector=vec, top_k=top_k, + status=status, scope=memory_scope, cube_name=cube_name, search_filter=filt, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 05db56f53..96c6c97f1 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -134,7 +134,11 @@ def _parse_task( related_nodes = [ self.graph_store.get_node(n["id"]) for n in self.graph_store.search_by_embedding( - query_embedding, top_k=top_k, search_filter=search_filter, user_name=user_name + query_embedding, + top_k=top_k, + status="activated", + search_filter=search_filter, + user_name=user_name, ) ] memories = [] diff --git a/tests/memories/textual/test_tree.py b/tests/memories/textual/test_tree.py index f3e662992..772a79d78 100644 --- a/tests/memories/textual/test_tree.py +++ b/tests/memories/textual/test_tree.py @@ -66,7 +66,7 @@ def test_add_calls_manager(mock_tree_text_memory): metadata=TreeNodeTextualMemoryMetadata(updated_at=None), ) mock_tree_text_memory.add([mock_item]) - mock_tree_text_memory.memory_manager.add.assert_called_once() + mock_tree_text_memory.memory_manager.add.assert_called_once_with([mock_item], mode="sync") def test_get_working_memory_sorted(mock_tree_text_memory): @@ -161,4 +161,4 @@ def test_add_returns_ids(mock_tree_text_memory): result = mock_tree_text_memory.add(mock_items) assert result == dummy_ids - mock_tree_text_memory.memory_manager.add.assert_called_once_with(mock_items) + mock_tree_text_memory.memory_manager.add.assert_called_once_with(mock_items, mode="sync")