Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions examples/supported_llms/openai_async_custom_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import asyncio
import os

import dotenv
from openai import AsyncOpenAI

from memori import Memori

# Load environment variables from .env file
dotenv.load_dotenv()

api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
model = os.getenv("OPENAI_MODEL", "gpt-4")

client = AsyncOpenAI(api_key=api_key, base_url=base_url)

print("Initializing Memori with OpenAI...")
openai_memory = Memori(
database_connect="sqlite:///openai_custom_demo.db",
conscious_ingest=True,
auto_ingest=True,
verbose=True,
api_key=api_key,
base_url=base_url,
model=model,
)

print("Enabling memory tracking...")
openai_memory.enable()

print(f"Memori OpenAI Example - Chat with {model} while memory is being tracked")
print("Type 'exit' or press Ctrl+C to quit")
print("-" * 50)

use_stream = True


async def main():
while True:
user_input = input("User: ")
if not user_input.strip():
continue

if user_input.lower() == "exit":
print("Goodbye!")
break

print("Processing your message with memory tracking...")

response = await client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": user_input}],
stream=use_stream,
)

if use_stream:
full_response = ""
async for chunk in response:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
full_response += content
print(content, end="", flush=True) # 实时显示
print() # 换行
# async def finalize_callback(final_response, _context):
# print(chunks)
# """Callback to record conversation when streaming completes."""
# if final_response is not None:
# print(f"AI: {final_response.choices[0].message.content}")
# print() # Add blank line for readability

# create_openai_streaming_proxy(
# stream=response,
# finalize_callback=finalize_callback
# )
else:
print(f"AI: {response.choices[0].message.content}")
print() # Add blank line for readability


if __name__ == "__main__":
try:
asyncio.run(main())
except (EOFError, KeyboardInterrupt):
print("\nExiting...")
except Exception as e:
print(f"Error: {e}")
6 changes: 3 additions & 3 deletions memori/agents/memory_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ async def process_conversation_async(
CONVERSATION CONTEXT:
- Session: {context.session_id}
- Model: {context.model_used}
- User Projects: {', '.join(context.current_projects) if context.current_projects else 'None specified'}
- Relevant Skills: {', '.join(context.relevant_skills) if context.relevant_skills else 'None specified'}
- Topic Thread: {context.topic_thread or 'General conversation'}
- User Projects: {", ".join(context.current_projects) if context.current_projects else "None specified"}
- Relevant Skills: {", ".join(context.relevant_skills) if context.relevant_skills else "None specified"}
- Topic Thread: {context.topic_thread or "General conversation"}
"""

# Try structured outputs first, fall back to manual parsing
Expand Down
10 changes: 6 additions & 4 deletions memori/agents/retrieval_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ def _execute_category_search(

filtered_results = []
for i, result in enumerate(all_results):
logger.debug(f"Processing result {i+1}/{len(all_results)}: {type(result)}")
logger.debug(
f"Processing result {i + 1}/{len(all_results)}: {type(result)}"
)

# Extract category from processed_data if it's stored as JSON
try:
Expand Down Expand Up @@ -531,7 +533,7 @@ def _execute_category_search(
logger.debug("No category found in result")

except Exception as e:
logger.debug(f"Error processing result {i+1}: {e}")
logger.debug(f"Error processing result {i + 1}: {e}")
continue

logger.debug(
Expand Down Expand Up @@ -817,7 +819,7 @@ async def execute_search_async(
self._execute_keyword_search,
search_plan,
db_manager,
namespace,
# namespace,
limit,
)
)
Expand All @@ -833,7 +835,7 @@ async def execute_search_async(
self._execute_category_search,
search_plan,
db_manager,
namespace,
# namespace,
limit,
)
)
Expand Down
58 changes: 29 additions & 29 deletions memori/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def _copy_memory_to_short_term_sync(self, memory_row: tuple) -> bool:

# SECURITY FIX: Use ORM methods instead of raw SQL to prevent injection
# Check for exact match or conscious-prefixed memories
from sqlalchemy import or_, text
from sqlalchemy import or_

from memori.database.models import ShortTermMemory

Expand Down Expand Up @@ -722,33 +722,33 @@ def _copy_memory_to_short_term_sync(self, memory_row: tuple) -> bool:
)

# Insert directly into short-term memory with conscious_context category
connection.execute(
text(
"""INSERT INTO short_term_memory (
memory_id, processed_data, importance_score, category_primary,
retention_type, user_id, assistant_id, session_id, created_at, expires_at,
searchable_content, summary, is_permanent_context
) VALUES (:memory_id, :processed_data, :importance_score, :category_primary,
:retention_type, :user_id, :assistant_id, :session_id, :created_at, :expires_at,
:searchable_content, :summary, :is_permanent_context)"""
),
{
"memory_id": short_term_id,
"processed_data": processed_data,
"importance_score": importance_score,
"category_primary": "conscious_context",
"retention_type": "permanent",
"user_id": self.user_id or "default",
"assistant_id": self.assistant_id,
"session_id": self.session_id or "default",
"created_at": datetime.now().isoformat(),
"expires_at": None,
"searchable_content": searchable_content,
"summary": summary,
"is_permanent_context": True,
},
)
connection.commit()
# connection.execute(
# text(
# """INSERT INTO short_term_memory (
# memory_id, processed_data, importance_score, category_primary,
# retention_type, user_id, assistant_id, session_id, created_at, expires_at,
# searchable_content, summary, is_permanent_context
# ) VALUES (:memory_id, :processed_data, :importance_score, :category_primary,
# :retention_type, :user_id, :assistant_id, :session_id, :created_at, :expires_at,
# :searchable_content, :summary, :is_permanent_context)"""
# ),
# {
# "memory_id": short_term_id,
# "processed_data": processed_data,
# "importance_score": importance_score,
# "category_primary": "conscious_context",
# "retention_type": "permanent",
# "user_id": self.user_id or "default",
# "assistant_id": self.assistant_id,
# "session_id": self.session_id or "default",
# "created_at": datetime.now().isoformat(),
# "expires_at": None,
# "searchable_content": searchable_content,
# "summary": summary,
# "is_permanent_context": True,
# },
# )
# connection.commit()

logger.debug(
f"Conscious-ingest: Copied memory {memory_id} to short-term as {short_term_id}"
Expand Down Expand Up @@ -2092,7 +2092,7 @@ def record_conversation(

# Generate ID and timestamp
chat_id = str(uuid.uuid4())
timestamp = datetime.now()
# timestamp = datetime.now()

try:
# Store conversation
Expand Down
2 changes: 1 addition & 1 deletion memori/database/auto_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _create_mysql_database(self, components: dict[str, str]) -> None:
# Database name is already validated, so this is safe
conn.execute(
text(
f'CREATE DATABASE `{components["database"]}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci'
f"CREATE DATABASE `{components['database']}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
)
)
conn.commit()
Expand Down
1 change: 0 additions & 1 deletion memori/database/connectors/postgres_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ def execute_transaction(self, queries: list[tuple]) -> bool:
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:

for query, params in queries:
if params:
cursor.execute(query, params)
Expand Down
8 changes: 3 additions & 5 deletions memori/database/search_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ def get_list_metadata(
short_users = short_query.all()
long_users = long_query.all()
all_users = set([u[0] for u in short_users] + [u[0] for u in long_users])
metadata["available_filters"]["user_ids"] = sorted(list(all_users))
metadata["available_filters"]["user_ids"] = sorted(all_users)

# Get distinct assistant_ids
base_short_query = self.session.query(
Expand Down Expand Up @@ -1363,9 +1363,7 @@ def get_list_metadata(
[a[0] for a in short_assistants if a[0]]
+ [a[0] for a in long_assistants if a[0]]
)
metadata["available_filters"]["assistant_ids"] = sorted(
list(all_assistants)
)
metadata["available_filters"]["assistant_ids"] = sorted(all_assistants)

# Get distinct session_ids
short_sessions_query = self.session.query(
Expand All @@ -1390,7 +1388,7 @@ def get_list_metadata(
[s[0] for s in short_sessions if s[0]]
+ [s[0] for s in long_sessions if s[0]]
)
metadata["available_filters"]["session_ids"] = sorted(list(all_sessions))
metadata["available_filters"]["session_ids"] = sorted(all_sessions)

# Get counts
short_count_query = self.session.query(ShortTermMemory)
Expand Down
47 changes: 45 additions & 2 deletions memori/integrations/openai_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@

from loguru import logger

from ..utils.streaming_proxy import create_openai_streaming_proxy

# Global registry of enabled Memori instances
_enabled_memori_instances = []

Expand Down Expand Up @@ -278,6 +280,12 @@ def patched_process_response(
**kwargs,
)

if stream:
# Record streaming conversation for enabled Memori instances
result = cls._stream_record_conversation_for_enabled_instances(
options, result, client_type
)

# Record conversation for enabled Memori instances
if not stream: # Don't record streaming here - handle separately
cls._record_conversation_for_enabled_instances(
Expand Down Expand Up @@ -336,6 +344,12 @@ async def patched_async_process_response(
**kwargs,
)

# Record streaming conversation for enabled Memori instances
if stream:
result = cls._stream_record_conversation_for_enabled_instances(
options, result, client_type
)

# Record conversation for enabled Memori instances
if not stream:
cls._record_conversation_for_enabled_instances(
Expand All @@ -350,9 +364,9 @@ async def patched_async_process_response(
if original_prepare_key in cls._original_methods:
original_prepare = cls._original_methods[original_prepare_key]

def patched_async_prepare_options(self, options):
async def patched_async_prepare_options(self, options):
# Call original method first
options = original_prepare(self, options)
options = await original_prepare(self, options)

# Inject context for enabled Memori instances
options = cls._inject_context_for_enabled_instances(
Expand Down Expand Up @@ -479,6 +493,35 @@ def _is_internal_agent_call(cls, json_data):
logger.debug(f"Failed to check internal agent call: {e}")
return False

@classmethod
def _stream_record_conversation_for_enabled_instances(
cls, options, response, client_type
):
"""
Wrap streaming response to record conversation for enabled Memori instances.
"""
try:
if response is None:
return response

# Define finalize callback to record conversation after streaming completes
async def finalize_callback(final_response, context_data):
options, client_type = context_data
if final_response is not None:
cls._record_conversation_for_enabled_instances(
options, final_response, client_type
)

# Create streaming proxy
return create_openai_streaming_proxy(
stream=response,
finalize_callback=finalize_callback,
context_data=(options, client_type),
)
except Exception as e:
logger.error(f"Failed to wrap streaming conversation: {e}")
return response

@classmethod
def _record_conversation_for_enabled_instances(cls, options, response, client_type):
"""Record conversation for the active Memori instance (or all enabled instances for backward compatibility)."""
Expand Down
7 changes: 7 additions & 0 deletions memori/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
RetentionType,
)

# Streaming utilities
from .streaming_proxy import (
create_openai_streaming_proxy,
)

# Validation utilities
from .validators import DataValidator, MemoryValidator

Expand Down Expand Up @@ -90,4 +95,6 @@
# Logging
"LoggingManager",
"get_logger",
# Streaming
"create_openai_streaming_proxy",
]
Loading
Loading