Skip to content

Commit 81bc1f6

Browse files
committed
refactor(server):重构服务器路由模块以优化内存管理
1 parent 235125a commit 81bc1f6

1 file changed

Lines changed: 28 additions & 61 deletions

File tree

Lines changed: 28 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,62 @@
1-
import os
21
import json
2+
import os
33
import time
44
from fastapi import APIRouter
55
from memos import log
66
from memos.api.product_models import (
7-
BaseResponse,
8-
ChatCompleteRequest,
9-
ChatRequest,
10-
GetMemoryRequest,
117
MemoryCreateRequest,
12-
MemoryResponse,
138
SearchRequest,
14-
SearchResponse,
15-
SimpleResponse,
16-
SuggestionRequest,
17-
SuggestionResponse,
18-
UserRegisterRequest,
19-
UserRegisterResponse,
9+
SearchResponse, MemoryResponse,
2010
)
11+
from memos.chunkers.sentence_chunker import SentenceChunker
12+
from memos.configs.chunker import SentenceChunkerConfig
2113
from memos.configs.embedder import UniversalAPIEmbedderConfig
2214
from memos.configs.graph_db import NebulaGraphDBConfig
2315
from memos.configs.llm import OpenAILLMConfig
2416
from memos.embedders.universal_api import UniversalAPIEmbedder
2517
from memos.graph_dbs.nebular import NebulaGraphDB
2618
from memos.llms.openai import OpenAILLM
27-
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
19+
from memos.mem_reader.simple_struct import SimpleStructMemReader
2820
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
21+
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
2922
from memos.reranker.cosine_local import CosineLocalReranker
30-
from memos.mem_reader.simple_struct import SimpleStructMemReader
31-
from memos.configs.mem_reader import SimpleStructMemReaderConfig
32-
from memos.configs.chunker import ChunkerConfigFactory
33-
from memos.configs.llm import LLMConfigFactory
34-
from memos.configs.embedder import EmbedderConfigFactory
35-
from memos.configs.chunker import SentenceChunkerConfig
36-
from memos.chunkers.sentence_chunker import SentenceChunker
3723

3824
logger = log.get_logger(__name__)
3925
router = APIRouter()
4026

27+
4128
def init_model():
4229
llm = OpenAILLM(
43-
OpenAILLMConfig(model_schema='memos.configs.llm.OpenAILLMConfig', model_name_or_path='gpt-4o',
30+
OpenAILLMConfig(model_schema='memos.configs.llm.OpenAILLMConfig', model_name_or_path='gpt-4o-mini',
4431
temperature=0.8, max_tokens=1024, top_p=0.9, top_k=50, remove_think_prefix=True,
4532
api_key=os.getenv('OPENAI_API_KEY'),
4633
api_base=os.getenv('OPENAI_API_BASE'), extra_body=None))
4734
embedder = UniversalAPIEmbedder(
4835
UniversalAPIEmbedderConfig(model_schema='memos.configs.embedder.UniversalAPIEmbedderConfig',
49-
model_name_or_path='bge-m3', embedding_dims=None, provider='openai',
50-
api_key='EMPTY', base_url=os.getenv('MOS_EMBEDDER_API_BASE')))
36+
model_name_or_path='bge-m3', embedding_dims=None, provider='openai',
37+
api_key='EMPTY', base_url=os.getenv('MOS_EMBEDDER_API_BASE')))
5138

5239
reranker = CosineLocalReranker(level_weights={"topic": 1.0, "concept": 1.0, "fact": 1.0}, level_field='background')
5340

