Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/memos/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_message(
)
response.raise_for_status()
response_data = response.json()

return MemOSGetMessagesResponse(**response_data)
except Exception as e:
logger.error(f"Failed to get messages (retry {retry + 1}/3): {e}")
Expand All @@ -74,6 +75,7 @@ def add_message(
)
response.raise_for_status()
response_data = response.json()

return MemOSAddResponse(**response_data)
except Exception as e:
logger.error(f"Failed to add memory (retry {retry + 1}/3): {e}")
Expand Down Expand Up @@ -102,6 +104,7 @@ def search_memory(
)
response.raise_for_status()
response_data = response.json()

return MemOSSearchResponse(**response_data)
except Exception as e:
logger.error(f"Failed to search memory (retry {retry + 1}/3): {e}")
Expand Down
12 changes: 12 additions & 0 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,13 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
"embedder": APIConfig.get_embedder_config(),
"internet_retriever": internet_config,
"reranker": APIConfig.get_reranker_config(),
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower()
== "true",
"memory_size": {
"WorkingMemory": os.getenv("NEBULAR_WORKING_MEMORY", 20),
"LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6),
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
},
},
},
"act_mem": {}
Expand Down Expand Up @@ -575,6 +582,11 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower()
== "true",
"internet_retriever": internet_config,
"memory_size": {
"WorkingMemory": os.getenv("NEBULAR_WORKING_MEMORY", 20),
"LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6),
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
},
},
},
"act_mem": {}
Expand Down
2 changes: 1 addition & 1 deletion src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class GetMessagesData(BaseModel):
"""Data model for get messages response based on actual API."""

message_detail_list: list[MessageDetail] = Field(
default_factory=list, alias="memory_detail_list", description="List of message details"
default_factory=list, alias="message_detail_list", description="List of message details"
)


Expand Down
88 changes: 64 additions & 24 deletions src/memos/graph_dbs/nebular.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "N
client = cls._CLIENT_CACHE.get(key)
if client is None:
# Connection setting

tmp_client = NebulaClient(
hosts=cfg.uri,
username=cfg.user,
password=cfg.password,
session_config=SessionConfig(graph=None),
session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000),
)
try:
cls._ensure_space_exists(tmp_client, cfg)
finally:
tmp_client.close()

conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
if conn_conf is None:
conn_conf = ConnectionConfig.from_defults(
Expand Down Expand Up @@ -318,6 +331,7 @@ def __init__(self, config: NebulaGraphDBConfig):
}
"""

assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED"
self.config = config
self.db_name = config.space
self.user_name = config.user_name
Expand Down Expand Up @@ -429,15 +443,21 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None:
if not self.config.use_multi_db and self.config.user_name:
optional_condition = f"AND n.user_name = '{self.config.user_name}'"

query = f"""
MATCH (n@Memory)
WHERE n.memory_type = '{memory_type}'
{optional_condition}
ORDER BY n.updated_at DESC
OFFSET {keep_latest}
DETACH DELETE n
"""
self.execute_query(query)
count = self.count_nodes(memory_type)

if count > keep_latest:
delete_query = f"""
MATCH (n@Memory)
WHERE n.memory_type = '{memory_type}'
{optional_condition}
ORDER BY n.updated_at DESC
OFFSET {keep_latest}
DETACH DELETE n
"""
try:
self.execute_query(delete_query)
except Exception as e:
logger.warning(f"Delete old mem error: {e}")

@timed
def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
Expand Down Expand Up @@ -597,14 +617,19 @@ def get_memory_count(self, memory_type: str) -> int:
return -1

@timed
def count_nodes(self, scope: str) -> int:
query = f"""
MATCH (n@Memory)
WHERE n.memory_type = "{scope}"
"""
def count_nodes(self, scope: str | None = None) -> int:
query = "MATCH (n@Memory)"
conditions = []

if scope:
conditions.append(f'n.memory_type = "{scope}"')
if not self.config.use_multi_db and self.config.user_name:
user_name = self.config.user_name
query += f"\nAND n.user_name = '{user_name}'"
conditions.append(f"n.user_name = '{user_name}'")

if conditions:
query += "\nWHERE " + " AND ".join(conditions)

query += "\nRETURN count(n) AS count"

result = self.execute_query(query)
Expand Down Expand Up @@ -985,8 +1010,7 @@ def search_by_embedding(
dim = len(vector)
vector_str = ",".join(f"{float(x)}" for x in vector)
gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])"

where_clauses = []
where_clauses = [f"n.{self.dim_field} IS NOT NULL"]
if scope:
where_clauses.append(f'n.memory_type = "{scope}"')
if status:
Expand All @@ -1008,15 +1032,12 @@ def search_by_embedding(
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""

gql = f"""
MATCH (n@Memory)
let a = {gql_vector}
MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
{where_clause}
ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
APPROXIMATE
ORDER BY inner_product(n.{self.dim_field}, a) DESC
LIMIT {top_k}
OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }}
RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score
"""

RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score"""
try:
result = self.execute_query(gql)
except Exception as e:
Expand Down Expand Up @@ -1471,6 +1492,25 @@ def merge_nodes(self, id1: str, id2: str) -> str:
"""
raise NotImplementedError

@classmethod
def _ensure_space_exists(cls, tmp_client, cfg):
"""Lightweight check to ensure target graph (space) exists."""
db_name = getattr(cfg, "space", None)
if not db_name:
logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.")
return

try:
res = tmp_client.execute("SHOW GRAPHS;")
existing = {row.values()[0].as_string() for row in res}
if db_name not in existing:
tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type;")
logger.info(f"✅ Graph `{db_name}` created before session binding.")
else:
logger.debug(f"Graph `{db_name}` already exists.")
except Exception:
logger.exception("[NebulaGraphDBSync] Failed to ensure space exists")

@timed
def _ensure_database_exists(self):
graph_type_name = "MemOSBgeM3Type"
Expand Down
4 changes: 2 additions & 2 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,10 @@ def load(self, dir: str) -> None:
except Exception as e:
logger.error(f"An error occurred while loading memories: {e}")

def dump(self, dir: str) -> None:
def dump(self, dir: str, include_embedding: bool = False) -> None:
"""Dump memories to os.path.join(dir, self.config.memory_filename)"""
try:
json_memories = self.graph_store.export_graph()
json_memories = self.graph_store.export_graph(include_embedding=include_embedding)

os.makedirs(dir, exist_ok=True)
memory_file = os.path.join(dir, self.config.memory_filename)
Expand Down
44 changes: 24 additions & 20 deletions src/memos/memories/textual/tree_text_memory/organize/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
"LongTermMemory": 1500,
"UserMemory": 480,
}
logger.info(f"MemorySize is {self.memory_size}")
self._threshold = threshold
self.is_reorganize = is_reorganize
self.reorganizer = GraphStructureReorganizer(
Expand All @@ -66,30 +67,33 @@ def add(self, memories: list[TextualMemoryItem]) -> list[str]:
except Exception as e:
logger.exception("Memory processing error: ", exc_info=e)

try:
self.graph_store.remove_oldest_memory(
memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"]
)
except Exception:
logger.warning(f"Remove WorkingMemory error: {traceback.format_exc()}")

try:
self.graph_store.remove_oldest_memory(
memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"]
)
except Exception:
logger.warning(f"Remove LongTermMemory error: {traceback.format_exc()}")

try:
self.graph_store.remove_oldest_memory(
memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"]
)
except Exception:
logger.warning(f"Remove UserMemory error: {traceback.format_exc()}")
# Only clean up if we're close to or over the limit
self._cleanup_memories_if_needed()

self._refresh_memory_size()
return added_ids

def _cleanup_memories_if_needed(self) -> None:
"""
Only clean up memories if we're close to or over the limit.
This reduces unnecessary database operations.
"""
cleanup_threshold = 0.8 # Clean up when 80% full

for memory_type, limit in self.memory_size.items():
current_count = self.current_memory_size.get(memory_type, 0)
threshold = int(limit * cleanup_threshold)

# Only clean up if we're at or above the threshold
if current_count >= threshold:
try:
self.graph_store.remove_oldest_memory(
memory_type=memory_type, keep_latest=limit
)
logger.debug(f"Cleaned up {memory_type}: {current_count} -> {limit}")
except Exception:
logger.warning(f"Remove {memory_type} error: {traceback.format_exc()}")

def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None:
"""
Replace WorkingMemory
Expand Down