Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,10 @@ lightmem = LightMemory.from_config(config_dict)
session = {
"timestamp": "2025-01-10",
"turns": [
[
{"role": "user", "content": "My favorite ice cream flavor is pistachio, and my dog's name is Rex."},
{"role": "assistant", "content": "Got it. Pistachio is a great choice."}],
]
[
{"role": "user", "content": "My favorite ice cream flavor is pistachio, and my dog's name is Rex.", "speaker_name": "John","speaker_id": "speaker_a"},
{"role": "assistant", "content": "Got it. Pistachio is a great choice.", "speaker_name": "Assistant", "speaker_id": "speaker_b"}],
]
}


Expand Down Expand Up @@ -377,10 +377,10 @@ We welcome contributions from the community! If you'd like to contribute, please
</a>
</td>
<td align="center" width="150">
<a href="https://github.com/usememos/memos">
<img src="https://avatars.githubusercontent.com/usememos" width="80" style="border-radius:50%" alt="Memos"/>
<a href="https://github.com/MemTensor/MemOS">
<img src="https://avatars.githubusercontent.com/MemTensor" width="80" style="border-radius:50%" alt="MemOS"/>
<br />
<sub><b>Memos</b></sub>
<sub><b>MemOS</b></sub>
</a>
</td>
<td align="center" width="150">
Expand Down
2 changes: 2 additions & 0 deletions src/lightmem/configs/memory_manager/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(
max_tokens: int = 2000,
top_p: float = 0.1,
top_k: int = 1,
include_topic_summary: bool = False,
enable_vision: bool = False,
vision_details: Optional[str] = "auto",
# Openai specific
Expand All @@ -31,6 +32,7 @@ def __init__(
self.max_tokens = max_tokens
self.top_p = top_p
self.top_k = top_k
self.include_topic_summary = include_topic_summary
self.enable_vision = enable_vision
self.vision_details = vision_details
# Openai specific
Expand Down
2 changes: 1 addition & 1 deletion src/lightmem/configs/text_embedder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TextEmbedderConfig(BaseModel):
description="The embedding model or Deployment platform (e.g., 'openai', 'huggingface')"
)

_model_list: ClassVar[List[str]] = ["huggingface"]
_model_list: ClassVar[List[str]] = ["huggingface", "openai"]

configs: Optional[Union[BaseTextEmbedderConfig, Dict[str, Any]]] = Field(
default=None,
Expand Down
2 changes: 1 addition & 1 deletion src/lightmem/factory/memory_buffer/short_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(self, max_tokens: int = 2000, tokenizer: Optional[Any] = None):
self.tokenizer = resolve_tokenizer(tokenizer)
self.buffer: List[List[Dict[str, Any]]] = []
self.token_count: int = 0

print(f"ShortMemBufferManager initialized with max_tokens={self.max_tokens}")
def _count_tokens(self, messages: List[Dict[str, Any]], messages_use: str) -> int:
role_map = {
"user_only": ["user"],
Expand Down
99 changes: 70 additions & 29 deletions src/lightmem/factory/memory_manager/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import concurrent
from collections import defaultdict
from openai import OpenAI
from typing import List, Dict, Optional, Literal
from typing import List, Dict, Optional, Literal, Any
import json, os, warnings
import httpx
from lightmem.configs.memory_manager.base_config import BaseMemoryManagerConfig
Expand Down Expand Up @@ -120,28 +121,41 @@ def generate_response(
params["tool_choice"] = tool_choice

response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools)

usage_info = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
parsed_response = self._parse_response(response, tools)

return parsed_response, usage_info

def meta_text_extract(
self,
system_prompt: str,
extract_list: List[List[List[Dict]]],
messages_use: Literal["user_only", "assistant_only", "hybrid"] = "user_only"
timestamps_list: Optional[List[List[List[str]]]] = None,
weekday_list: Optional[List[List[List[str]]]] = None,
messages_use: Literal["user_only", "assistant_only", "hybrid"] = "user_only",
topic_id_mapping: Optional[List[List[int]]] = None
) -> List[Optional[Dict]]:
"""
Extract metadata from text segments using parallel processing.

Args:
system_prompt: The system prompt for metadata generation
all_segments: List of message segments to process
extract_list: List of message segments to process
timestamps_list: Optional list of timestamps (reserved for future use)
weekday_list: Optional list of weekdays (reserved for future use)
messages_use: Strategy for which messages to use
topic_id_mapping: For each API call, the global topic IDs

Returns:
List of extracted metadata results, None for failed segments
"""
if not extract_list:
return []

def concatenate_messages(segment: List[Dict], messages_use: str) -> str:
"""Concatenate messages based on usage strategy"""
role_filter = {
Expand All @@ -161,43 +175,69 @@ def concatenate_messages(segment: List[Dict], messages_use: str) -> str:
sequence_id = mes["sequence_number"]
role = mes["role"]
content = mes.get("content", "")
message_lines.append(f"{sequence_id}.{role}: {content}")

speaker_name = mes.get("speaker_name", "")
time_stamp = mes.get("time_stamp", "")
weekday = mes.get("weekday", "")

time_prefix = ""
if time_stamp and weekday:
time_prefix = f"[{time_stamp}, {weekday}] "

if speaker_name != 'Unknown':
message_lines.append(f"{time_prefix}{sequence_id//2}.{speaker_name}: {content}")
else:
message_lines.append(f"{time_prefix}{sequence_id//2}.{role}: {content}")

return "\n".join(message_lines)

max_workers = min(len(extract_list), 5)

def process_segment_wrapper(api_call_segments: List[List[Dict]]):
"""Process one API call (multiple topic segments inside)"""
def process_segment_wrapper(args):
api_call_idx, api_call_segments = args
try:
user_prompt_parts = []
for idx, topic_segment in enumerate(api_call_segments, start=1):
user_prompt_parts: List[str] = []

global_topic_ids: List[int] = []
if topic_id_mapping and api_call_idx < len(topic_id_mapping):
global_topic_ids = topic_id_mapping[api_call_idx]

for topic_idx, topic_segment in enumerate(api_call_segments):
if topic_idx < len(global_topic_ids):
global_topic_id = global_topic_ids[topic_idx]
else:
global_topic_id = topic_idx + 1

topic_text = concatenate_messages(topic_segment, messages_use)
user_prompt_parts.append(f"--- Topic {idx} ---\n{topic_text}")
user_prompt_parts.append(f"--- Topic {global_topic_id} ---\n{topic_text}")

print(f"User prompt for API call {api_call_idx}:\n" + "\n".join(user_prompt_parts))
user_prompt = "\n".join(user_prompt_parts)

messages = [
metadata_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
{"role": "user", "content": user_prompt},
]
raw_response = self.generate_response(
messages=messages,
response_format={"type": "json_object"}

raw_response, usage_info = self.generate_response(
messages=metadata_messages,
response_format={"type": "json_object"},
)
cleaned_result = clean_response(raw_response)
metadata_facts = clean_response(raw_response)

return {
"input_prompt": messages,
"input_prompt": metadata_messages,
"output_prompt": raw_response,
"cleaned_result": cleaned_result
"cleaned_result": metadata_facts,
"usage": usage_info,
}

except Exception as e:
print(f"Error processing API call: {e}")
print(f"Error processing API call {api_call_idx}: {e}")
return None

with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
try:
results = list(executor.map(process_segment_wrapper, extract_list))
results = list(executor.map(process_segment_wrapper, enumerate(extract_list)))
except Exception as e:
print(f"Error in parallel processing: {e}")
results = [None] * len(extract_list)
Expand All @@ -218,15 +258,16 @@ def _call_update_llm(self, system_prompt, target_entry, candidate_sources):
{"role": "user", "content": user_prompt}
]

response_text = self.generate_response(
response_text, usage_info = self.generate_response(
messages=messages,
response_format={"type": "json_object"}
)

try:
result = json.loads(response_text)
if "action" not in result:
return {"action": "ignore"}
result = {"action": "ignore"}
result["usage"] = usage_info
return result
except Exception:
return {"action": "ignore"}
return {"action": "ignore", "usage": usage_info if 'usage_info' in locals() else None}
1 change: 1 addition & 0 deletions src/lightmem/factory/text_embedder/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class TextEmbedderFactory:
_MODEL_MAPPING: Dict[str, str] = {
"huggingface": "lightmem.factory.text_embedder.huggingface.TextEmbedderHuggingface",
"openai": "lightmem.factory.text_embedder.openai.TextEmbedderOpenAI",
}

@classmethod
Expand Down
17 changes: 15 additions & 2 deletions src/lightmem/factory/text_embedder/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
class TextEmbedderHuggingface:
def __init__(self, config: Optional[BaseTextEmbedderConfig] = None):
self.config = config
self.total_calls = 0
self.total_tokens = 0
if config.huggingface_base_url:
self.client = OpenAI(base_url=config.huggingface_base_url)
self.use_api = True
else:
self.config.model = config.model or "all-MiniLM-L6-v2"
self.model = SentenceTransformer(config.model, **config.model_kwargs)
self.config.embedding_dims = config.embedding_dims or self.model.get_sentence_embedding_dimension()
self.use_api = False

@classmethod
def from_config(cls, config):
Expand All @@ -39,11 +43,20 @@ def embed(self, text):
Returns:
list: The embedding vector.
"""
self.total_calls += 1
if self.config.huggingface_base_url:
return self.client.embeddings.create(input=text, model="tei").data[0].embedding
response = self.client.embeddings.create(input=text, model="tei")
self.total_tokens += getattr(response.usage, 'total_tokens', 0)
return response.data[0].embedding
else:
result = self.model.encode(text, convert_to_numpy=True)
if isinstance(result, np.ndarray):
return result.tolist()
else:
return result
return result

def get_stats(self):
return {
"total_calls": self.total_calls,
"total_tokens": self.total_tokens,
}
53 changes: 53 additions & 0 deletions src/lightmem/factory/text_embedder/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from openai import OpenAI
from typing import Optional, List, Union
import os
import httpx
from lightmem.configs.text_embedder.base_config import BaseTextEmbedderConfig


class TextEmbedderOpenAI:
def __init__(self, config: Optional[BaseTextEmbedderConfig] = None):
self.config = config
self.model = getattr(config, "model", None) or "text-embedding-3-small"
http_client = httpx.Client(verify=False)
api_key = self.config.api_key
base_url = self.config.openai_base_url
self.client = OpenAI(
api_key=api_key,
base_url=base_url,
http_client=http_client
)
self.total_calls = 0
self.total_tokens = 0

@classmethod
def from_config(cls, config: BaseTextEmbedderConfig):
return cls(config)

def embed(self, text: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
def preprocess(t):
return str(t).replace("\n", " ")

api_params = {"model": self.config.model}
api_params["dimensions"] = self.config.embedding_dims

if isinstance(text, list):
if len(text) == 0:
return []
inputs = [preprocess(x) for x in text]
resp = self.client.embeddings.create(input=inputs, **api_params)
self.total_calls += 1
self.total_tokens += resp.usage.total_tokens
return [item.embedding for item in resp.data]
else:
preprocessed = preprocess(text)
resp = self.client.embeddings.create(input=[preprocessed], **api_params)
self.total_calls += 1
self.total_tokens += resp.usage.total_tokens
return resp.data[0].embedding

def get_stats(self):
return {
"total_calls": self.total_calls,
"total_tokens": self.total_tokens,
}
Loading