diff --git a/examples/supported_llms/openai_async_custom_example.py b/examples/supported_llms/openai_async_custom_example.py new file mode 100644 index 00000000..fb057815 --- /dev/null +++ b/examples/supported_llms/openai_async_custom_example.py @@ -0,0 +1,87 @@ +import asyncio +import os + +import dotenv +from openai import AsyncOpenAI + +from memori import Memori + +# Load environment variables from .env file +dotenv.load_dotenv() + +api_key = os.getenv("OPENAI_API_KEY") +base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") +model = os.getenv("OPENAI_MODEL", "gpt-4") + +client = AsyncOpenAI(api_key=api_key, base_url=base_url) + +print("Initializing Memori with OpenAI...") +openai_memory = Memori( + database_connect="sqlite:///openai_custom_demo.db", + conscious_ingest=True, + auto_ingest=True, + verbose=True, + api_key=api_key, + base_url=base_url, + model=model, +) + +print("Enabling memory tracking...") +openai_memory.enable() + +print(f"Memori OpenAI Example - Chat with {model} while memory is being tracked") +print("Type 'exit' or press Ctrl+C to quit") +print("-" * 50) + +use_stream = True + + +async def main(): + while True: + user_input = input("User: ") + if not user_input.strip(): + continue + + if user_input.lower() == "exit": + print("Goodbye!") + break + + print("Processing your message with memory tracking...") + + response = await client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": user_input}], + stream=use_stream, + ) + + if use_stream: + full_response = "" + async for chunk in response: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + full_response += content + print(content, end="", flush=True) # 实时显示 + print() # 换行 + # async def finalize_callback(final_response, _context): + # print(chunks) + # """Callback to record conversation when streaming completes.""" + # if final_response is not None: + # print(f"AI: {final_response.choices[0].message.content}") + # print() # Add blank line for readability + + # create_openai_streaming_proxy( + # stream=response, + # finalize_callback=finalize_callback + # ) + else: + print(f"AI: {response.choices[0].message.content}") + print() # Add blank line for readability + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except (EOFError, KeyboardInterrupt): + print("\nExiting...") + except Exception as e: + print(f"Error: {e}") diff --git a/memori/agents/memory_agent.py b/memori/agents/memory_agent.py index 94c9e1cc..30617757 100644 --- a/memori/agents/memory_agent.py +++ b/memori/agents/memory_agent.py @@ -184,9 +184,9 @@ async def process_conversation_async( CONVERSATION CONTEXT: - Session: {context.session_id} - Model: {context.model_used} -- User Projects: {', '.join(context.current_projects) if context.current_projects else 'None specified'} -- Relevant Skills: {', '.join(context.relevant_skills) if context.relevant_skills else 'None specified'} -- Topic Thread: {context.topic_thread or 'General conversation'} +- User Projects: {", ".join(context.current_projects) if context.current_projects else "None specified"} +- Relevant Skills: {", ".join(context.relevant_skills) if context.relevant_skills else "None specified"} +- Topic Thread: {context.topic_thread or "General conversation"} """ # Try structured outputs first, fall back to manual parsing diff --git a/memori/agents/retrieval_agent.py b/memori/agents/retrieval_agent.py index d197486c..ef8d7c49 100644 --- a/memori/agents/retrieval_agent.py +++ b/memori/agents/retrieval_agent.py @@ -460,7 +460,9 @@ def _execute_category_search( filtered_results = [] for i, result in enumerate(all_results): - logger.debug(f"Processing result {i+1}/{len(all_results)}: {type(result)}") + logger.debug( + f"Processing result {i + 1}/{len(all_results)}: {type(result)}" + ) # Extract category from processed_data if it's stored as JSON try: @@ -531,7 +533,7 @@ def _execute_category_search( logger.debug("No category found in result") except Exception as e: - logger.debug(f"Error processing result {i+1}: {e}") + logger.debug(f"Error processing result {i + 1}: {e}") continue logger.debug( @@ -817,7 +819,7 @@ async def execute_search_async( self._execute_keyword_search, search_plan, db_manager, - namespace, + # namespace, limit, ) ) @@ -833,7 +835,7 @@ async def execute_search_async( self._execute_category_search, search_plan, db_manager, - namespace, + # namespace, limit, ) ) diff --git a/memori/core/memory.py b/memori/core/memory.py index 3df77abc..8c649820 100644 --- a/memori/core/memory.py +++ b/memori/core/memory.py @@ -692,7 +692,7 @@ def _copy_memory_to_short_term_sync(self, memory_row: tuple) -> bool: # SECURITY FIX: Use ORM methods instead of raw SQL to prevent injection # Check for exact match or conscious-prefixed memories - from sqlalchemy import or_, text + from sqlalchemy import or_ from memori.database.models import ShortTermMemory @@ -722,33 +722,33 @@ def _copy_memory_to_short_term_sync(self, memory_row: tuple) -> bool: ) # Insert directly into short-term memory with conscious_context category - connection.execute( - text( - """INSERT INTO short_term_memory ( - memory_id, processed_data, importance_score, category_primary, - retention_type, user_id, assistant_id, session_id, created_at, expires_at, - searchable_content, summary, is_permanent_context - ) VALUES (:memory_id, :processed_data, :importance_score, :category_primary, - :retention_type, :user_id, :assistant_id, :session_id, :created_at, :expires_at, - :searchable_content, :summary, :is_permanent_context)""" - ), - { - "memory_id": short_term_id, - "processed_data": processed_data, - "importance_score": importance_score, - "category_primary": "conscious_context", - "retention_type": "permanent", - "user_id": self.user_id or "default", - "assistant_id": self.assistant_id, - "session_id": self.session_id or "default", - "created_at": datetime.now().isoformat(), - "expires_at": None, - "searchable_content": searchable_content, - "summary": summary, - "is_permanent_context": True, - }, - ) - connection.commit() + # connection.execute( + # text( + # """INSERT INTO short_term_memory ( + # memory_id, processed_data, importance_score, category_primary, + # retention_type, user_id, assistant_id, session_id, created_at, expires_at, + # searchable_content, summary, is_permanent_context + # ) VALUES (:memory_id, :processed_data, :importance_score, :category_primary, + # :retention_type, :user_id, :assistant_id, :session_id, :created_at, :expires_at, + # :searchable_content, :summary, :is_permanent_context)""" + # ), + # { + # "memory_id": short_term_id, + # "processed_data": processed_data, + # "importance_score": importance_score, + # "category_primary": "conscious_context", + # "retention_type": "permanent", + # "user_id": self.user_id or "default", + # "assistant_id": self.assistant_id, + # "session_id": self.session_id or "default", + # "created_at": datetime.now().isoformat(), + # "expires_at": None, + # "searchable_content": searchable_content, + # "summary": summary, + # "is_permanent_context": True, + # }, + # ) + # connection.commit() logger.debug( f"Conscious-ingest: Copied memory {memory_id} to short-term as {short_term_id}" @@ -2092,7 +2092,7 @@ def record_conversation( # Generate ID and timestamp chat_id = str(uuid.uuid4()) - timestamp = datetime.now() + # timestamp = datetime.now() try: # Store conversation diff --git a/memori/database/auto_creator.py b/memori/database/auto_creator.py index 8431f4d9..1c0ae186 100644 --- a/memori/database/auto_creator.py +++ b/memori/database/auto_creator.py @@ -310,7 +310,7 @@ def _create_mysql_database(self, components: dict[str, str]) -> None: # Database name is already validated, so this is safe conn.execute( text( - f'CREATE DATABASE `{components["database"]}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci' + f"CREATE DATABASE `{components['database']}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" ) ) conn.commit() diff --git a/memori/database/connectors/postgres_connector.py b/memori/database/connectors/postgres_connector.py index 2412bdbc..9f5f7e97 100644 --- a/memori/database/connectors/postgres_connector.py +++ b/memori/database/connectors/postgres_connector.py @@ -243,7 +243,6 @@ def execute_transaction(self, queries: list[tuple]) -> bool: try: with self.get_connection() as conn: with conn.cursor() as cursor: - for query, params in queries: if params: cursor.execute(query, params) diff --git a/memori/database/search_service.py b/memori/database/search_service.py index 515c2e13..6489c83b 100644 --- a/memori/database/search_service.py +++ b/memori/database/search_service.py @@ -1331,7 +1331,7 @@ def get_list_metadata( short_users = short_query.all() long_users = long_query.all() all_users = set([u[0] for u in short_users] + [u[0] for u in long_users]) - metadata["available_filters"]["user_ids"] = sorted(list(all_users)) + metadata["available_filters"]["user_ids"] = sorted(all_users) # Get distinct assistant_ids base_short_query = self.session.query( @@ -1363,9 +1363,7 @@ def get_list_metadata( [a[0] for a in short_assistants if a[0]] + [a[0] for a in long_assistants if a[0]] ) - metadata["available_filters"]["assistant_ids"] = sorted( - list(all_assistants) - ) + metadata["available_filters"]["assistant_ids"] = sorted(all_assistants) # Get distinct session_ids short_sessions_query = self.session.query( @@ -1390,7 +1388,7 @@ def get_list_metadata( [s[0] for s in short_sessions if s[0]] + [s[0] for s in long_sessions if s[0]] ) - metadata["available_filters"]["session_ids"] = sorted(list(all_sessions)) + metadata["available_filters"]["session_ids"] = sorted(all_sessions) # Get counts short_count_query = self.session.query(ShortTermMemory) diff --git a/memori/integrations/openai_integration.py b/memori/integrations/openai_integration.py index 9f9bfc88..27d9501b 100644 --- a/memori/integrations/openai_integration.py +++ b/memori/integrations/openai_integration.py @@ -40,6 +40,8 @@ from loguru import logger +from ..utils.streaming_proxy import create_openai_streaming_proxy + # Global registry of enabled Memori instances _enabled_memori_instances = [] @@ -278,6 +280,12 @@ def patched_process_response( **kwargs, ) + if stream: + # Record streaming conversation for enabled Memori instances + result = cls._stream_record_conversation_for_enabled_instances( + options, result, client_type + ) + # Record conversation for enabled Memori instances if not stream: # Don't record streaming here - handle separately cls._record_conversation_for_enabled_instances( @@ -336,6 +344,12 @@ async def patched_async_process_response( **kwargs, ) + # Record streaming conversation for enabled Memori instances + if stream: + result = cls._stream_record_conversation_for_enabled_instances( + options, result, client_type + ) + # Record conversation for enabled Memori instances if not stream: cls._record_conversation_for_enabled_instances( @@ -350,9 +364,9 @@ async def patched_async_process_response( if original_prepare_key in cls._original_methods: original_prepare = cls._original_methods[original_prepare_key] - def patched_async_prepare_options(self, options): + async def patched_async_prepare_options(self, options): # Call original method first - options = original_prepare(self, options) + options = await original_prepare(self, options) # Inject context for enabled Memori instances options = cls._inject_context_for_enabled_instances( @@ -479,6 +493,35 @@ def _is_internal_agent_call(cls, json_data): logger.debug(f"Failed to check internal agent call: {e}") return False + @classmethod + def _stream_record_conversation_for_enabled_instances( + cls, options, response, client_type + ): + """ + Wrap streaming response to record conversation for enabled Memori instances. + """ + try: + if response is None: + return response + + # Define finalize callback to record conversation after streaming completes + async def finalize_callback(final_response, context_data): + options, client_type = context_data + if final_response is not None: + cls._record_conversation_for_enabled_instances( + options, final_response, client_type + ) + + # Create streaming proxy + return create_openai_streaming_proxy( + stream=response, + finalize_callback=finalize_callback, + context_data=(options, client_type), + ) + except Exception as e: + logger.error(f"Failed to wrap streaming conversation: {e}") + return response + @classmethod def _record_conversation_for_enabled_instances(cls, options, response, client_type): """Record conversation for the active Memori instance (or all enabled instances for backward compatibility).""" diff --git a/memori/utils/__init__.py b/memori/utils/__init__.py index 1fb6a5f9..bb1f30a3 100644 --- a/memori/utils/__init__.py +++ b/memori/utils/__init__.py @@ -47,6 +47,11 @@ RetentionType, ) +# Streaming utilities +from .streaming_proxy import ( + create_openai_streaming_proxy, +) + # Validation utilities from .validators import DataValidator, MemoryValidator @@ -90,4 +95,6 @@ # Logging "LoggingManager", "get_logger", + # Streaming + "create_openai_streaming_proxy", ] diff --git a/memori/utils/query_builder.py b/memori/utils/query_builder.py index a7c2ddba..b785456d 100644 --- a/memori/utils/query_builder.py +++ b/memori/utils/query_builder.py @@ -120,7 +120,7 @@ def build_search_query( query = f""" SELECT *, '{tables[0]}' as memory_type FROM {tables[0]} - WHERE {' AND '.join(where_conditions)} + WHERE {" AND ".join(where_conditions)} ORDER BY importance_score DESC, created_at DESC {self.LIMIT_SYNTAX[self.dialect]} """ @@ -131,13 +131,13 @@ def build_search_query( table_query = f""" SELECT *, '{table}' as memory_type FROM {table} - WHERE {' AND '.join(where_conditions)} + WHERE {" AND ".join(where_conditions)} """ union_parts.append(table_query) query = f""" SELECT * FROM ( - {' UNION ALL '.join(union_parts)} + {" UNION ALL ".join(union_parts)} ) combined ORDER BY importance_score DESC, created_at DESC {self.LIMIT_SYNTAX[self.dialect]} @@ -241,8 +241,8 @@ def build_update_query( query = f""" UPDATE {table} - SET {', '.join(set_conditions)} - WHERE {' AND '.join(where_parts)} + SET {", ".join(set_conditions)} + WHERE {" AND ".join(where_parts)} """ return query, params @@ -326,7 +326,7 @@ def build_fts_query( FROM memory_search_fts fts LEFT JOIN short_term_memory st ON fts.memory_id = st.memory_id AND fts.memory_type = 'short_term' LEFT JOIN long_term_memory lt ON fts.memory_id = lt.memory_id AND fts.memory_type = 'long_term' - WHERE {' AND '.join(where_conditions)} + WHERE {" AND ".join(where_conditions)} ORDER BY rank, importance_score DESC {self.LIMIT_SYNTAX[self.dialect]} """ @@ -360,7 +360,7 @@ def build_fts_query( ts_rank(COALESCE(to_tsvector('english', st.searchable_content), to_tsvector('english', lt.searchable_content)), plainto_tsquery('english', %s)) as rank FROM short_term_memory st FULL OUTER JOIN long_term_memory lt ON FALSE -- Force separate processing - WHERE {' AND '.join(where_conditions)} + WHERE {" AND ".join(where_conditions)} ORDER BY rank DESC, importance_score DESC {self.LIMIT_SYNTAX[self.dialect]} """ @@ -409,7 +409,7 @@ def build_fts_query( lt.summary, MATCH(lt.searchable_content) AGAINST(%s IN BOOLEAN MODE) as rank FROM long_term_memory lt - WHERE {' AND '.join(where_conditions)} + WHERE {" AND ".join(where_conditions)} ORDER BY rank DESC, importance_score DESC {self.LIMIT_SYNTAX[self.dialect]} """ diff --git a/memori/utils/streaming_proxy.py b/memori/utils/streaming_proxy.py new file mode 100644 index 00000000..02019b0d --- /dev/null +++ b/memori/utils/streaming_proxy.py @@ -0,0 +1,472 @@ +""" +Generic streaming proxy utilities for intercepting and processing streaming responses. + +This module provides reusable classes for wrapping streaming responses, +capturing chunks, and processing them when the stream completes. +""" + +from __future__ import annotations + +import asyncio +import inspect +import time +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import Any + +from loguru import logger +from openai._streaming import AsyncStream, Stream +from openai.types.chat.chat_completion import ( + ChatCompletion, +) +from openai.types.chat.chat_completion import Choice as ChatCompletionChoice +from openai.types.chat.chat_completion_message import ( + ChatCompletionMessage, +) +from openai.types.chat.chat_completion_message import ( + FunctionCall as ChatCompletionFunctionCall, +) +from openai.types.chat.chat_completion_message_function_tool_call import ( + ChatCompletionMessageFunctionToolCall, +) +from openai.types.chat.chat_completion_message_function_tool_call import ( + Function as ChatCompletionToolFunction, +) +from openai.types.completion_usage import CompletionUsage + + +@dataclass +class _FunctionCallAccumulator: + """Accumulates partial function call deltas.""" + + name: str | None = None + argument_parts: list[str] = field(default_factory=list) + + def add(self, delta) -> None: + if delta is None: + return + if getattr(delta, "name", None): + self.name = delta.name + if getattr(delta, "arguments", None): + self.argument_parts.append(delta.arguments) + + def build(self) -> ChatCompletionFunctionCall | None: + if not (self.name or self.argument_parts): + return None + try: + return ChatCompletionFunctionCall( + name=self.name or "", + arguments="".join(self.argument_parts), + ) + except Exception as exc: # pragma: no cover - defensive + logger.debug(f"Failed to build function call: {exc}") + return None + + +@dataclass +class _ToolCallAccumulator: + """Accumulates partial tool call deltas.""" + + index: int + tool_id: str | None = None + tool_type: str | None = None + function_name: str | None = None + argument_parts: list[str] = field(default_factory=list) + + def add(self, delta) -> None: + if delta is None: + return + if getattr(delta, "id", None): + self.tool_id = delta.id + if getattr(delta, "type", None): + self.tool_type = delta.type + + function_delta = getattr(delta, "function", None) + if function_delta: + if getattr(function_delta, "name", None): + self.function_name = function_delta.name + if getattr(function_delta, "arguments", None): + self.argument_parts.append(function_delta.arguments) + + def build(self) -> ChatCompletionMessageFunctionToolCall | None: + if not (self.tool_id or self.function_name or self.argument_parts): + return None + + try: + return ChatCompletionMessageFunctionToolCall( + id=self.tool_id or f"tool_call_{self.index}", + type=self.tool_type or "function", + function=ChatCompletionToolFunction( + name=self.function_name or "", + arguments="".join(self.argument_parts), + ), + ) + except Exception as exc: # pragma: no cover - defensive + logger.debug(f"Failed to build tool call: {exc}") + return None + + +@dataclass +class _ChoiceAccumulator: + """Holds accumulated data for a single streamed choice.""" + + index: int + role: str | None = None + content_parts: list[str] = field(default_factory=list) + refusal_parts: list[str] = field(default_factory=list) + finish_reason: str | None = None + logprobs: Any = None + function_call: _FunctionCallAccumulator = field( + default_factory=_FunctionCallAccumulator + ) + tool_calls: dict[int, _ToolCallAccumulator] = field(default_factory=dict) + + def add_delta(self, delta) -> None: + if delta is None: + return + + if getattr(delta, "role", None): + self.role = delta.role + if getattr(delta, "content", None): + self.content_parts.append(delta.content) + if getattr(delta, "refusal", None): + self.refusal_parts.append(delta.refusal) + + self.function_call.add(getattr(delta, "function_call", None)) + + tool_deltas = getattr(delta, "tool_calls", None) or [] + for tool_delta in tool_deltas: + tool_acc = self.tool_calls.setdefault( + getattr(tool_delta, "index", len(self.tool_calls)), + _ToolCallAccumulator(index=getattr(tool_delta, "index", 0)), + ) + tool_acc.add(tool_delta) + + def build_message(self) -> ChatCompletionMessage | None: + try: + message_kwargs = { + "role": self.role or "assistant", + } + + if self.content_parts: + message_kwargs["content"] = "".join(self.content_parts) + if self.refusal_parts: + message_kwargs["refusal"] = "".join(self.refusal_parts) + + built_tool_calls = [ + tool_call + for tool_call in ( + tool_acc.build() for _, tool_acc in sorted(self.tool_calls.items()) + ) + if tool_call + ] + if built_tool_calls: + message_kwargs["tool_calls"] = built_tool_calls + + function_call = self.function_call.build() + if function_call: + message_kwargs["function_call"] = function_call + + return ChatCompletionMessage(**message_kwargs) + except Exception as exc: # pragma: no cover - defensive + logger.debug(f"Failed to build chat completion message: {exc}") + return None + + +class _ChatCompletionStreamAggregator: + """Aggregates OpenAI chat completion chunks into a final response.""" + + def __init__(self) -> None: + self._choices: dict[int, _ChoiceAccumulator] = {} + self._has_chunks = False + self._id: str | None = None + self._created: int | None = None + self._model: str | None = None + self._service_tier: str | None = None + self._system_fingerprint: str | None = None + self._usage: CompletionUsage | None = None + + def add_chunk(self, chunk: Any) -> None: + if chunk is None: + return + + try: + self._has_chunks = True + + if getattr(chunk, "id", None) and not self._id: + self._id = chunk.id + if getattr(chunk, "created", None) and not self._created: + self._created = chunk.created + if getattr(chunk, "model", None) and not self._model: + self._model = chunk.model + + if getattr(chunk, "service_tier", None): + self._service_tier = chunk.service_tier + if getattr(chunk, "system_fingerprint", None): + self._system_fingerprint = chunk.system_fingerprint + if getattr(chunk, "usage", None): + self._usage = chunk.usage + + for choice in getattr(chunk, "choices", []) or []: + index = getattr(choice, "index", 0) + accumulator = self._choices.setdefault( + index, _ChoiceAccumulator(index=index) + ) + accumulator.add_delta(getattr(choice, "delta", None)) + + if getattr(choice, "finish_reason", None): + accumulator.finish_reason = choice.finish_reason + if getattr(choice, "logprobs", None): + accumulator.logprobs = choice.logprobs + + except Exception as exc: # pragma: no cover - defensive + logger.debug(f"Failed to aggregate streaming chunk: {exc}") + + def build(self) -> ChatCompletion | None: + if not self._has_chunks or not self._choices: + return None + + try: + choices: list[ChatCompletionChoice] = [] + for index, accumulator in sorted(self._choices.items()): + message = accumulator.build_message() + if message is None: + continue + + choice_kwargs = { + "index": index, + "message": message, + "finish_reason": accumulator.finish_reason or "stop", + } + + if accumulator.logprobs is not None: + choice_kwargs["logprobs"] = accumulator.logprobs + + choices.append(ChatCompletionChoice(**choice_kwargs)) + + if not choices: + return None + + chat_kwargs = { + "id": self._id or "streaming_response", + "choices": choices, + "created": self._created or int(time.time()), + "model": self._model or "unknown", + "object": "chat.completion", + } + + if self._service_tier is not None: + chat_kwargs["service_tier"] = self._service_tier + if self._system_fingerprint is not None: + chat_kwargs["system_fingerprint"] = self._system_fingerprint + if self._usage is not None: + chat_kwargs["usage"] = self._usage + + return ChatCompletion(**chat_kwargs) + except Exception as exc: # pragma: no cover - defensive + logger.debug(f"Failed to build aggregated chat completion: {exc}") + return None + + +def _execute_finalize_callback_sync( + callback: Callable[[Any, Any], Awaitable[None] | None] | None, + final_response: Any, + context_data: Any, +) -> None: + if callback is None: + return + + try: + result = callback(final_response, context_data) + if inspect.isawaitable(result): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + asyncio.run(result) + else: + loop.create_task(result) + except Exception as exc: # pragma: no cover - defensive + logger.error(f"Streaming finalize callback failed: {exc}") + + +async def _execute_finalize_callback_async( + callback: Callable[[Any, Any], Awaitable[None] | None] | None, + final_response: Any, + context_data: Any, +) -> None: + if callback is None: + return + + try: + result = callback(final_response, context_data) + if inspect.isawaitable(result): + await result + except Exception as exc: # pragma: no cover - defensive + logger.error(f"Streaming finalize callback failed: {exc}") + + +class _SyncOpenAIStreamProxy: + """Proxy for synchronous OpenAI streaming responses.""" + + def __init__( + self, + stream: Stream, + finalize_callback: Callable[[Any, Any], Awaitable[None] | None] | None, + context_data: Any, + ) -> None: + self._stream = stream + self._finalize_callback = finalize_callback + self._context_data = context_data + self._aggregator = _ChatCompletionStreamAggregator() + self._final_response: ChatCompletion | None = None + self._finalized = False + + def __getattr__(self, item: str) -> Any: + return getattr(self._stream, item) + + def __iter__(self) -> _SyncOpenAIStreamProxy: + return self + + def __next__(self) -> Any: + try: + chunk = next(self._stream) + except StopIteration: + self._finalize() + raise + else: + self._aggregator.add_chunk(chunk) + return chunk + + def __enter__(self) -> _SyncOpenAIStreamProxy: + if hasattr(self._stream, "__enter__"): + self._stream.__enter__() + return self + + def __exit__(self, exc_type, exc, exc_tb) -> None: + try: + if hasattr(self._stream, "__exit__"): + self._stream.__exit__(exc_type, exc, exc_tb) + finally: + self._finalize() + + def close(self) -> None: + try: + if hasattr(self._stream, "close"): + self._stream.close() + finally: + self._finalize() + + def _finalize(self) -> None: + if self._finalized: + return + self._finalized = True + self._final_response = self._aggregator.build() + _execute_finalize_callback_sync( + self._finalize_callback, self._final_response, self._context_data + ) + + @property + def final_response(self) -> ChatCompletion | None: + if self._final_response is None: + self._final_response = self._aggregator.build() + return self._final_response + + +class _AsyncOpenAIStreamProxy: + """Proxy for asynchronous OpenAI streaming responses.""" + + def __init__( + self, + stream: AsyncStream, + finalize_callback: Callable[[Any, Any], Awaitable[None] | None] | None, + context_data: Any, + ) -> None: + self._stream = stream + self._finalize_callback = finalize_callback + self._context_data = context_data + self._aggregator = _ChatCompletionStreamAggregator() + self._final_response: ChatCompletion | None = None + self._finalized = False + + def __getattr__(self, item: str) -> Any: + return getattr(self._stream, item) + + def __aiter__(self) -> _AsyncOpenAIStreamProxy: + return self + + async def __anext__(self) -> Any: + try: + chunk = await self._stream.__anext__() + except StopAsyncIteration: + await self._finalize() + raise + else: + self._aggregator.add_chunk(chunk) + return chunk + + async def __aenter__(self) -> _AsyncOpenAIStreamProxy: + if hasattr(self._stream, "__aenter__"): + await self._stream.__aenter__() + return self + + async def __aexit__(self, exc_type, exc, exc_tb) -> None: + try: + if hasattr(self._stream, "__aexit__"): + await self._stream.__aexit__(exc_type, exc, exc_tb) + finally: + await self._finalize() + + async def aclose(self) -> None: + try: + if hasattr(self._stream, "aclose"): + await self._stream.aclose() + finally: + await self._finalize() + + async def _finalize(self) -> None: + if self._finalized: + return + self._finalized = True + self._final_response = self._aggregator.build() + await _execute_finalize_callback_async( + self._finalize_callback, self._final_response, self._context_data + ) + + @property + def final_response(self) -> ChatCompletion | None: + if self._final_response is None: + self._final_response = self._aggregator.build() + return self._final_response + + +# Convenience function for creating OpenAI streaming proxies +def create_openai_streaming_proxy( + stream: Stream | AsyncStream, + finalize_callback: Callable[[Any, Any], Awaitable[None] | None] | None = None, + context_data: Any = None, +) -> Stream | AsyncStream: + """ + Create a StreamingProxy specialized for OpenAI streaming responses. + Args: + stream: The original OpenAI streaming response (Stream or AsyncStream). + finalize_callback: An optional async callback to be called when the stream + completes. + context_data: Optional context provided to the callback. + Returns: + Stream or AsyncStream + """ + + if finalize_callback is None: + return stream + + if isinstance(stream, AsyncStream): + return _AsyncOpenAIStreamProxy(stream, finalize_callback, context_data) + + if isinstance(stream, Stream): + return _SyncOpenAIStreamProxy(stream, finalize_callback, context_data) + + logger.warning( + "create_openai_streaming_proxy received an unsupported stream type: %s", + type(stream), + ) + return stream diff --git a/memori/utils/transaction_manager.py b/memori/utils/transaction_manager.py index f4ac7f35..e4c5f30a 100644 --- a/memori/utils/transaction_manager.py +++ b/memori/utils/transaction_manager.py @@ -77,7 +77,7 @@ def transaction( ): """Context manager for database transactions with proper error handling""" - transaction_id = f"txn_{int(time.time()*1000)}" + transaction_id = f"txn_{int(time.time() * 1000)}" start_time = time.time() try: diff --git a/tests/comprehensive_database_comparison.py b/tests/comprehensive_database_comparison.py index 1d685b33..e407acff 100644 --- a/tests/comprehensive_database_comparison.py +++ b/tests/comprehensive_database_comparison.py @@ -49,11 +49,11 @@ def get_database_connections(): def test_database_comprehensive(db_name, connection_string, test_name): """Comprehensive database test with all features""" - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"🧪 Testing {test_name}") print(f"Database: {db_name}") print(f"Connection: {connection_string}") - print(f"{'='*60}") + print(f"{'=' * 60}") try: from memori import Memori @@ -134,7 +134,7 @@ def test_database_comprehensive(db_name, connection_string, test_name): test_results["insert_time"] = time.time() - start_time print( - f" ⏱️ Insert time: {test_results['insert_time']:.3f}s ({test_results['insert_time']/len(test_messages):.3f}s per record)" + f" ⏱️ Insert time: {test_results['insert_time']:.3f}s ({test_results['insert_time'] / len(test_messages):.3f}s per record)" ) # Test 2: Data retrieval performance diff --git a/tests/litellm_support/litellm_test_suite.py b/tests/litellm_support/litellm_test_suite.py index 74f316c9..f98fc8b9 100644 --- a/tests/litellm_support/litellm_test_suite.py +++ b/tests/litellm_support/litellm_test_suite.py @@ -28,12 +28,12 @@ def run_test_scenario(test_name, conscious_ingest, auto_ingest, test_inputs): auto_ingest: Boolean for auto_ingest parameter test_inputs: List of test inputs to process """ - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Running Test: {test_name}") print( f"Configuration: conscious_ingest={conscious_ingest}, auto_ingest={auto_ingest}" ) - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") # Create database directory for this test db_dir = f"test_databases/{test_name}" diff --git a/tests/mysql_support/compare_databases.py b/tests/mysql_support/compare_databases.py index fd3efc2e..fccc9e28 100644 --- a/tests/mysql_support/compare_databases.py +++ b/tests/mysql_support/compare_databases.py @@ -16,10 +16,10 @@ def test_database_performance(db_type, connection_string, test_name): """Test database performance with various operations""" - print(f"\n{'='*50}") + print(f"\n{'=' * 50}") print(f"🧪 Testing {test_name}") print(f"Connection: {connection_string}") - print(f"{'='*50}") + print(f"{'=' * 50}") try: from memori import Memori @@ -60,7 +60,7 @@ def test_database_performance(db_type, connection_string, test_name): insert_time = time.time() - start_time print( - f" ⏱️ Insert time: {insert_time:.3f}s ({insert_time/10:.3f}s per record)" + f" ⏱️ Insert time: {insert_time:.3f}s ({insert_time / 10:.3f}s per record)" ) # Test 2: Data retrieval @@ -163,16 +163,16 @@ def main(): # Results comparison if len(results) >= 2: - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("📊 PERFORMANCE COMPARISON RESULTS") - print(f"{'='*60}") + print(f"{'=' * 60}") sqlite_result = next((r for r in results if r["db_type"] == "SQLite"), None) mysql_result = next((r for r in results if r["db_type"] == "MySQL"), None) if sqlite_result and mysql_result: print(f"{'Metric':<20} {'SQLite':<15} {'MySQL':<15} {'Winner':<10}") - print(f"{'-'*60}") + print(f"{'-' * 60}") # Insert performance sqlite_insert = sqlite_result["insert_time"] diff --git a/tests/mysql_support/litellm_mysql_test_suite.py b/tests/mysql_support/litellm_mysql_test_suite.py index 7918dcb8..b6848ac7 100644 --- a/tests/mysql_support/litellm_mysql_test_suite.py +++ b/tests/mysql_support/litellm_mysql_test_suite.py @@ -42,12 +42,12 @@ def run_mysql_test_scenario(test_name, conscious_ingest, auto_ingest, test_input auto_ingest: Boolean for auto_ingest parameter (None to omit) test_inputs: List of test inputs to process """ - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"🧪 Running MySQL Test: {test_name}") print( f"Configuration: conscious_ingest={conscious_ingest}, auto_ingest={auto_ingest}" ) - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") # Create database directory for this test db_dir = f"mysql_test_databases/{test_name}" diff --git a/tests/openai/openai_support/openai_test_suite.py b/tests/openai/openai_support/openai_test_suite.py index 1325f26b..4f47ee65 100644 --- a/tests/openai/openai_support/openai_test_suite.py +++ b/tests/openai/openai_support/openai_test_suite.py @@ -33,12 +33,12 @@ def run_test_scenario(test_name, conscious_ingest, auto_ingest, test_inputs): auto_ingest: Boolean for auto_ingest parameter test_inputs: List of test inputs to process """ - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Running OpenAI Test: {test_name}") print( f"Configuration: conscious_ingest={conscious_ingest}, auto_ingest={auto_ingest}" ) - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") # Create database directory for this test root_dir = os.getcwd() diff --git a/tests/openai_support/openai_test.py b/tests/openai_support/openai_test.py index b0c56949..0fea0a78 100644 --- a/tests/openai_support/openai_test.py +++ b/tests/openai_support/openai_test.py @@ -46,7 +46,7 @@ def run_openai_test_scenario( test_inputs: List of test inputs to process openai_config: OpenAI configuration dictionary """ - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Running OpenAI Test: {test_name}") print( f"Configuration: conscious_ingest={conscious_ingest}, auto_ingest={auto_ingest}" @@ -56,7 +56,7 @@ def run_openai_test_scenario( print(f"Base URL: {openai_config['base_url']}") if openai_config["organization"]: print(f"Organization: {openai_config['organization']}") - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") # Create database directory for this test db_dir = f"test_databases_openai/{test_name}" @@ -166,9 +166,9 @@ def run_openai_test_scenario( print(f"\n✓ OpenAI Test '{test_name}' completed.") print(f" Database saved at: {db_path}") - total = max(1, len(test_inputs)) # Prevent divide-by-zero + # total = max(1, len(test_inputs)) # Prevent divide-by-zero print( - f" Success rate: {success_count}/{len(test_inputs)} ({100*success_count/total:.1f}%)\n" + f" Success rate: {success_count}/{len(test_inputs)} ({100 * success_count / len(test_inputs):.1f}%)\n" ) return success_count > 0 diff --git a/tests/openai_support/openai_test_suite.py b/tests/openai_support/openai_test_suite.py index 9b466dc6..c94e0058 100644 --- a/tests/openai_support/openai_test_suite.py +++ b/tests/openai_support/openai_test_suite.py @@ -144,7 +144,7 @@ def display_summary( successful_providers += 1 print( - f"\n🏆 Success Rate: {successful_providers}/{total_providers} providers ({100*successful_providers/total_providers:.1f}%)" + f"\n🏆 Success Rate: {successful_providers}/{total_providers} providers ({100 * successful_providers / total_providers:.1f}%)" ) # Database Statistics diff --git a/tests/openai_support/test_streaming_proxy.py b/tests/openai_support/test_streaming_proxy.py new file mode 100644 index 00000000..ec2a9820 --- /dev/null +++ b/tests/openai_support/test_streaming_proxy.py @@ -0,0 +1,140 @@ +import asyncio +from types import SimpleNamespace + +import pytest +from openai._streaming import AsyncStream, Stream +from openai.types.completion_usage import CompletionUsage + +from memori.utils.streaming_proxy import create_openai_streaming_proxy + + +class DummyStream(Stream): + """Minimal synchronous OpenAI stream stub for testing.""" + + def __init__(self, chunks): + self._chunks = iter(chunks) + self.response = SimpleNamespace(close=lambda: None) + + def __iter__(self): + return self + + def __next__(self): + return next(self._chunks) + + def close(self): + self.response.close() + + +class DummyAsyncStream(AsyncStream): + """Minimal asynchronous OpenAI stream stub for testing.""" + + def __init__(self, chunks): + self._chunks = iter(chunks) + self.response = SimpleNamespace(aclose=lambda: None) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._chunks) + except StopIteration as exc: # pragma: no cover - mirrors real behaviour + raise StopAsyncIteration from exc + + async def aclose(self): + await asyncio.sleep(0) + + +def _make_chunk( + *, + content: str | None = None, + finish_reason: str | None = None, + role: str | None = None, + usage: CompletionUsage | None = None, +): + delta_attrs = {} + if content is not None: + delta_attrs["content"] = content + if role is not None: + delta_attrs["role"] = role + + choice = SimpleNamespace( + index=0, + delta=SimpleNamespace(**delta_attrs) if delta_attrs else None, + finish_reason=finish_reason, + logprobs=None, + ) + + return SimpleNamespace( + id="chatcmpl-test", + created=123, + model="gpt-4o", + choices=[choice], + service_tier="scale" if finish_reason else None, + system_fingerprint="fingerprint-xyz", + usage=usage, + ) + + +def test_sync_streaming_proxy_aggregates_chunks_and_invokes_finalize(): + usage = CompletionUsage(prompt_tokens=5, completion_tokens=7, total_tokens=12) + chunks = [ + _make_chunk(content="Hello", role="assistant"), + _make_chunk(content=" world", finish_reason="stop", usage=usage), + ] + + captured = {} + + async def finalize_callback(final_response, context): + captured["response"] = final_response + captured["context"] = context + + proxy = create_openai_streaming_proxy( + DummyStream(chunks), + finalize_callback=finalize_callback, + context_data={"req": 1}, + ) + + emitted_chunks = list(proxy) + + assert emitted_chunks == chunks + assert captured["context"] == {"req": 1} + final = captured["response"] + assert final is not None + assert final.model == "gpt-4o" + assert final.choices[0].message.content == "Hello world" + assert final.choices[0].finish_reason == "stop" + assert final.usage and final.usage.total_tokens == 12 + + +@pytest.mark.asyncio +async def test_async_streaming_proxy_invokes_async_finalize_callback(): + usage = CompletionUsage(prompt_tokens=3, completion_tokens=4, total_tokens=7) + chunks = [ + _make_chunk(content="Streaming", role="assistant"), + _make_chunk(content=" done", finish_reason="stop", usage=usage), + ] + + captured = {} + + async def finalize_callback(final_response, context): + captured["response"] = final_response + captured["context"] = context + + proxy = create_openai_streaming_proxy( + DummyAsyncStream(chunks), + finalize_callback=finalize_callback, + context_data=("ctx",), + ) + + emitted_chunks = [] + async for chunk in proxy: + emitted_chunks.append(chunk) + + assert emitted_chunks == chunks + assert captured["context"] == ("ctx",) + final = captured["response"] + assert final is not None + assert final.choices[0].message.content == "Streaming done" + assert final.choices[0].finish_reason == "stop" + assert final.usage and final.usage.total_tokens == 7