5441
graph_store = NebulaGraphDB(
5542
NebulaGraphDBConfig(model_schema='memos.configs.graph_db.NebulaGraphDBConfig',
5643
uri=json.loads(os.getenv('NEBULAR_HOSTS')),
57-
user=os.getenv('NEBULAR_USER'), password=os.getenv('NEBULAR_PASSWORD'), space=os.getenv('NEBULAR_SPACE'),
44+
user=os.getenv('NEBULAR_USER'), password=os.getenv('NEBULAR_PASSWORD'),
45+
space=os.getenv('NEBULAR_SPACE'),
5846
auto_create=True, max_client=1000, embedding_dimension=1024))
5947
search_obj = Searcher(llm, graph_store, embedder, reranker, internet_retriever=None, moscube=False)
6048
chunker = SentenceChunker(
61-
SentenceChunkerConfig(
62-
model_schema='memos.configs.chunker.SentenceChunkerConfig',
63-
tokenizer_or_token_counter="gpt2",
64-
chunk_size=512,
65-
chunk_overlap=128,
66-
min_sentences_per_chunk=1,
67-
)
68-
)
69-
mem_reader = SimpleStructMemReader(
70-
llm,
71-
embedder,
72-
chunker
73-
)
74-
memory_add_obj = MemoryManager(
75-
graph_store,
76-
embedder,
77-
llm,
78-
memory_size={
79-
"WorkingMemory": 20,
80-
"LongTermMemory": 1500,
81-
"UserMemory": 480,
82-
},
83-
is_reorganize=False
84-
)
49+
SentenceChunkerConfig(model_schema='memos.configs.chunker.SentenceChunkerConfig',
50+
tokenizer_or_token_counter="gpt2", chunk_size=512, chunk_overlap=128,
51+
min_sentences_per_chunk=1))
52+
mem_reader = SimpleStructMemReader(llm, embedder, chunker)
53+
memory_add_obj = MemoryManager(graph_store, embedder, llm,
54+
memory_size={"WorkingMemory": 20, "LongTermMemory": 1500, "UserMemory": 480},
55+
is_reorganize=False)
8556

8657
return search_obj, memory_add_obj, mem_reader
8758

59+
8860
search_obj, memory_add_obj, mem_reader = init_model()
8961

9062

@@ -95,30 +67,25 @@ def search_memories(search_req: SearchRequest):
9567
# user_id = f"memos{search_req.user_id.replace('-', '')}"
9668
user_id = search_req.user_id
9769
res = search_obj.search(query=search_req.query, user_id=user_id, top_k=search_req.top_k
98-
, mode="fast", search_filter=None,
99-
info={'user_id': user_id, 'session_id': 'root_session', 'chat_history': []})
70+
, mode="fast", search_filter=None,
71+
info={'user_id': user_id, 'session_id': 'root_session', 'chat_history': []})
10072
res = {"d": res}
10173
# print(res)
10274
return SearchResponse(message="Search completed successfully", data=res)
10375

10476

105-
@router.post("/add", summary="add memories", response_model=SearchResponse)
77+
@router.post("/add", summary="add memories", response_model=MemoryResponse)
10678
def add_memories(add_req: MemoryCreateRequest):
10779
"""Add memories for a specific user."""
10880
time_start = time.time()
109-
11081
memories = mem_reader.get_memory(
11182
[add_req.messages],
11283
type="chat",
113-
info={"user_id": add_req.user_id, "session_id": add_req.session_id},)[0]
114-
logger.info(
115-
f"time add: get mem_reader time user_id: {add_req.user_id} time is: {time.time() - time_start:.2f}s"
116-
)
117-
data = []
118-
84+
info={"user_id": add_req.user_id, "session_id": add_req.session_id})[0]
85+
logger.info(f"time add: get mem_reader time user_id: {add_req.user_id} time is: {time.time() - time_start:.2f}s")
11986
mem_id_list: list[str] = memory_add_obj.add(memories, user_name=add_req.user_id)
12087
logger.info(f"Added memory for user {add_req.user_id} in session {add_req.session_id}: {mem_id_list}")
121-
88+
data = []
12289
for m_id, m in zip(mem_id_list, memories):
123-
data.append({'memory': m.memory, 'mem_ids': m_id, 'memory_type': m.metadata.memory_type})
124-
return SearchResponse(message="Memory added successfully", data=data)
90+
data.append({'memory': m.memory, 'memory_id': m_id, 'memory_type': m.metadata.memory_type})
91+
return MemoryResponse(message="Memory added successfully", data=data)

0 commit comments

Comments
 (0)