diff --git a/experiments/run_lightmen_qwen_server.py b/experiments/run_lightmen_qwen_server.py new file mode 100644 index 0000000..0d00914 --- /dev/null +++ b/experiments/run_lightmen_qwen_server.py @@ -0,0 +1,228 @@ +from openai import OpenAI +import json +from tqdm import tqdm +import datetime +import time +import os +from lightmem.memory.lightmem import LightMemory + +# ============ API Configuration ============ +API_KEY = "empty" +API_BASE_URL = "your_qwen_server_url" +LLM_MODEL = "Qwen/Qwen-chat" +JUDGE_MODEL = "Qwen/Qwen-chat" + +# ============ Model Paths ============ +LLMLINGUA_MODEL_PATH = "/your/path/to/models/llmlingua-2-bert-base-multilingual-cased-meetingbank" +EMBEDDING_MODEL_PATH = "/your/path/to/models/all-MiniLM-L6-v2" +EMBEDDING_DIMS = 384 +# ============ Data Configuration ============ +DATA_PATH = "/your/path/to/LightMem/experiments/sample_data.json" +RESULTS_DIR = "./results" +QDRANT_DATA_DIR = "./qdrant_data" + + +def get_anscheck_prompt(task, question, answer, response, abstention=False): + if not abstention: + if task in ["single-session-user", "single-session-assistant", "multi-session"]: + template = "I will give you a question, a correct answer, and a response from a model. Please answer yes if the response contains the correct answer. Otherwise, answer no. If the response is equivalent to the correct answer or contains all the intermediate steps to get the correct answer, you should also answer yes. If the response only contains a subset of the information required by the answer, answer no. \n\nQuestion: {}\n\nCorrect Answer: {}\n\nModel Response: {}\n\nIs the model response correct? Answer yes or no only." + prompt = template.format(question, answer, response) + elif task == "temporal-reasoning": + template = "I will give you a question, a correct answer, and a response from a model. Please answer yes if the response contains the correct answer. Otherwise, answer no. If the response is equivalent to the correct answer or contains all the intermediate steps to get the correct answer, you should also answer yes. If the response only contains a subset of the information required by the answer, answer no. In addition, do not penalize off-by-one errors for the number of days. If the question asks for the number of days/weeks/months, etc., and the model makes off-by-one errors (e.g., predicting 19 days when the answer is 18), the model's response is still correct. \n\nQuestion: {}\n\nCorrect Answer: {}\n\nModel Response: {}\n\nIs the model response correct? Answer yes or no only." + prompt = template.format(question, answer, response) + elif task == "knowledge-update": + template = "I will give you a question, a correct answer, and a response from a model. Please answer yes if the response contains the correct answer. Otherwise, answer no. If the response contains some previous information along with an updated answer, the response should be considered as correct as long as the updated answer is the required answer.\n\nQuestion: {}\n\nCorrect Answer: {}\n\nModel Response: {}\n\nIs the model response correct? Answer yes or no only." + prompt = template.format(question, answer, response) + elif task == "single-session-preference": + template = "I will give you a question, a rubric for desired personalized response, and a response from a model. Please answer yes if the response satisfies the desired response. Otherwise, answer no. The model does not need to reflect all the points in the rubric. The response is correct as long as it recalls and utilizes the user's personal information correctly.\n\nQuestion: {}\n\nRubric: {}\n\nModel Response: {}\n\nIs the model response correct? Answer yes or no only." + prompt = template.format(question, answer, response) + else: + raise NotImplementedError + else: + template = "I will give you an unanswerable question, an explanation, and a response from a model. Please answer yes if the model correctly identifies the question as unanswerable. The model could say that the information is incomplete, or some other information is given but the asked information is not.\n\nQuestion: {}\n\nExplanation: {}\n\nModel Response: {}\n\nDoes the model correctly identify the question as unanswerable? Answer yes or no only." + prompt = template.format(question, answer, response) + return prompt + + +def true_or_false(response): + if response is None: + return False + normalized = str(response).strip().lower() + if not normalized: + return False + first_line = normalized.splitlines()[0].strip() + tokens = first_line.replace(".", "").replace("!", "").replace(":", "").replace(";", "").split() + if not tokens: + return False + head = tokens[0] + if head in ("yes", "y"): + return True + if head in ("no", "n"): + return False + if "yes" in first_line: + return True + if "no" in first_line: + return False + return False + + +class LLMModel: + def __init__(self, model_name, api_key, base_url): + self.name = model_name + self.api_key = api_key + self.base_url = base_url + self.max_tokens = 2000 + self.temperature = 0.0 + self.top_p = 0.8 + self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) + + def call(self, messages: list, **kwargs): + max_retries = kwargs.get("max_retries", 3) + + for attempt in range(max_retries): + try: + completion = self.client.chat.completions.create( + model=self.name, + messages=messages, + max_tokens=self.max_tokens, + temperature=self.temperature, + top_p=self.top_p, + stream=False, + ) + response = completion.choices[0].message.content + print(response) + return response + except Exception as e: + if attempt == max_retries - 1: + raise + + +def load_lightmem(collection_name): + config = { + "pre_compress": True, + "pre_compressor": { + "model_name": "llmlingua-2", + "configs": { + "llmlingua_config": { + "model_name": LLMLINGUA_MODEL_PATH, + "device_map": "cpu", + "use_llmlingua2": True, + }, + }, + }, + "topic_segment": True, + "precomp_topic_shared": True, + "topic_segmenter": { + "model_name": "llmlingua-2", + }, + "messages_use": "user_only", + "metadata_generate": True, + "text_summary": True, + "memory_manager": { + "model_name": "openai", + "configs": { + "model": LLM_MODEL, + "api_key": API_KEY, + "max_tokens": 16000, + "openai_base_url": API_BASE_URL, + }, + }, + "extract_threshold": 0.1, + "index_strategy": "embedding", + "text_embedder": { + "model_name": "huggingface", + "configs": { + "model": EMBEDDING_MODEL_PATH, + "embedding_dims": EMBEDDING_DIMS, + "model_kwargs": {"device": "cpu"}, + }, + }, + "retrieve_strategy": "embedding", + "embedding_retriever": { + "model_name": "qdrant", + "configs": { + "collection_name": collection_name, + "embedding_model_dims": EMBEDDING_DIMS, + "path": f"{QDRANT_DATA_DIR}/{collection_name}", + }, + }, + "update": "offline", + } + lightmem = LightMemory.from_config(config) + return lightmem + + +llm_judge = LLMModel(JUDGE_MODEL, API_KEY, API_BASE_URL) +llm = LLMModel(LLM_MODEL, API_KEY, API_BASE_URL) + +data = json.load(open(DATA_PATH, "r")) +data = data[:10] + +INIT_RESULT = {"add_input_prompt": [], "add_output_prompt": [], "api_call_nums": 0} + +for item in tqdm(data): + print(item["question_id"]) + lightmem = load_lightmem(collection_name=item["question_id"]) + sessions = item["haystack_sessions"] + timestamps = item["haystack_dates"] + + results_list = [] + + time_start = time.time() + for session, timestamp in zip(sessions, timestamps): + while session and session[0]["role"] != "user": + session.pop(0) + num_turns = len(session) // 2 + for turn_idx in range(num_turns): + turn_messages = session[turn_idx * 2 : turn_idx * 2 + 2] + if len(turn_messages) < 2 or turn_messages[0]["role"] != "user" or turn_messages[1]["role"] != "assistant": + continue + for msg in turn_messages: + msg["time_stamp"] = timestamp + is_last_turn = session is sessions[-1] and turn_idx == num_turns - 1 + result = lightmem.add_memory( + messages=turn_messages, + force_segment=is_last_turn, + force_extract=is_last_turn, + ) + if result != INIT_RESULT: + results_list.append(result) + + time_end = time.time() + construction_time = time_end - time_start + + related_memories = lightmem.retrieve(item["question"], limit=20) + messages = [] + messages.append({"role": "system", "content": "You are a helpful assistant."}) + messages.append( + { + "role": "user", + "content": f"Question time:{item['question_date']} and question:{item['question']}\nPlease answer the question based on the following memories: {str(related_memories)}", + } + ) + generated_answer = llm.call(messages) + + if "abs" in item["question_id"]: + prompt = get_anscheck_prompt( + item["question_type"], item["question"], item["answer"], generated_answer, abstention=True + ) + else: + prompt = get_anscheck_prompt(item["question_type"], item["question"], item["answer"], generated_answer) + messages = [{"role": "user", "content": prompt}] + response = llm_judge.call(messages) + + correct = 1 if true_or_false(response) else 0 + + save_data = { + "question_id": item["question_id"], + "results": results_list, + "construction_time": construction_time, + "generated_answer": generated_answer, + "ground_truth": item["answer"], + "correct": correct, + } + + filename = f"../results/result_{item['question_id']}.json" + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "w", encoding="utf-8") as f: + json.dump(save_data, f, ensure_ascii=False, indent=4) diff --git a/experiments/run_lightmen_qwen_server_README.md b/experiments/run_lightmen_qwen_server_README.md new file mode 100644 index 0000000..f1173a0 --- /dev/null +++ b/experiments/run_lightmen_qwen_server_README.md @@ -0,0 +1,137 @@ +# Evaluating Long-Term Memory for LLMs with LightMem + +This project demonstrates a robust system for endowing Large Language Models (LLMs) with long-term memory capabilities using the **LightMem** framework. By integrating a locally hosted Qwen model via an OpenAI-compatible API, this experiment showcases how to overcome the inherent limitations of finite context windows, enabling conversations that are context-aware, persistent, and personalized across multiple sessions. + +The included script runs a series of tests to evaluate the system's performance on three critical long-term memory tasks: +1. **Multi-Session Context Consolidation** +2. **Dynamic Knowledge Updates** +3. **Handling of Unanswerable Questions** + +## πŸ›οΈ Architecture + +The system operates on a Retrieval-Augmented Generation (RAG) architecture orchestrated by LightMem: + +- **Memory Framework:** **LightMem** manages the entire lifecycle of memories: ingestion, compression, storage, retrieval, and updating. +- **Reasoning Engine (LLM):** **Qwen/Qwen3-Next-80B-A3B-Instruct** (or any other compatible model) served locally, responsible for understanding user queries and generating responses based on retrieved memories. +- **Memory Compressor:** **`llmlingua-2`** reduces the token footprint of conversational history, making storage and retrieval more efficient. +- **Embedding Model:** **`all-MiniLM-L6-v2`** transforms textual memories into dense vector representations for semantic search. +- **Vector Store:** **Qdrant** provides a high-performance local database for storing and querying memory vectors. + +## ✨ Key Features Demonstrated + +- **Cross-Session Information Recall:** The system can synthesize information provided in separate conversations to answer a complex query. +- **Stateful Knowledge Management:** It correctly identifies and uses the most recent information when user preferences or facts change over time. +- **Knowledge Boundary Awareness:** The model can recognize when it lacks the necessary information to answer a question and gracefully abstains from responding, preventing factual hallucination. + +## πŸš€ Getting Started + +Follow these steps to set up and run the experiment on your own machine. + +### 1. Prerequisites + +- Python 3.8+ +- An OpenAI-compatible API server running a model like Qwen. You can set this up using tools like [vLLM](https://github.com/vllm-project/vllm) or [FastChat](https://github.com/lm-sys/FastChat). +- The required Hugging Face models downloaded locally. + +### 2. Installation + +1. **Clone the repository:** + ```bash + git clone https://github.com/zjunlp/LightMem.git + cd LightMem + ``` + +2. **Create a Python virtual environment (recommended):** + ```bash + # Using conda + conda create -n lightmem python=3.10 + conda activate lightmem + + # Or using venv + python -m venv venv + source venv/bin/activate # On Windows: venv\Scripts\activate + ``` + +3. **Install dependencies:** + ```bash + pip install -r requirements.txt + ``` + *(If you don't have a `requirements.txt` file, create one with the following content)*: + + ``` + lightmem + openai + tqdm + ``` + +### 3. Configuration + +Before running the script, you **must** update the configuration variables at the top of `experiments/run_lightmen_qwen_server.py`: + +```python +# ============ API Configuration ============ +# ⚠️ UPDATE THIS to your LLM server's endpoint +API_BASE_URL = "your_qwen_server_url" + +# ============ Model Paths ============ +# ⚠️ UPDATE THESE paths to where you have downloaded the models +LLMLINGUA_MODEL_PATH = "/your/path/to/models/llmlingua-2-bert-base-multilingual-cased-meetingbank" +EMBEDDING_MODEL_PATH = "/your/path/to/models/all-MiniLM-L6-v2" + +# ============ Data Configuration ============ +# This points to the sample data file +DATA_PATH = "/your/path/to/LightMem/experiments/sample_data.json" +``` + +### 4. Running the Experiment + +Navigate to the directory containing the script and execute it: + +```bash +# It is recommended to run from the root directory of the project +python experiments/run_lightmen_qwen_server.py +``` + +## πŸ“Š Expected Output + +As the script runs, you will see the following: + +1. **Console Output:** A `tqdm` progress bar will show the progress through the three test cases. The generated answers from the LLM and the "yes/no" judgment will be printed to the console for each case. + + ``` + 0%| | 0/3 [00:00 Dict[str, Any]: "microsoft/llmlingua-2-xlm-roberta-large-meetingbank", "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank", "NousResearch/Llama-2-7b-hf", - None + None, ] model_name = v.get("model_name") if model_name is not None: if model_name not in allowed_models and not os.path.exists(model_name): raise ValueError( - f"model_name must be one of {allowed_models} " - f"or a valid local path (got {model_name})" + f"model_name must be one of {allowed_models} " f"or a valid local path (got {model_name})" ) - + if "use_llmlingua2" in v and not isinstance(v["use_llmlingua2"], bool): raise ValueError("use_llmlingua2 must be a boolean") - + return v - + @field_validator("llmlingua2_config") @classmethod def validate_llmlingua2_config(cls, v: Dict[str, Any]) -> Dict[str, Any]: @@ -59,4 +55,4 @@ def validate_llmlingua2_config(cls, v: Dict[str, Any]) -> Dict[str, Any]: raise ValueError("max_batch_size must be a positive integer") if not isinstance(v.get("max_force_token"), int) or v["max_force_token"] <= 0: raise ValueError("max_force_token must be a positive integer") - return v \ No newline at end of file + return v diff --git a/src/lightmem/factory/memory_manager/openai.py b/src/lightmem/factory/memory_manager/openai.py index 73f4aad..11faf78 100644 --- a/src/lightmem/factory/memory_manager/openai.py +++ b/src/lightmem/factory/memory_manager/openai.py @@ -6,10 +6,8 @@ from lightmem.configs.memory_manager.base_config import BaseMemoryManagerConfig from lightmem.memory.utils import clean_response -model_name_context_windows = { - "gpt-4o-mini": 128000 , - "qwen3-30b-a3b-instruct-2507": 128000 -} +model_name_context_windows = {"gpt-4o-mini": 128000, "qwen3-30b-a3b-instruct-2507": 128000, "Qwen/Qwen-chat": 128000} + class OpenaiManager: def __init__(self, config: BaseMemoryManagerConfig): @@ -17,7 +15,7 @@ def __init__(self, config: BaseMemoryManagerConfig): if not self.config.model: self.config.model = "gpt-4o-mini" - + self.context_windows = model_name_context_windows[self.config.model] http_client = httpx.Client(verify=False) @@ -121,12 +119,12 @@ def generate_response( response = self.client.chat.completions.create(**params) return self._parse_response(response, tools) - + def meta_text_extract( self, system_prompt: str, extract_list: List[List[List[Dict]]], - messages_use: Literal["user_only", "assistant_only", "hybrid"] = "user_only" + messages_use: Literal["user_only", "assistant_only", "hybrid"] = "user_only", ) -> List[Optional[Dict]]: """ Extract metadata from text segments using parallel processing. @@ -141,14 +139,10 @@ def meta_text_extract( """ if not extract_list: return [] - + def concatenate_messages(segment: List[Dict], messages_use: str) -> str: """Concatenate messages based on usage strategy""" - role_filter = { - "user_only": {"user"}, - "assistant_only": {"assistant"}, - "hybrid": {"user", "assistant"} - } + role_filter = {"user_only": {"user"}, "assistant_only": {"assistant"}, "hybrid": {"user", "assistant"}} if messages_use not in role_filter: raise ValueError(f"Invalid messages_use value: {messages_use}") @@ -164,7 +158,7 @@ def concatenate_messages(segment: List[Dict], messages_use: str) -> str: message_lines.append(f"{sequence_id}.{role}: {content}") return "\n".join(message_lines) - + max_workers = min(len(extract_list), 5) def process_segment_wrapper(api_call_segments: List[List[Dict]]): @@ -177,20 +171,10 @@ def process_segment_wrapper(api_call_segments: List[List[Dict]]): user_prompt = "\n".join(user_prompt_parts) - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ] - raw_response = self.generate_response( - messages=messages, - response_format={"type": "json_object"} - ) + messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] + raw_response = self.generate_response(messages=messages, response_format={"type": "json_object"}) cleaned_result = clean_response(raw_response) - return { - "input_prompt": messages, - "output_prompt": raw_response, - "cleaned_result": cleaned_result - } + return {"input_prompt": messages, "output_prompt": raw_response, "cleaned_result": cleaned_result} except Exception as e: print(f"Error processing API call: {e}") return None @@ -203,25 +187,18 @@ def process_segment_wrapper(api_call_segments: List[List[Dict]]): results = [None] * len(extract_list) return results - + def _call_update_llm(self, system_prompt, target_entry, candidate_sources): target_memory = target_entry["payload"]["memory"] candidate_memories = [c["payload"]["memory"] for c in candidate_sources] - user_prompt = ( - f"Target memory:{target_memory}\n" - f"Candidate memories:\n" + "\n".join([f"- {m}" for m in candidate_memories]) + user_prompt = f"Target memory:{target_memory}\n" f"Candidate memories:\n" + "\n".join( + [f"- {m}" for m in candidate_memories] ) - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ] + messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] - response_text = self.generate_response( - messages=messages, - response_format={"type": "json_object"} - ) + response_text = self.generate_response(messages=messages, response_format={"type": "json_object"}) try: result = json.loads(response_text) @@ -229,4 +206,4 @@ def _call_update_llm(self, system_prompt, target_entry, candidate_sources): return {"action": "ignore"} return result except Exception: - return {"action": "ignore"} \ No newline at end of file + return {"action": "ignore"} diff --git a/src/lightmem/factory/retriever/embeddingretriever/qdrant.py b/src/lightmem/factory/retriever/embeddingretriever/qdrant.py index 3e13d84..0bb87d2 100644 --- a/src/lightmem/factory/retriever/embeddingretriever/qdrant.py +++ b/src/lightmem/factory/retriever/embeddingretriever/qdrant.py @@ -20,9 +20,7 @@ class Qdrant: - def __init__( - self, config: Optional[QdrantConfig] = None - ): + def __init__(self, config: Optional[QdrantConfig] = None): """ Initialize the Qdrant vector store. @@ -149,22 +147,26 @@ def search( query_filter=query_filter, limit=limit, with_payload=True, - with_vectors=True, + with_vectors=True, ) results = [] for h in hits.points: if return_full: - results.append({ - "id": h.id, - "score": h.score, - "payload": h.payload, - }) + results.append( + { + "id": h.id, + "score": h.score, + "payload": h.payload, + } + ) else: - results.append({ - "id": h.id, - "score": h.score, - }) + results.append( + { + "id": h.id, + "score": h.score, + } + ) return results def delete(self, vector_id: int): @@ -285,7 +287,7 @@ def exists(self, vector_id: str) -> bool: except Exception as e: logger.error(f"Error checking existence of ID {vector_id}: {e}") return False - + def get_all(self, with_vectors: bool = True, with_payload: bool = True) -> list: """ Retrieve all points from the collection. @@ -303,13 +305,33 @@ def get_all(self, with_vectors: bool = True, with_payload: bool = True) -> list: result, offset = self.client.scroll( collection_name=self.collection_name, scroll_filter=None, - limit=100, + limit=100, with_payload=with_payload, with_vectors=with_vectors, offset=offset, ) all_points.extend([p.model_dump() for p in result]) - if offset is None: + if offset is None: break return all_points + def close(self): + """Close the Qdrant client connection.""" + if hasattr(self, "client") and self.client: + try: + self.client.close() + except Exception: + pass + + def __del__(self): + """Cleanup on deletion.""" + self.close() + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.close() + return False diff --git a/src/lightmem/memory/utils.py b/src/lightmem/memory/utils.py index 4567941..4609675 100644 --- a/src/lightmem/memory/utils.py +++ b/src/lightmem/memory/utils.py @@ -23,6 +23,7 @@ class MemoryEntry: hit_time: int = 0 update_queue: List = field(default_factory=list) + def clean_response(response: str) -> List[Dict[str, Any]]: """ Cleans the model response by: @@ -47,11 +48,12 @@ def clean_response(response: str) -> List[Dict[str, Any]]: return [] + def assign_sequence_numbers_with_timestamps(extract_list): current_index = 0 timestamps_list = [] weekday_list = [] - + for segments in extract_list: for seg in segments: for message in seg: @@ -59,9 +61,10 @@ def assign_sequence_numbers_with_timestamps(extract_list): timestamps_list.append(message["time_stamp"]) weekday_list.append(message["weekday"]) current_index += 1 - + return extract_list, timestamps_list, weekday_list + # TODO:merge into context retriever def save_memory_entries(memory_entries, file_path="memory_entries.json"): def entry_to_dict(entry): @@ -106,7 +109,7 @@ def resolve_tokenizer(tokenizer_or_name: Union[str, Any]): "gpt-4.1-mini": "o200k_base", "gpt-4.1": "o200k_base", "gpt-3.5-turbo": "cl100k_base", - "qwen3-30b-a3b-instruct-2507": "o200k_base" + "Qwen/Qwen-chat": "o200k_base", } if tokenizer_or_name not in model_tokenizer_map: