diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 519299fd3d..e17c87ba60 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -266,7 +266,7 @@ The memory module in AgentScope currently supports: - **In-memory storage**: For lightweight, temporary memory needs - **Relational databases via SQLAlchemy**: For persistent, structured data storage -- **NoSQL databases**: For flexible schema requirements (e.g., Redis) +- **NoSQL databases**: For flexible schema requirements (e.g., Redis, Tablestore) **⚠️ Important Notice:** diff --git a/CONTRIBUTING_zh.md b/CONTRIBUTING_zh.md index 59ac4e397b..0960a475a6 100644 --- a/CONTRIBUTING_zh.md +++ b/CONTRIBUTING_zh.md @@ -261,7 +261,7 @@ AgentScope 的记忆模块目前支持: - **内存存储**:用于轻量级的临时记忆需求 - **通过 SQLAlchemy 支持关系型数据库**:用于持久化的结构化数据存储 -- **NoSQL 数据库**:用于灵活的模式需求(例如 Redis) +- **NoSQL 数据库**:用于灵活的模式需求(例如 Redis、表格存储) **⚠️ 请注意:** diff --git a/docs/tutorial/en/src/task_memory.py b/docs/tutorial/en/src/task_memory.py index 8b9aedb377..00f48058a8 100644 --- a/docs/tutorial/en/src/task_memory.py +++ b/docs/tutorial/en/src/task_memory.py @@ -37,6 +37,8 @@ - An asynchronous SQLAlchemy-based implementation of memory storage, which supports various databases such as SQLite, PostgreSQL, MySQL, etc. * - ``RedisMemory`` - A Redis-based implementation of memory storage. + * - ``TablestoreMemory`` + - An Alibaba Cloud Tablestore-based implementation of memory storage, enabling persistent and searchable memory across distributed environments. .. tip:: If you're interested in contributing new memory storage implementations, please refer to the `Contribution Guide `_. @@ -410,6 +412,122 @@ async def redis_memory_example() -> None: # await client.aclose() # # +# Tablestore Memory +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# AgentScope also provides a memory implementation based on +# `Alibaba Cloud Tablestore `_, +# a fully managed NoSQL database service. ``TablestoreMemory`` enables +# persistent and searchable memory across distributed environments, with +# built-in support for multi-user and multi-session isolation. +# +# First, install the required packages: +# +# .. code-block:: bash +# +# pip install tablestore tablestore-for-agent-memory +# +# Then, you can initialize the Tablestore memory as follows: +# +# .. code-block:: python +# :caption: Tablestore Memory Basic Usage +# +# import asyncio +# from agentscope.memory import TablestoreMemory +# from agentscope.message import Msg +# +# +# async def tablestore_memory_example(): +# # Create the Tablestore memory +# memory = TablestoreMemory( +# end_point="https://your-instance.cn-hangzhou.ots.aliyuncs.com", +# instance_name="your-instance-name", +# access_key_id="your-access-key-id", +# access_key_secret="your-access-key-secret", +# # Optionally specify user_id and session_id +# user_id="user_1", +# session_id="session_1", +# ) +# +# # Add a message to the memory +# await memory.add( +# Msg("Alice", "Generate a report about AgentScope", "user"), +# ) +# +# # Add a hint message with the mark "hint" +# await memory.add( +# Msg( +# "system", +# "Create a plan first to collect information and " +# "generate the report step by step.", +# "system", +# ), +# marks="hint", +# ) +# +# # Retrieve messages with the mark "hint" +# msgs = await memory.get_memory(mark="hint") +# for msg in msgs: +# print(f"- {msg}") +# +# # Close the Tablestore client connection when done +# await memory.close() +# +# +# asyncio.run(tablestore_memory_example()) +# +# The ``TablestoreMemory`` can also be used as an async context manager: +# +# .. code-block:: python +# :caption: Tablestore Memory as Async Context Manager +# +# async with TablestoreMemory( +# end_point="https://your-instance.cn-hangzhou.ots.aliyuncs.com", +# instance_name="your-instance-name", +# access_key_id="your-access-key-id", +# access_key_secret="your-access-key-secret", +# user_id="user_1", +# session_id="session_1", +# ) as memory: +# await memory.add( +# Msg("Alice", "Generate a report about AgentScope", "user"), +# ) +# +# msgs = await memory.get_memory() +# for msg in msgs: +# print(f"- {msg}") +# +# Similarly, ``TablestoreMemory`` can be used in production environments with FastAPI: +# +# .. code-block:: python +# :caption: Tablestore Memory in FastAPI +# +# import os +# from fastapi import FastAPI +# from agentscope.memory import TablestoreMemory +# from agentscope.message import Msg +# +# +# app = FastAPI() +# +# +# @app.post("/chat_endpoint") +# async def chat_endpoint(user_id: str, session_id: str, input: str): +# """A chat endpoint using Tablestore memory.""" +# memory = TablestoreMemory( +# end_point=os.environ["TABLESTORE_ENDPOINT"], +# instance_name=os.environ["TABLESTORE_INSTANCE_NAME"], +# access_key_id=os.environ["TABLESTORE_ACCESS_KEY_ID"], +# access_key_secret=os.environ["TABLESTORE_ACCESS_KEY_SECRET"], +# user_id=user_id, +# session_id=session_id, +# ) +# +# # Use the memory with your agent +# ... +# +# # Close the Tablestore client connection when done +# await memory.close() +# # # Customizing Memory # ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/tutorial/zh_CN/src/task_memory.py b/docs/tutorial/zh_CN/src/task_memory.py index e905f2cac6..db98df8a9f 100644 --- a/docs/tutorial/zh_CN/src/task_memory.py +++ b/docs/tutorial/zh_CN/src/task_memory.py @@ -29,6 +29,8 @@ - 基于异步 SQLAlchemy 的记忆存储实现,支持如 SQLite、PostgreSQL、MySQL 等多种关系数据库。 * - ``RedisMemory`` - 基于 Redis 的记忆存储实现。 + * - ``TablestoreMemory`` + - 基于阿里云表格存储(Tablestore)的记忆存储实现,支持分布式环境下的持久化和可搜索记忆。 .. tip:: 如果您有兴趣贡献新的记忆存储实现,请参考 `贡献指南 `_。 @@ -397,6 +399,121 @@ async def redis_memory_example() -> None: # await client.aclose() # # +# 表格存储记忆(Tablestore Memory) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# AgentScope 还提供了基于 +# `阿里云表格存储(Tablestore) `_ +# 的记忆实现。``TablestoreMemory`` 支持分布式环境下的持久化和可搜索记忆, +# 并内置多用户和多会话隔离。 +# +# 首先,安装所需的依赖包: +# +# .. code-block:: bash +# +# pip install tablestore tablestore-for-agent-memory +# +# 然后,可以按如下方式初始化 Tablestore 记忆: +# +# .. code-block:: python +# :caption: Tablestore 记忆基本用法 +# +# import asyncio +# from agentscope.memory import TablestoreMemory +# from agentscope.message import Msg +# +# +# async def tablestore_memory_example(): +# # 创建 Tablestore 记忆 +# memory = TablestoreMemory( +# end_point="https://your-instance.cn-hangzhou.ots.aliyuncs.com", +# instance_name="your-instance-name", +# access_key_id="your-access-key-id", +# access_key_secret="your-access-key-secret", +# # 可选地指定 user_id 和 session_id +# user_id="user_1", +# session_id="session_1", +# ) +# +# # 向记忆中添加消息 +# await memory.add( +# Msg("Alice", "生成一份关于AgentScope的报告", "user"), +# ) +# +# # 添加一条带有标记"hint"的提示消息 +# await memory.add( +# Msg( +# "system", +# "首先创建一个计划来收集信息," +# "然后逐步生成报告。", +# "system", +# ), +# marks="hint", +# ) +# +# # 检索带有标记"hint"的消息 +# msgs = await memory.get_memory(mark="hint") +# for msg in msgs: +# print(f"- {msg}") +# +# # 完成后关闭 Tablestore 客户端连接 +# await memory.close() +# +# +# asyncio.run(tablestore_memory_example()) +# +# ``TablestoreMemory`` 也可以用作异步上下文管理器: +# +# .. code-block:: python +# :caption: Tablestore 记忆作为异步上下文管理器 +# +# async with TablestoreMemory( +# end_point="https://your-instance.cn-hangzhou.ots.aliyuncs.com", +# instance_name="your-instance-name", +# access_key_id="your-access-key-id", +# access_key_secret="your-access-key-secret", +# user_id="user_1", +# session_id="session_1", +# ) as memory: +# await memory.add( +# Msg("Alice", "生成一份关于AgentScope的报告", "user"), +# ) +# +# msgs = await memory.get_memory() +# for msg in msgs: +# print(f"- {msg}") +# +# 同样,``TablestoreMemory`` 也可以在生产环境中与 FastAPI 一起使用: +# +# .. code-block:: python +# :caption: FastAPI 中使用 Tablestore 记忆 +# +# import os +# from fastapi import FastAPI +# from agentscope.memory import TablestoreMemory +# from agentscope.message import Msg +# +# +# app = FastAPI() +# +# +# @app.post("/chat_endpoint") +# async def chat_endpoint(user_id: str, session_id: str, input: str): +# """使用 Tablestore 记忆的聊天端点。""" +# memory = TablestoreMemory( +# end_point=os.environ["TABLESTORE_ENDPOINT"], +# instance_name=os.environ["TABLESTORE_INSTANCE_NAME"], +# access_key_id=os.environ["TABLESTORE_ACCESS_KEY_ID"], +# access_key_secret=os.environ["TABLESTORE_ACCESS_KEY_SECRET"], +# user_id=user_id, +# session_id=session_id, +# ) +# +# # 使用记忆与智能体交互 +# ... +# +# # 完成后关闭 Tablestore 客户端连接 +# await memory.close() +# # # 自定义记忆(Customizing Memory) # ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index e3615dd695..ff2c085b26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,10 @@ tokens = [ # ------------ Memory ------------ redis_memory = ["redis"] +tablestore_memory = [ + "tablestore-for-agent-memory", + "tablestore", +] mem0ai = [ "mem0ai<=1.0.3", @@ -83,6 +87,7 @@ memory = [ "agentscope[redis_memory]", "agentscope[mem0ai]", "agentscope[reme]", + "agentscope[tablestore_memory]", ] # ------------ RAG ------------ @@ -171,6 +176,7 @@ dev = [ # Development tools "pre-commit", "pytest", + "pytest-asyncio", "pytest-forked", "sphinx-gallery", "furo", diff --git a/src/agentscope/memory/__init__.py b/src/agentscope/memory/__init__.py index 8af080ba09..abf665bbcd 100644 --- a/src/agentscope/memory/__init__.py +++ b/src/agentscope/memory/__init__.py @@ -6,6 +6,7 @@ InMemoryMemory, RedisMemory, AsyncSQLAlchemyMemory, + TablestoreMemory, ) from ._long_term_memory import ( LongTermMemoryBase, @@ -22,6 +23,7 @@ "InMemoryMemory", "RedisMemory", "AsyncSQLAlchemyMemory", + "TablestoreMemory", # Long-term memory "LongTermMemoryBase", "Mem0LongTermMemory", diff --git a/src/agentscope/memory/_working_memory/__init__.py b/src/agentscope/memory/_working_memory/__init__.py index a092f0e58e..daa727b9de 100644 --- a/src/agentscope/memory/_working_memory/__init__.py +++ b/src/agentscope/memory/_working_memory/__init__.py @@ -7,10 +7,12 @@ from ._in_memory_memory import InMemoryMemory from ._redis_memory import RedisMemory from ._sqlalchemy_memory import AsyncSQLAlchemyMemory +from ._tablestore_memory import TablestoreMemory __all__ = [ "MemoryBase", "InMemoryMemory", "RedisMemory", "AsyncSQLAlchemyMemory", + "TablestoreMemory", ] diff --git a/src/agentscope/memory/_working_memory/_tablestore_memory.py b/src/agentscope/memory/_working_memory/_tablestore_memory.py new file mode 100644 index 0000000000..ac90074c34 --- /dev/null +++ b/src/agentscope/memory/_working_memory/_tablestore_memory.py @@ -0,0 +1,851 @@ +# -*- coding: utf-8 -*- +"""The Tablestore-based working memory implementation for agentscope.""" +import asyncio +import copy +import json +from typing import Any, Optional + + +from ..._logging import logger +from ...message import Msg +from ._base import MemoryBase + + +class TablestoreMemory(MemoryBase): + """A Tablestore-based working memory implementation using + ``tablestore_for_agent_memory``'s ``AsyncKnowledgeStore``. + + This memory stores messages in Alibaba Cloud Tablestore, enabling + persistent and searchable memory across distributed environments. + Messages are stored as documents with optional embedding vectors + for semantic search. + """ + + _SEARCH_INDEX_NAME = "agentscope_memory_search_index" + + def __init__( + self, + end_point: str, + instance_name: str, + access_key_id: str, + access_key_secret: str, + user_id: str = "default", + session_id: str = "default", + sts_token: Optional[str] = None, + table_name: str = "agentscope_memory", + text_field: str = "text", + embedding_field: str = "embedding", + vector_dimension: int = 0, + **kwargs: Any, + ) -> None: + """Initialize the Tablestore memory. + + Args: + end_point (`str`): + The endpoint of the Tablestore instance. + instance_name (`str`): + The name of the Tablestore instance. + access_key_id (`str`): + The access key ID for authentication. + access_key_secret (`str`): + The access key secret for authentication. + user_id (`str`, defaults to ``"default"``): + The user ID for multi-tenant isolation. + session_id (`str`, defaults to ``"default"``): + The session ID for session-level isolation. + sts_token (`str | None`, optional): + The STS token for temporary credentials. + table_name (`str`, defaults to ``"agentscope_memory"``): + The table name for storing memory documents. + text_field (`str`, defaults to ``"text"``): + The field name for text content in Tablestore. + embedding_field (`str`, defaults to ``"embedding"``): + The field name for embedding vectors in Tablestore. + vector_dimension (`int`, defaults to ``0``): + The dimension of the embedding vectors. Set to ``0`` + if not using vector search. + **kwargs (`Any`): + Additional keyword arguments passed to the + ``AsyncKnowledgeStore``. + """ + super().__init__() + + try: + from tablestore import ( + AsyncOTSClient as AsyncTablestoreClient, + WriteRetryPolicy, + FieldSchema, + FieldType, + ) + except ImportError as exc: + raise ImportError( + "The 'tablestore' and 'tablestore-for-agent-memory' packages " + "are required for TablestoreMemory. Please install them via " + "'pip install tablestore tablestore-for-agent-memory'.", + ) from exc + + self._user_id = user_id + self._session_id = session_id + self._table_name = table_name + self._text_field = text_field + self._embedding_field = embedding_field + self._vector_dimension = vector_dimension + + self._tablestore_client = AsyncTablestoreClient( + end_point=end_point, + access_key_id=access_key_id, + access_key_secret=access_key_secret, + instance_name=instance_name, + sts_token=None if sts_token == "" else sts_token, + retry_policy=WriteRetryPolicy(), + ) + + self._search_index_schema = [ + FieldSchema("document_id", FieldType.KEYWORD), + FieldSchema("tenant_id", FieldType.KEYWORD), + FieldSchema("session_id", FieldType.KEYWORD), + FieldSchema("marks_json", FieldType.KEYWORD, is_array=True), + ] + + self._knowledge_store = None + self._knowledge_store_kwargs = kwargs + self._initialized = False + self._init_lock = asyncio.Lock() + + async def _ensure_initialized(self) -> None: + """Lazily initialize the knowledge store on first use. + + Uses an ``asyncio.Lock`` to prevent concurrent initialization when + multiple coroutines call this method simultaneously. + """ + if self._initialized: + return + async with self._init_lock: + if self._initialized: + return + + from tablestore_for_agent_memory.knowledge.async_knowledge_store import ( # noqa: E501 + AsyncKnowledgeStore, + ) + + self._knowledge_store = AsyncKnowledgeStore( + tablestore_client=self._tablestore_client, + vector_dimension=self._vector_dimension, + table_name=self._table_name, + search_index_name=self._SEARCH_INDEX_NAME, + search_index_schema=copy.deepcopy(self._search_index_schema), + text_field=self._text_field, + embedding_field=self._embedding_field, + enable_multi_tenant=True, + **self._knowledge_store_kwargs, + ) + + await self._knowledge_store.init_table() + self._initialized = True + + _DOCUMENT_ID_SEPARATOR = ":::" + + def _make_document_id(self, msg_id: str) -> str: + """Convert a message ID to a Tablestore document ID. + + The document ID is formatted as ``{msg_id}:::{session_id}`` to + ensure uniqueness across sessions within the same tenant. + + Args: + msg_id (`str`): + The message ID. + + Returns: + `str`: + The Tablestore document ID. + """ + return f"{msg_id}{self._DOCUMENT_ID_SEPARATOR}{self._session_id}" + + def _extract_msg_id(self, document_id: str) -> str: + """Extract the message ID from a Tablestore document ID. + + The method verifies that ``document_id`` ends with + ``:::{session_id}``. If it does not, an error is logged and + the original ``document_id`` is returned as-is. + + Args: + document_id (`str`): + The Tablestore document ID in ``{msg_id}:::{session_id}`` + format. + + Returns: + `str`: + The original message ID. + """ + expected_suffix = f"{self._DOCUMENT_ID_SEPARATOR}{self._session_id}" + if not document_id.endswith(expected_suffix): + logger.error( + "Unexpected document_id format: '%s'. " + "Expected suffix ':::%s'.", + document_id, + self._session_id, + ) + return document_id + return document_id[: -len(expected_suffix)] + + def _msg_to_document(self, msg: Msg, marks: list[str]) -> Any: + """Convert a ``Msg`` to a Tablestore document. + + The document ID is formatted as ``{msg.id}:::{session_id}``. + ``self._user_id`` is used as ``tenant_id``. + + Args: + msg (`Msg`): + The message to convert. + marks (`list[str]`): + The marks associated with the message. + + Returns: + A ``Document`` object for Tablestore. + """ + from tablestore_for_agent_memory.base.base_knowledge_store import ( + Document as TablestoreDocument, + ) + + text_content = json.dumps( + msg.to_dict(), + ensure_ascii=False, + default=str, + ) + + metadata = { + "session_id": self._session_id, + "name": msg.name, + "role": msg.role, + "timestamp": msg.timestamp or "", + "invocation_id": msg.invocation_id or "", + "marks_json": json.dumps(marks, ensure_ascii=False), + } + + return TablestoreDocument( + document_id=self._make_document_id(msg.id), + text=text_content, + tenant_id=self._user_id, + metadata=metadata, + ) + + @staticmethod + def _document_to_msg_and_marks(document: Any) -> tuple[Msg, list[str]]: + """Convert a Tablestore document back to a ``Msg`` and marks. + + The ``Msg`` is restored entirely from the JSON-serialized text + stored in ``document.text``. The ``document_id`` is in + ``{msg_id}:::{session_id}`` format and is not used for restoring + the ``Msg``. + + Args: + document: + The Tablestore document to convert. + + Returns: + A tuple of (``Msg``, marks list). + """ + msg_dict = json.loads(document.text) + msg = Msg.from_dict(msg_dict) + + metadata = document.metadata or {} + marks: list[str] = [] + marks_json = metadata.get("marks_json", "[]") + try: + marks = json.loads(marks_json) + except (json.JSONDecodeError, TypeError): + pass + + return msg, marks + + async def add( + self, + memories: Msg | list[Msg] | None, + marks: str | list[str] | None = None, + allow_duplicates: bool = True, + **kwargs: Any, + ) -> None: + """Add message(s) into the memory storage with the given mark + (if provided). + + Args: + memories (`Msg | list[Msg] | None`): + The message(s) to be added. + marks (`str | list[str] | None`, optional): + The mark(s) to associate with the message(s). If `None`, no + mark is associated. + allow_duplicates (`bool`, defaults to ``True``): + Whether to allow duplicate messages. + """ + if memories is None: + return + + await self._ensure_initialized() + + if isinstance(memories, Msg): + memories = [memories] + + if marks is None: + marks_list: list[str] = [] + elif isinstance(marks, str): + marks_list = [marks] + elif isinstance(marks, list) and all( + isinstance(m, str) for m in marks + ): + marks_list = marks + else: + raise TypeError( + f"The mark should be a string, a list of strings, or None, " + f"but got {type(marks)}.", + ) + + if not allow_duplicates: + # Filter out duplicates + existing_ids = await self._get_existing_msg_ids_in_session( + [msg.id for msg in memories], + ) + memories = [msg for msg in memories if msg.id not in existing_ids] + + put_tasks = [] + for msg in memories: + document = self._msg_to_document(msg, marks_list) + put_tasks.append( + self._knowledge_store.put_document(document), + ) + await asyncio.gather(*put_tasks) + + async def _get_existing_msg_ids_in_session( + self, + msg_ids: list[str], + ) -> set[str]: + """Get the IDs that actually exist in the current session from the + provided list. + + Args: + msg_ids (`list[str]`): + The list of message IDs to check. + + Returns: + `set[str]`: + The set of message IDs that exist in the current session. + """ + document_ids = [self._make_document_id(mid) for mid in msg_ids] + existing_docs = await self._knowledge_store.get_documents( + document_id_list=document_ids, + tenant_id=self._user_id, + ) + return { + self._extract_msg_id(doc.document_id) + for doc in existing_docs + if doc is not None + } + + async def _get_existing_msg_ids_and_marks_in_session( + self, + msg_ids: list[str], + ) -> dict[str, list[str]]: + """Get the IDs and their marks for messages that actually exist in + the current session from the provided list. + + Args: + msg_ids (`list[str]`): + The list of message IDs to check. + + Returns: + `dict[str, list[str]]`: + A mapping from message ID to its list of marks, only for + messages that exist in the current session. + """ + document_ids = [self._make_document_id(mid) for mid in msg_ids] + existing_docs = await self._knowledge_store.get_documents( + document_id_list=document_ids, + tenant_id=self._user_id, + ) + + result_map: dict[str, list[str]] = {} + for doc in existing_docs: + if doc is None: + continue + msg_id = self._extract_msg_id(doc.document_id) + metadata = doc.metadata or {} + marks_json = metadata.get("marks_json", "[]") + try: + msg_marks = json.loads(marks_json) + except (json.JSONDecodeError, TypeError): + msg_marks = [] + result_map[msg_id] = msg_marks + return result_map + + async def _get_all_msg_ids(self) -> set[str]: + """Get all message IDs currently stored for this user/session.""" + return await self._search_msg_ids_by_marks() + + async def _get_all_msg_ids_and_marks(self) -> dict[str, list[str]]: + """Get all message IDs and their marks for this user/session. + + Returns: + `dict[str, list[str]]`: + A mapping from message ID to its full list of marks. + """ + return await self._search_msg_ids_and_marks_by_marks() + + async def _search_msg_ids_by_marks( + self, + marks: str | list[str] | None = None, + ) -> set[str]: + """Search for message IDs, optionally filtered by marks. + + Uses ``Filters.In`` on the ``marks_json`` field (when marks is + provided) combined with ``session_id`` and ``tenant_id`` filters + to query matching documents via ``search_documents``. + + Args: + marks (`str | list[str] | None`, optional): + A single mark string or list of marks to filter by. + If provided, returns messages that contain **any** of + the specified marks. If ``None``, returns all message + IDs in the session. + + Returns: + `set[str]`: + The set of message IDs matching the filter criteria. + """ + from tablestore_for_agent_memory.base.filter import Filters + + if isinstance(marks, str): + marks = [marks] + + conditions = [Filters.eq("session_id", self._session_id)] + if marks: + conditions.append(Filters.In("marks_json", marks)) + + matched_ids: set[str] = set() + next_token = None + while True: + result = await self._knowledge_store.search_documents( + tenant_id=self._user_id, + metadata_filter=Filters.logical_and(conditions), + next_token=next_token, + ) + for hit in result.hits: + document_id = hit.document.document_id + if document_id: + matched_ids.add(self._extract_msg_id(document_id)) + next_token = result.next_token + if not next_token: + break + return matched_ids + + async def _search_msg_ids_and_marks_by_marks( + self, + marks: str | list[str] | None = None, + ) -> dict[str, list[str]]: + """Search for message IDs and their marks, optionally filtered by + marks. + + Uses ``Filters.In`` on the ``marks_json`` field (when marks is + provided) combined with ``session_id`` and ``tenant_id`` filters + to query matching documents via ``search_documents``. + + Args: + marks (`str | list[str] | None`, optional): + A single mark string or list of marks to filter by. + If provided, returns messages that contain **any** of + the specified marks. If ``None``, returns all messages + and their marks. + + Returns: + `dict[str, list[str]]`: + A mapping from message ID to its full list of marks. + """ + from tablestore_for_agent_memory.base.filter import Filters + + if isinstance(marks, str): + marks = [marks] + + conditions = [Filters.eq("session_id", self._session_id)] + if marks: + conditions.append(Filters.In("marks_json", marks)) + + result_map: dict[str, list[str]] = {} + next_token = None + while True: + result = await self._knowledge_store.search_documents( + tenant_id=self._user_id, + metadata_filter=Filters.logical_and(conditions), + meta_data_to_get=["marks_json"], + next_token=next_token, + ) + for hit in result.hits: + document_id = hit.document.document_id + if document_id: + msg_id = self._extract_msg_id(document_id) + metadata = hit.document.metadata or {} + marks_json = metadata.get("marks_json", "[]") + try: + msg_marks = json.loads(marks_json) + except (json.JSONDecodeError, TypeError): + msg_marks = [] + result_map[msg_id] = msg_marks + next_token = result.next_token + if not next_token: + break + return result_map + + async def delete( + self, + msg_ids: list[str], + **kwargs: Any, + ) -> int: + """Remove message(s) from the storage by their IDs. + + Args: + msg_ids (`list[str]`): + The list of message IDs to be removed. + + Returns: + `int`: + The number of messages removed. + """ + await self._ensure_initialized() + + # Get only the IDs that actually exist in the current session + existing_ids = await self._get_existing_msg_ids_in_session(msg_ids) + + delete_tasks = [ + self._knowledge_store.delete_document( + document_id=self._make_document_id(msg_id), + tenant_id=self._user_id, + ) + for msg_id in existing_ids + ] + + if delete_tasks: + await asyncio.gather(*delete_tasks) + + return len(existing_ids) + + async def delete_by_mark( + self, + mark: str | list[str], + **kwargs: Any, + ) -> int: + """Remove messages from the memory by their marks. + + Args: + mark (`str | list[str]`): + The mark(s) of the messages to be removed. + + Raises: + `TypeError`: + If the provided mark is not a string or a list of strings. + + Returns: + `int`: + The number of messages removed. + """ + if isinstance(mark, str): + mark = [mark] + + if not isinstance(mark, list) or not all( + isinstance(m, str) for m in mark + ): + raise TypeError( + f"The mark should be a string or a list of strings, " + f"but got {type(mark)}.", + ) + + await self._ensure_initialized() + + matched_msg_ids = await self._search_msg_ids_by_marks(mark) + if not matched_msg_ids: + return 0 + + delete_tasks = [ + self._knowledge_store.delete_document( + document_id=self._make_document_id(msg_id), + tenant_id=self._user_id, + ) + for msg_id in matched_msg_ids + ] + await asyncio.gather(*delete_tasks) + + return len(matched_msg_ids) + + async def size(self) -> int: + """Get the number of messages in the storage. + + Returns: + `int`: + The number of messages in the storage. + """ + await self._ensure_initialized() + all_msg_ids = await self._get_all_msg_ids() + return len(all_msg_ids) + + async def clear(self) -> None: + """Clear the memory content for the current session.""" + await self._ensure_initialized() + + all_msg_ids = await self._get_all_msg_ids() + if not all_msg_ids: + return + + delete_tasks = [ + self._knowledge_store.delete_document( + document_id=self._make_document_id(msg_id), + tenant_id=self._user_id, + ) + for msg_id in all_msg_ids + ] + await asyncio.gather(*delete_tasks) + + async def get_memory( + self, + mark: str | None = None, + exclude_mark: str | None = None, + prepend_summary: bool = True, + **kwargs: Any, + ) -> list[Msg]: + """Get the messages from the memory by mark (if provided). Otherwise, + get all messages. + + .. note:: If `mark` and `exclude_mark` are both provided, the messages + will be filtered by both arguments, and they should not overlap. + + Args: + mark (`str | None`, optional): + The mark to filter messages. If `None`, return all messages. + exclude_mark (`str | None`, optional): + The mark to exclude messages. If provided, messages with + this mark will be excluded from the results. + prepend_summary (`bool`, defaults to True): + Whether to prepend the compressed summary as a message + + Returns: + `list[Msg]`: + The list of messages retrieved from the storage. + """ + if not (mark is None or isinstance(mark, str)): + raise TypeError( + f"The mark should be a string or None, but got {type(mark)}.", + ) + + if not (exclude_mark is None or isinstance(exclude_mark, str)): + raise TypeError( + f"The exclude_mark should be a string or None, but got " + f"{type(exclude_mark)}.", + ) + + await self._ensure_initialized() + + all_docs = await self._search_documents_by_marks_and_exclude_marks( + marks=mark, + exclude_marks=exclude_mark, + ) + + results: list[Msg] = [] + + for doc in all_docs: + msg, _ = self._document_to_msg_and_marks(doc) + results.append(msg) + + if prepend_summary and self._compressed_summary: + return [ + Msg("user", self._compressed_summary, "user"), + *results, + ] + + return results + + async def update_messages_mark( + self, + new_mark: str | None, + old_mark: str | None = None, + msg_ids: list[str] | None = None, + ) -> int: + """A unified method to update marks of messages in the storage (add, + remove, or change marks). + + - If `msg_ids` is provided, the update will be applied to the messages + with the specified IDs. + - If `old_mark` is provided, the update will be applied to the + messages with the specified old mark. Otherwise, the `new_mark` will + be added to all messages (or those filtered by `msg_ids`). + - If `new_mark` is `None`, the mark will be removed from the messages. + + Args: + new_mark (`str | None`, optional): + The new mark to set for the messages. If `None`, the mark + will be removed. + old_mark (`str | None`, optional): + The old mark to filter messages. If `None`, this constraint + is ignored. + msg_ids (`list[str] | None`, optional): + The list of message IDs to be updated. If `None`, this + constraint is ignored. + + Returns: + `int`: + The number of messages updated. + """ + await self._ensure_initialized() + + # Get msg_ids and their marks + if msg_ids is not None: + # Get the marks for the provided msg_ids, + # use msg id to search is faster than using marks + id_to_marks = ( + await self._get_existing_msg_ids_and_marks_in_session( + msg_ids, + ) + ) + else: + id_to_marks = await self._search_msg_ids_and_marks_by_marks( + old_mark, + ) + + # Collect msg_ids that need mark updates + ids_to_update: dict[str, list[str]] = {} + for msg_id, current_marks in id_to_marks.items(): + if old_mark is not None and old_mark not in current_marks: + continue + + updated_marks = current_marks.copy() + changed = False + + if new_mark is None: + if old_mark in updated_marks: + updated_marks.remove(old_mark) + changed = True + else: + if old_mark is not None and old_mark in updated_marks: + updated_marks.remove(old_mark) + changed = True + if new_mark not in updated_marks: + updated_marks.append(new_mark) + changed = True + + if changed: + ids_to_update[msg_id] = updated_marks + + if not ids_to_update: + return 0 + + from tablestore_for_agent_memory.base.base_knowledge_store import ( + Document as TablestoreDocument, + ) + + update_tasks = [] + for msg_id, updated_marks in ids_to_update.items(): + update_doc = TablestoreDocument( + document_id=self._make_document_id(msg_id), + tenant_id=self._user_id, + metadata={ + "marks_json": json.dumps( + updated_marks, + ensure_ascii=False, + ), + }, + ) + update_tasks.append( + self._knowledge_store.update_document(update_doc), + ) + + await asyncio.gather(*update_tasks) + return len(ids_to_update) + + async def _search_documents_by_marks_and_exclude_marks( + self, + marks: str | list[str] | None = None, + exclude_marks: str | list[str] | None = None, + ) -> list: + """Get all documents filtered by inclusion and/or exclusion marks. + + Dynamically builds filters based on the provided arguments: + - If ``marks`` is provided, uses ``Filters.In`` to include only + documents with any of the specified marks. + - If ``exclude_marks`` is provided, uses ``Filters.not_in`` to + exclude documents with any of the specified marks. + - If neither is provided, returns all documents in the session. + + Args: + marks (`str | list[str] | None`, optional): + A single mark string or list of marks to include. + If ``None``, no inclusion filter is applied. + exclude_marks (`str | list[str] | None`, optional): + A single mark string or list of marks to exclude. + If ``None``, no exclusion filter is applied. + + Returns: + `list`: + A list of Tablestore documents matching the filter criteria. + """ + from tablestore_for_agent_memory.base.filter import Filters + + if isinstance(marks, str): + marks = [marks] + if isinstance(exclude_marks, str): + exclude_marks = [exclude_marks] + + conditions = [Filters.eq("session_id", self._session_id)] + if marks: + conditions.append(Filters.In("marks_json", marks)) + if exclude_marks: + conditions.append(Filters.not_in("marks_json", exclude_marks)) + + all_docs: list = [] + next_token = None + while True: + result = await self._knowledge_store.search_documents( + tenant_id=self._user_id, + metadata_filter=Filters.logical_and(conditions), + meta_data_to_get=[ + self._text_field, + "name", + "role", + "timestamp", + "marks_json", + "session_id", + "invocation_id", + ], + next_token=next_token, + ) + all_docs.extend(hit.document for hit in result.hits) + next_token = result.next_token + if not next_token: + break + + # Sort documents by timestamp to maintain message order + all_docs.sort( + key=lambda doc: (doc.metadata or {}).get("timestamp", ""), + ) + return all_docs + + async def close(self) -> None: + """Close the Tablestore client connection.""" + if self._knowledge_store is not None: + await self._knowledge_store.close() + self._knowledge_store = None + self._initialized = False + + def state_dict(self) -> dict: + """Get the state dictionary for serialization. + + Note: Only the compressed summary is serialized. The actual memory + content is persisted in Tablestore. + """ + return { + "_compressed_summary": self._compressed_summary, + } + + def load_state_dict(self, state_dict: dict, strict: bool = True) -> None: + """Load the state dictionary for deserialization. + + Args: + state_dict (`dict`): + The state dictionary to load. + strict (`bool`, defaults to ``True``): + If ``True``, raises an error if required keys are missing. + """ + self._compressed_summary = state_dict.get("_compressed_summary", "") diff --git a/src/agentscope/session/__init__.py b/src/agentscope/session/__init__.py index c0ac730c16..f37b8ac9f1 100644 --- a/src/agentscope/session/__init__.py +++ b/src/agentscope/session/__init__.py @@ -4,9 +4,11 @@ from ._session_base import SessionBase from ._json_session import JSONSession from ._redis_session import RedisSession +from ._tablestore_session import TablestoreSession __all__ = [ "SessionBase", "JSONSession", "RedisSession", + "TablestoreSession", ] diff --git a/src/agentscope/session/_tablestore_session.py b/src/agentscope/session/_tablestore_session.py new file mode 100644 index 0000000000..76837e5d3a --- /dev/null +++ b/src/agentscope/session/_tablestore_session.py @@ -0,0 +1,272 @@ +# -*- coding: utf-8 -*- +"""The Tablestore session class for agentscope.""" +import asyncio +import json +from typing import Any, Optional + +from ._session_base import SessionBase +from .._logging import logger +from ..module import StateModule + + +class TablestoreSession(SessionBase): + """A Tablestore-based session implementation using + ``tablestore_for_agent_memory``'s ``AsyncMemoryStore``. + + This session stores and retrieves agent state via the session table's + metadata field in Tablestore, enabling persistent session management + across distributed environments. + """ + + _SESSION_SECONDARY_INDEX_NAME = "agentscope_session_secondary_index" + _SESSION_SEARCH_INDEX_NAME = "agentscope_session_search_index" + _MESSAGE_SECONDARY_INDEX_NAME = "agentscope_message_secondary_index" + _MESSAGE_SEARCH_INDEX_NAME = "agentscope_message_search_index" + + def __init__( + self, + end_point: str, + instance_name: str, + access_key_id: str, + access_key_secret: str, + sts_token: Optional[str] = None, + session_table_name: str = "agentscope_session", + message_table_name: str = "agentscope_message", + **kwargs: Any, + ) -> None: + """Initialize the Tablestore session. + + Args: + end_point (`str`): + The endpoint of the Tablestore instance. + instance_name (`str`): + The name of the Tablestore instance. + access_key_id (`str`): + The access key ID for authentication. + access_key_secret (`str`): + The access key secret for authentication. + sts_token (`str | None`, optional): + The STS token for temporary credentials. + session_table_name (`str`, defaults to + ``"agentscope_session"``): + The table name for storing sessions. + message_table_name (`str`, defaults to + ``"agentscope_message"``): + The table name for storing messages. + **kwargs (`Any`): + Additional keyword arguments passed to the + ``AsyncMemoryStore``. + """ + try: + from tablestore import ( + AsyncOTSClient as AsyncTablestoreClient, + WriteRetryPolicy, + ) + from tablestore_for_agent_memory.memory.async_memory_store import ( + AsyncMemoryStore, + ) + except ImportError as exc: + raise ImportError( + "The 'tablestore' and 'tablestore-for-agent-memory' packages " + "are required for TablestoreSession. Please install them via " + "'pip install tablestore tablestore-for-agent-memory'.", + ) from exc + + self._tablestore_client = AsyncTablestoreClient( + end_point=end_point, + access_key_id=access_key_id, + access_key_secret=access_key_secret, + instance_name=instance_name, + sts_token=None if sts_token == "" else sts_token, + retry_policy=WriteRetryPolicy(), + ) + + self._session_table_name = session_table_name + self._message_table_name = message_table_name + self._memory_store: Optional[AsyncMemoryStore] = None + self._memory_store_kwargs = kwargs + self._initialized = False + self._init_lock = asyncio.Lock() + + async def _ensure_initialized(self) -> None: + """Lazily initialize the memory store on first use. + + Uses an ``asyncio.Lock`` to prevent concurrent initialization when + multiple coroutines call this method simultaneously. + """ + if self._initialized: + return + async with self._init_lock: + if self._initialized: + return + + from tablestore_for_agent_memory.memory.async_memory_store import ( + AsyncMemoryStore, + ) + + self._memory_store = AsyncMemoryStore( + tablestore_client=self._tablestore_client, + session_table_name=self._session_table_name, + message_table_name=self._message_table_name, + session_secondary_index_name=( + self._SESSION_SECONDARY_INDEX_NAME + ), + session_search_index_name=self._SESSION_SEARCH_INDEX_NAME, + message_secondary_index_name=( + self._MESSAGE_SECONDARY_INDEX_NAME + ), + message_search_index_name=self._MESSAGE_SEARCH_INDEX_NAME, + **self._memory_store_kwargs, + ) + + await self._memory_store.init_table() + await self._memory_store.init_search_index() + self._initialized = True + + async def save_session_state( + self, + session_id: str, + user_id: str = "", + **state_modules_mapping: StateModule, + ) -> None: + """Save the session state to Tablestore. + + Each state module's ``state_dict()`` is serialized to JSON and stored + in the session table's metadata field under the key + ``"__state__"``. + + Args: + session_id (`str`): + The session id. + user_id (`str`, default to ``""``): + The user ID for the storage. + **state_modules_mapping (`dict[str, StateModule]`): + A dictionary mapping of state module names to their instances. + """ + from tablestore_for_agent_memory.base.base_memory_store import ( + Session as TablestoreSessionModel, + ) + + await self._ensure_initialized() + + state_dicts = { + name: state_module.state_dict() + for name, state_module in state_modules_mapping.items() + } + serialized_state = json.dumps(state_dicts, ensure_ascii=False) + + # Create a session model + tablestore_session = TablestoreSessionModel( + session_id=session_id, + user_id=user_id or "default", + metadata={"__state__": serialized_state}, + ) + await self._memory_store.update_session(tablestore_session) + + logger.info( + "Saved session state to Tablestore for session '%s'.", + session_id, + ) + + async def load_session_state( + self, + session_id: str, + user_id: str = "", + allow_not_exist: bool = True, + **state_modules_mapping: StateModule, + ) -> None: + """Load the session state from Tablestore. + + The state is read from the session table's metadata field + under the key ``"__state__"``. + + Args: + session_id (`str`): + The session id. + user_id (`str`, default to ``""``): + The user ID for the storage. + allow_not_exist (`bool`, defaults to ``True``): + Whether to allow the session to not exist. + **state_modules_mapping (`dict[str, StateModule]`): + The mapping of state modules to be loaded. + """ + await self._ensure_initialized() + + tablestore_session = await self._memory_store.get_session( + user_id=user_id or "default", + session_id=session_id, + ) + + if not tablestore_session: + if allow_not_exist: + logger.info( + "Session '%s' does not exist in Tablestore. " + "Skip loading session state.", + session_id, + ) + return + raise ValueError( + f"Failed to load session state because session " + f"'{session_id}' does not exist in Tablestore.", + ) + + state_content = (tablestore_session.metadata or {}).get("__state__") + + if state_content is None: + if allow_not_exist: + logger.info( + "No state data found for session '%s'. " + "Skip loading session state.", + session_id, + ) + return + raise ValueError( + f"Failed to load session state because no state data " + f"found for session '{session_id}'.", + ) + + states = json.loads(state_content) + + for name, state_module in state_modules_mapping.items(): + if name in states: + state_module.load_state_dict(states[name]) + + logger.info( + "Loaded session state from Tablestore for session '%s'.", + session_id, + ) + + async def close(self) -> None: + """Close the Tablestore client connection.""" + if self._memory_store is not None: + await self._memory_store.close() + self._memory_store = None + self._initialized = False + + async def __aenter__(self) -> "TablestoreSession": + """Enter the async context manager. + + Returns: + `TablestoreSession`: + The current ``TablestoreSession`` instance. + """ + await self._ensure_initialized() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: Any, + ) -> None: + """Exit the async context manager and close the connection. + + Args: + exc_type (`type[BaseException] | None`): + The type of the exception. + exc_value (`BaseException | None`): + The exception instance. + traceback (`Any`): + The traceback. + """ + await self.close() diff --git a/tests/session_test.py b/tests/session_test.py index 6abd4f1c42..1ae81ac32f 100644 --- a/tests/session_test.py +++ b/tests/session_test.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- """Session module tests.""" import os -from typing import Union +import sys +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Union from unittest import IsolatedAsyncioTestCase +from unittest.mock import MagicMock, patch from agentscope.agent import ReActAgent, AgentBase from agentscope.formatter import DashScopeChatFormatter @@ -158,3 +161,275 @@ async def asyncTearDown(self) -> None: # close clients await self.session.close() await self._redis.close() + + +@dataclass +class _FakeSession: + """In-memory fake of tablestore_for_agent_memory Session dataclass.""" + + user_id: str + session_id: str + update_time: Optional[int] = None + metadata: Optional[Dict[str, Any]] = field(default_factory=dict) + + +class _FakeMemoryStore: + """In-memory fake of AsyncMemoryStore for testing.""" + + def __init__(self) -> None: + self._sessions: dict[tuple[str, str], _FakeSession] = {} + + async def init_table(self) -> None: + """Initialize the fake memory store.""" + pass # pylint: disable=unnecessary-pass + + async def init_search_index(self) -> None: + """Initialize the fake memory store.""" + pass # pylint: disable=unnecessary-pass + + async def update_session(self, session: _FakeSession) -> None: + """Update a session.""" + key = (session.user_id, session.session_id) + existing = self._sessions.get(key) + if existing: + if session.metadata: + existing.metadata.update(session.metadata) + else: + self._sessions[key] = _FakeSession( + user_id=session.user_id, + session_id=session.session_id, + metadata=dict(session.metadata) if session.metadata else {}, + ) + + async def get_session( + self, + user_id: str, + session_id: str, + ) -> Optional[_FakeSession]: + """Get a session.""" + return self._sessions.get((user_id, session_id)) + + async def close(self) -> None: + """Close the fake memory store.""" + self._sessions.clear() + + +def _build_tablestore_mocks() -> tuple[MagicMock, MagicMock]: + """Create mock modules for tablestore and tablestore_for_agent_memory.""" + tablestore_mod = MagicMock() + tablestore_mod.AsyncOTSClient = MagicMock + tablestore_mod.WriteRetryPolicy = MagicMock + + memory_base_mod = MagicMock() + memory_base_mod.Session = _FakeSession + + memory_store_mod = MagicMock() + + agent_memory_mod = MagicMock() + agent_memory_mod.base = MagicMock() + agent_memory_mod.base.base_memory_store = memory_base_mod + agent_memory_mod.memory = MagicMock() + agent_memory_mod.memory.async_memory_store = memory_store_mod + + return tablestore_mod, agent_memory_mod + + +class TablestoreSessionTest( + IsolatedAsyncioTestCase, +): # pylint: disable=protected-access + """Test cases for the TablestoreSession module (with mocked backend).""" + + async def asyncSetUp(self) -> None: + """Set up mock modules and create a TablestoreSession instance.""" + ( + self._tablestore_mod, + self._agent_memory_mod, + ) = _build_tablestore_mocks() + self._fake_store = _FakeMemoryStore() + + self._agent_memory_mod.memory.async_memory_store.AsyncMemoryStore = ( + MagicMock(return_value=self._fake_store) + ) + + self._patches = [ + patch.dict( + sys.modules, + { + "tablestore": self._tablestore_mod, + "tablestore_for_agent_memory": self._agent_memory_mod, + "tablestore_for_agent_memory.base": ( + self._agent_memory_mod.base + ), + "tablestore_for_agent_memory.base.base_memory_store": ( + self._agent_memory_mod.base.base_memory_store + ), + "tablestore_for_agent_memory.memory": ( + self._agent_memory_mod.memory + ), + "tablestore_for_agent_memory.memory.async_memory_store": ( + self._agent_memory_mod.memory.async_memory_store + ), + }, + ), + ] + for patcher in self._patches: + patcher.start() + + from agentscope.session._tablestore_session import TablestoreSession + + self.session = TablestoreSession( + end_point="https://fake.endpoint.com", + instance_name="fake_instance", + access_key_id="fake_ak", + access_key_secret="fake_sk", + ) + self.session._memory_store = self._fake_store + self.session._initialized = True + + async def test_save_and_load_session_state(self) -> None: + """Test saving and loading session state round-trip.""" + agent = MyAgent() + await agent.memory.add(Msg("Alice", "Hi!", "user")) + + await self.session.save_session_state( + session_id="sess_1", + user_id="user_1", + agent=agent, + ) + + agent.name = "Changed" + agent.sys_prompt = "Changed prompt" + + await self.session.load_session_state( + session_id="sess_1", + user_id="user_1", + agent=agent, + ) + + self.assertEqual(agent.name, "Friday") + self.assertEqual(agent.sys_prompt, "A helpful assistant.") + + async def test_load_before_save_session_not_exist(self) -> None: + """Test that load before any save reports session not found.""" + agent = MyAgent() + + with self.assertLogs("as", level="INFO") as log_context: + await self.session.load_session_state( + session_id="nonexistent", + user_id="user_1", + allow_not_exist=True, + agent=agent, + ) + + self.assertTrue( + any("does not exist" in msg for msg in log_context.output), + ) + + async def test_load_before_save_raises_when_not_allowed(self) -> None: + """Test that load raises ValueError when allow_not_exist=False.""" + agent = MyAgent() + + with self.assertRaises(ValueError): + await self.session.load_session_state( + session_id="nonexistent", + user_id="user_1", + allow_not_exist=False, + agent=agent, + ) + + async def test_lifecycle_load_save_load(self) -> None: + """Test full lifecycle: load (not exist) -> save -> load (exists).""" + agent = MyAgent() + original_name = agent.name + original_prompt = agent.sys_prompt + + with self.assertLogs("as", level="INFO") as log_before: + await self.session.load_session_state( + session_id="lifecycle_sess", + user_id="user_1", + allow_not_exist=True, + agent=agent, + ) + + self.assertTrue( + any("does not exist" in msg for msg in log_before.output), + ) + self.assertEqual(agent.name, original_name) + self.assertEqual(agent.sys_prompt, original_prompt) + + await agent.memory.add(Msg("Bob", "Hello!", "user")) + await self.session.save_session_state( + session_id="lifecycle_sess", + user_id="user_1", + agent=agent, + ) + + agent.name = "Mutated" + agent.sys_prompt = "Mutated prompt" + + await self.session.load_session_state( + session_id="lifecycle_sess", + user_id="user_1", + agent=agent, + ) + + self.assertEqual(agent.name, original_name) + self.assertEqual(agent.sys_prompt, original_prompt) + + async def test_save_overwrites_previous_state(self) -> None: + """Test that saving again overwrites the previous state.""" + agent = MyAgent() + + await self.session.save_session_state( + session_id="overwrite_sess", + user_id="user_1", + agent=agent, + ) + + agent.name = "UpdatedName" + agent.sys_prompt = "Updated prompt" + + await self.session.save_session_state( + session_id="overwrite_sess", + user_id="user_1", + agent=agent, + ) + + agent.name = "Garbage" + agent.sys_prompt = "Garbage" + + await self.session.load_session_state( + session_id="overwrite_sess", + user_id="user_1", + agent=agent, + ) + + self.assertEqual(agent.name, "UpdatedName") + self.assertEqual(agent.sys_prompt, "Updated prompt") + + async def test_close(self) -> None: + """Test that close resets the session state.""" + await self.session.close() + self.assertIsNone(self.session._memory_store) + self.assertFalse(self.session._initialized) + + async def test_async_context_manager(self) -> None: + """Test the async context manager protocol.""" + self.session._initialized = False + self.session._memory_store = None + + self._agent_memory_mod.memory.async_memory_store.AsyncMemoryStore = ( + MagicMock(return_value=self._fake_store) + ) + + async with self.session as session: + self.assertIs(session, self.session) + self.assertTrue(session._initialized) + + self.assertIsNone(self.session._memory_store) + self.assertFalse(self.session._initialized) + + async def asyncTearDown(self) -> None: + """Clean up patches.""" + for patcher in self._patches: + patcher.stop() diff --git a/tests/tablestore_memory_ft_test.py b/tests/tablestore_memory_ft_test.py new file mode 100644 index 0000000000..e25c169175 --- /dev/null +++ b/tests/tablestore_memory_ft_test.py @@ -0,0 +1,646 @@ +# -*- coding: utf-8 -*- +"""Functional tests for TablestoreMemory with real Tablestore instance. + +These tests require the following environment variables to be set: + - TABLESTORE_ENDPOINT + - TABLESTORE_INSTANCE_NAME + - TABLESTORE_ACCESS_KEY_ID + - TABLESTORE_ACCESS_KEY_SECRET + +If any of these are missing, the tests will be skipped. +""" +# pylint: disable=protected-access,redefined-outer-name +from __future__ import annotations + +import asyncio +import os +from typing import TYPE_CHECKING + +import pytest +import pytest_asyncio + +from agentscope.message import Msg + +if TYPE_CHECKING: + from agentscope.memory import TablestoreMemory + + +def _get_tablestore_config() -> dict[str, str] | None: + """Get Tablestore configuration from environment variables.""" + endpoint = os.getenv("TABLESTORE_ENDPOINT") + instance_name = os.getenv("TABLESTORE_INSTANCE_NAME") + access_key_id = os.getenv("TABLESTORE_ACCESS_KEY_ID") + access_key_secret = os.getenv("TABLESTORE_ACCESS_KEY_SECRET") + + if not all([endpoint, instance_name, access_key_id, access_key_secret]): + return None + + assert endpoint is not None + assert instance_name is not None + assert access_key_id is not None + assert access_key_secret is not None + + return { + "end_point": endpoint, + "instance_name": instance_name, + "access_key_id": access_key_id, + "access_key_secret": access_key_secret, + } + + +async def _wait_for_index_ready( + memory: "TablestoreMemory", + expected_count: int, +) -> None: + """Wait for the search index to be ready with the expected document count. + + Uses TablestoreHelper to poll the search index until the total count + matches the expected count. + """ + from tablestore_for_agent_memory.util.tablestore_helper import ( + TablestoreHelper, + ) + + await memory._ensure_initialized() + tablestore_client = memory._tablestore_client + table_name = memory._knowledge_store._table_name + index_name = memory._knowledge_store._search_index_name + + await TablestoreHelper.async_wait_search_index_ready( + tablestore_client=tablestore_client, + table_name=table_name, + index_name=index_name, + total_count=expected_count, + ) + + +@pytest.fixture +def tablestore_config() -> dict[str, str]: + """Fixture that provides Tablestore config or skips the test.""" + config = _get_tablestore_config() + if config is None: + pytest.skip( + "Tablestore environment variables not set: " + "TABLESTORE_ENDPOINT, TABLESTORE_INSTANCE_NAME, " + "TABLESTORE_ACCESS_KEY_ID, TABLESTORE_ACCESS_KEY_SECRET", + ) + return config # type: ignore[return-value] + + +@pytest_asyncio.fixture +async def tablestore_memory( # type: ignore[misc] + tablestore_config: dict[str, str], +) -> None: + """Fixture that creates a TablestoreMemory for testing.""" + from agentscope.memory import TablestoreMemory + + memory = TablestoreMemory( + user_id="ft_test_user", + session_id="ft_test_session", + table_name="agentscope_ft_memory", + vector_dimension=4, + **tablestore_config, + ) + + await memory._ensure_initialized() + + # Clean up any existing data before test + await memory.clear() + await _wait_for_index_ready(memory, 0) + + try: + yield memory + finally: + # Clean up after test + await memory.clear() + await _wait_for_index_ready(memory, 0) + await memory.close() + + +@pytest.mark.asyncio +async def test_memory_lifecycle(tablestore_config: dict[str, str]) -> None: + """Test creating and closing a TablestoreMemory.""" + from agentscope.memory import TablestoreMemory + + memory = TablestoreMemory( + user_id="ft_lifecycle_user", + session_id="ft_lifecycle_session", + table_name="agentscope_ft_memory", + vector_dimension=4, + **tablestore_config, + ) + + await memory._ensure_initialized() + assert memory._initialized is True + assert memory._knowledge_store is not None + + await memory.close() + assert memory._initialized is False + assert memory._knowledge_store is None + + +@pytest.mark.asyncio +async def test_add_and_get_memory( + tablestore_memory: TablestoreMemory, +) -> None: + """Test adding messages and retrieving them.""" + memory = tablestore_memory + + msg1 = Msg("Alice", "Hello world!", "user") + msg2 = Msg("Bob", "Hi there!", "assistant") + + await memory.add(msg1) + await memory.add(msg2) + await _wait_for_index_ready(memory, 2) + + messages = await memory.get_memory() + assert len(messages) == 2 + + names = {m.name for m in messages} + assert "Alice" in names + assert "Bob" in names + + +@pytest.mark.asyncio +async def test_add_multiple_messages( + tablestore_memory: TablestoreMemory, +) -> None: + """Test adding a list of messages at once.""" + memory = tablestore_memory + + msgs = [Msg("User", f"Message {i}", "user") for i in range(5)] + + await memory.add(msgs) + await _wait_for_index_ready(memory, 5) + + result = await memory.get_memory() + assert len(result) == 5 + + +@pytest.mark.asyncio +async def test_add_no_duplicates( + tablestore_memory: TablestoreMemory, +) -> None: + """Test duplicate messages not added when allow_duplicates=False.""" + memory = tablestore_memory + + msg = Msg("Alice", "Hello!", "user") + + await memory.add(msg) + await _wait_for_index_ready(memory, 1) + + # Try to add the same message again + await memory.add(msg, allow_duplicates=False) + await _wait_for_index_ready(memory, 1) + + result = await memory.get_memory() + assert len(result) == 1 + + +@pytest.mark.asyncio +async def test_add_allow_duplicates( + tablestore_memory: TablestoreMemory, +) -> None: + """Test that duplicate messages are added when allow_duplicates=True.""" + memory = tablestore_memory + + msg = Msg("Alice", "Hello!", "user") + + await memory.add(msg) + await _wait_for_index_ready(memory, 1) + + # Add the same message again with allow_duplicates=True + # Since the same msg.id is used as document_id, Tablestore will + # upsert (overwrite) the existing row, so count stays at 1. + await memory.add(msg, allow_duplicates=True) + await _wait_for_index_ready(memory, 1) + + result = await memory.get_memory() + assert len(result) == 1 + + +@pytest.mark.asyncio +async def test_delete_messages(tablestore_memory: TablestoreMemory) -> None: + """Test deleting messages by ID.""" + memory = tablestore_memory + + msg1 = Msg("Alice", "Hello!", "user") + msg2 = Msg("Bob", "Hi!", "assistant") + + await memory.add([msg1, msg2]) + await _wait_for_index_ready(memory, 2) + + # Delete msg1 + deleted = await memory.delete([msg1.id]) + assert deleted == 1 + await _wait_for_index_ready(memory, 1) + + result = await memory.get_memory() + assert len(result) == 1 + assert result[0].name == "Bob" + + +@pytest.mark.asyncio +async def test_delete_nonexistent( + tablestore_memory: TablestoreMemory, +) -> None: + """Test deleting a non-existent message does not raise error.""" + memory = tablestore_memory + + deleted = await memory.delete(["nonexistent_id"]) + assert deleted == 0 + + +@pytest.mark.asyncio +async def test_size(tablestore_memory: TablestoreMemory) -> None: + """Test getting the size of memory.""" + memory = tablestore_memory + + assert await memory.size() == 0 + + msgs = [Msg("User", f"Msg {i}", "user") for i in range(3)] + await memory.add(msgs) + await _wait_for_index_ready(memory, 3) + + assert await memory.size() == 3 + + +@pytest.mark.asyncio +async def test_clear(tablestore_memory: TablestoreMemory) -> None: + """Test clearing all messages.""" + memory = tablestore_memory + + msgs = [Msg("User", f"Msg {i}", "user") for i in range(5)] + await memory.add(msgs) + await _wait_for_index_ready(memory, 5) + + assert await memory.size() == 5 + + await memory.clear() + await _wait_for_index_ready(memory, 0) + + assert await memory.size() == 0 + + +@pytest.mark.asyncio +async def test_add_with_marks(tablestore_memory: TablestoreMemory) -> None: + """Test adding messages with marks.""" + memory = tablestore_memory + + msg = Msg("Alice", "Important message", "user") + await memory.add(msg, marks=["important", "review"]) + await _wait_for_index_ready(memory, 1) + + # Get all messages + all_msgs = await memory.get_memory() + assert len(all_msgs) == 1 + + # Get messages with specific mark + important_msgs = await memory.get_memory(mark="important") + assert len(important_msgs) == 1 + + # Get messages with non-matching mark + other_msgs = await memory.get_memory(mark="other") + assert len(other_msgs) == 0 + + +@pytest.mark.asyncio +async def test_get_memory_with_mark_filter( + tablestore_memory: TablestoreMemory, +) -> None: + """Test filtering messages by mark.""" + memory = tablestore_memory + + msg1 = Msg("Alice", "Important", "user") + msg2 = Msg("Bob", "Normal", "assistant") + msg3 = Msg("Charlie", "Also important", "user") + + await memory.add(msg1, marks=["important"]) + await memory.add(msg2, marks=["normal"]) + await memory.add(msg3, marks=["important"]) + await _wait_for_index_ready(memory, 3) + + important_msgs = await memory.get_memory(mark="important") + assert len(important_msgs) == 2 + + normal_msgs = await memory.get_memory(mark="normal") + assert len(normal_msgs) == 1 + assert normal_msgs[0].name == "Bob" + + +@pytest.mark.asyncio +async def test_get_memory_with_exclude_mark( + tablestore_memory: TablestoreMemory, +) -> None: + """Test excluding messages by mark.""" + memory = tablestore_memory + + msg1 = Msg("Alice", "Keep me", "user") + msg2 = Msg("Bob", "Exclude me", "assistant") + + await memory.add(msg1, marks=["keep"]) + await memory.add(msg2, marks=["exclude"]) + await _wait_for_index_ready(memory, 2) + + result = await memory.get_memory(exclude_mark="exclude") + assert len(result) == 1 + assert result[0].name == "Alice" + + +@pytest.mark.asyncio +async def test_delete_by_mark(tablestore_memory: TablestoreMemory) -> None: + """Test deleting messages by mark.""" + memory = tablestore_memory + + msg1 = Msg("Alice", "Keep me", "user") + msg2 = Msg("Bob", "Delete me", "assistant") + msg3 = Msg("Charlie", "Delete me too", "user") + + await memory.add(msg1, marks=["keep"]) + await memory.add(msg2, marks=["delete"]) + await memory.add(msg3, marks=["delete"]) + await _wait_for_index_ready(memory, 3) + + deleted = await memory.delete_by_mark("delete") + assert deleted == 2 + await _wait_for_index_ready(memory, 1) + + result = await memory.get_memory() + assert len(result) == 1 + assert result[0].name == "Alice" + + +@pytest.mark.asyncio +async def test_update_messages_mark_add( + tablestore_memory: TablestoreMemory, +) -> None: + """Test adding a mark to messages.""" + memory = tablestore_memory + + msg1 = Msg("Alice", "Hello", "user") + msg2 = Msg("Bob", "Hi", "assistant") + + await memory.add([msg1, msg2]) + await _wait_for_index_ready(memory, 2) + + updated = await memory.update_messages_mark( + msg_ids=[msg1.id], + new_mark="reviewed", + ) + assert updated == 1 + + # Wait for search index to sync metadata update (update_row needs time) + await asyncio.sleep(20) + + # Verify the mark was added + reviewed_msgs = await memory.get_memory(mark="reviewed") + assert len(reviewed_msgs) == 1 + assert reviewed_msgs[0].name == "Alice" + + +@pytest.mark.asyncio +async def test_update_messages_mark_replace( + tablestore_memory: TablestoreMemory, +) -> None: + """Test replacing a mark on messages.""" + memory = tablestore_memory + + msg = Msg("Alice", "Hello", "user") + await memory.add(msg, marks=["draft"]) + await _wait_for_index_ready(memory, 1) + + updated = await memory.update_messages_mark( + msg_ids=[msg.id], + old_mark="draft", + new_mark="final", + ) + assert updated == 1 + + # Wait for search index to sync metadata update (update_row needs time) + await asyncio.sleep(20) + + # Old mark should not match + draft_msgs = await memory.get_memory(mark="draft") + assert len(draft_msgs) == 0 + + # New mark should match + final_msgs = await memory.get_memory(mark="final") + assert len(final_msgs) == 1 + + +@pytest.mark.asyncio +async def test_update_messages_mark_remove( + tablestore_memory: TablestoreMemory, +) -> None: + """Test removing a mark from messages.""" + memory = tablestore_memory + + msg = Msg("Alice", "Hello", "user") + await memory.add(msg, marks=["temporary"]) + await _wait_for_index_ready(memory, 1) + + updated = await memory.update_messages_mark( + msg_ids=[msg.id], + old_mark="temporary", + new_mark=None, + ) + assert updated == 1 + + # Wait for search index to sync metadata update (update_row needs time) + await asyncio.sleep(20) + + # Mark should be removed + temp_msgs = await memory.get_memory(mark="temporary") + assert len(temp_msgs) == 0 + + # Message should still exist + all_msgs = await memory.get_memory() + assert len(all_msgs) == 1 + + +@pytest.mark.asyncio +async def test_compressed_summary( + tablestore_memory: TablestoreMemory, +) -> None: + """Test compressed summary functionality.""" + memory = tablestore_memory + + msg = Msg("Alice", "Hello", "user") + await memory.add(msg) + await _wait_for_index_ready(memory, 1) + + # Set compressed summary + await memory.update_compressed_summary("This is a summary of the chat.") + + # Get memory with summary prepended + result = await memory.get_memory(prepend_summary=True) + assert len(result) == 2 + assert result[0].content == "This is a summary of the chat." + assert result[1].name == "Alice" + + # Get memory without summary + result_no_summary = await memory.get_memory(prepend_summary=False) + assert len(result_no_summary) == 1 + + +@pytest.mark.asyncio +async def test_state_dict_roundtrip( + tablestore_memory: TablestoreMemory, +) -> None: + """Test state_dict and load_state_dict roundtrip.""" + memory = tablestore_memory + + await memory.update_compressed_summary("Test summary for state dict") + + state = memory.state_dict() + assert state["_compressed_summary"] == "Test summary for state dict" + + # Create a new memory and load state + from agentscope.memory import TablestoreMemory + + config = _get_tablestore_config() + new_memory = TablestoreMemory( + user_id="ft_state_user", + session_id="ft_state_session", + table_name="agentscope_ft_memory", + vector_dimension=4, + **config, + ) + new_memory.load_state_dict(state) + + assert new_memory._compressed_summary == "Test summary for state dict" + await new_memory.close() + + +@pytest.mark.asyncio +async def test_msg_roundtrip(tablestore_memory: TablestoreMemory) -> None: + """Test that Msg content is preserved through add/get roundtrip.""" + memory = tablestore_memory + + original_msg = Msg( + "Alice", + "Hello with special chars: 你好世界 & ", + "user", + metadata={"key": "value", "number": 42}, + ) + + await memory.clear() + await memory.add(original_msg) + await _wait_for_index_ready(memory, 1) + + result = await memory.get_memory() + assert len(result) == 1 + + restored = result[0] + assert restored.name == "Alice" + assert restored.content == "Hello with special chars: 你好世界 & " + assert restored.role == "user" + assert restored.id == original_msg.id + + +@pytest.mark.asyncio +async def test_isolation_between_users( + tablestore_config: dict[str, str], +) -> None: + """Test that different user_id/session_id combinations are isolated.""" + from agentscope.memory import TablestoreMemory + + memory1 = TablestoreMemory( + user_id="ft_user_A", + session_id="ft_session_1", + table_name="agentscope_ft_memory", + vector_dimension=4, + **tablestore_config, + ) + memory2 = TablestoreMemory( + user_id="ft_user_B", + session_id="ft_session_1", + table_name="agentscope_ft_memory", + vector_dimension=4, + **tablestore_config, + ) + + try: + await memory1._ensure_initialized() + await memory2._ensure_initialized() + + # Clean up + await memory1.clear() + await memory2.clear() + await _wait_for_index_ready(memory1, 0) + await _wait_for_index_ready(memory2, 0) + + # Add messages to memory1 + await memory1.add(Msg("Alice", "Hello from user A", "user")) + await _wait_for_index_ready(memory1, 1) + + # memory2 should not see memory1's messages + result2 = await memory2.get_memory() + assert len(result2) == 0 + + # memory1 should see its own messages + result1 = await memory1.get_memory() + assert len(result1) == 1 + assert result1[0].name == "Alice" + + finally: + await memory1.clear() + await memory2.clear() + await memory1.close() + await memory2.close() + + +@pytest.mark.asyncio +async def test_add_none(tablestore_memory: TablestoreMemory) -> None: + """Test that adding None does nothing.""" + memory = tablestore_memory + + await memory.add(None) + + assert await memory.size() == 0 + + +@pytest.mark.asyncio +async def test_get_memory_preserves_insertion_order( + tablestore_memory: TablestoreMemory, +) -> None: + """Test that get_memory returns messages sorted by timestamp, + preserving the insertion order.""" + memory = tablestore_memory + + msg1 = Msg( + "Alice", + "First message", + "user", + timestamp="2026-01-01 00:00:01.000", + ) + msg2 = Msg( + "Bob", + "Second message", + "assistant", + timestamp="2026-01-01 00:00:02.000", + ) + msg3 = Msg( + "Charlie", + "Third message", + "user", + timestamp="2026-01-01 00:00:03.000", + ) + + # Insert in reverse order to ensure sorting is by timestamp, + # not by insertion order in Tablestore + await memory.add(msg3) + await memory.add(msg1) + await memory.add(msg2) + await _wait_for_index_ready(memory, 3) + + messages = await memory.get_memory() + assert len(messages) == 3 + + assert messages[0].name == "Alice" + assert messages[1].name == "Bob" + assert messages[2].name == "Charlie" + + assert messages[0].timestamp == "2026-01-01 00:00:01.000" + assert messages[1].timestamp == "2026-01-01 00:00:02.000" + assert messages[2].timestamp == "2026-01-01 00:00:03.000" diff --git a/tests/tablestore_memory_test.py b/tests/tablestore_memory_test.py new file mode 100644 index 0000000000..c42b679eed --- /dev/null +++ b/tests/tablestore_memory_test.py @@ -0,0 +1,583 @@ +# -*- coding: utf-8 -*- +"""Tests for the Tablestore memory implementation.""" +# pylint: disable=protected-access,too-many-public-methods +import json +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock + +from agentscope.memory._working_memory._tablestore_memory import ( + TablestoreMemory, +) +from agentscope.message import Msg + + +def _create_mock_document( + msg: Msg, + marks: "list[str] | None" = None, + user_id: str = "default", + session_id: str = "default", +) -> MagicMock: + """Create a mock Tablestore document from a Msg.""" + if marks is None: + marks = [] + + doc = MagicMock() + doc.document_id = f"{msg.id}:::{session_id}" + doc.text = json.dumps( + msg.to_dict(), + ensure_ascii=False, + default=str, + ) + doc.tenant_id = user_id + doc.metadata = { + "session_id": session_id, + "name": msg.name, + "role": msg.role, + "timestamp": msg.timestamp or "", + "invocation_id": msg.invocation_id or "", + "marks_json": json.dumps(marks, ensure_ascii=False), + } + return doc + + +def _create_memory_with_mocks() -> "TablestoreMemory": + """Create a TablestoreMemory instance with mocked dependencies.""" + memory = object.__new__(TablestoreMemory) + # Initialize StateModule base + from collections import OrderedDict + + memory._module_dict = OrderedDict() + memory._attribute_dict = OrderedDict() + memory._compressed_summary = "" + memory.register_state("_compressed_summary") + + memory._user_id = "test_user" + memory._session_id = "test_session" + memory._table_name = "test_memory" + memory._text_field = "text" + memory._embedding_field = "embedding" + memory._tablestore_client = MagicMock() + memory._search_index_schema = [] + memory._knowledge_store = AsyncMock() + memory._knowledge_store_kwargs = {} + memory._initialized = True + return memory + + +class TablestoreMemoryTest(IsolatedAsyncioTestCase): + """Test cases for the Tablestore memory module.""" + + async def asyncSetUp(self) -> None: + """Set up test fixtures.""" + self.memory = _create_memory_with_mocks() + self.msgs = [] + for i in range(10): + msg = Msg("user", f"message {i}", "user") + msg.id = str(i) + self.msgs.append(msg) + + async def test_add_messages(self) -> None: + """Test adding messages to Tablestore memory.""" + # Mock _search_msg_ids_by_marks to return empty set (no duplicates) + self.memory._search_msg_ids_by_marks = AsyncMock(return_value=set()) + + await self.memory.add(self.msgs[:3]) + + # Verify put_document was called 3 times + self.assertEqual( + self.memory._knowledge_store.put_document.call_count, + 3, + ) + + async def test_add_single_message(self) -> None: + """Test adding a single message.""" + self.memory._search_msg_ids_by_marks = AsyncMock(return_value=set()) + + await self.memory.add(self.msgs[0]) + + self.memory._knowledge_store.put_document.assert_called_once() + + async def test_add_none(self) -> None: + """Test adding None does nothing.""" + await self.memory.add(None) + + self.memory._knowledge_store.put_document.assert_not_called() + + async def test_add_with_marks(self) -> None: + """Test adding messages with marks.""" + self.memory._search_msg_ids_by_marks = AsyncMock(return_value=set()) + + await self.memory.add(self.msgs[:2], marks=["important", "todo"]) + + self.assertEqual( + self.memory._knowledge_store.put_document.call_count, + 2, + ) + + # Verify marks are included in the document + call_args = self.memory._knowledge_store.put_document.call_args_list + for call in call_args: + doc = call[0][0] + marks = json.loads(doc.metadata["marks_json"]) + self.assertIn("important", marks) + self.assertIn("todo", marks) + + async def test_add_no_duplicates(self) -> None: + """Test that duplicate messages are filtered out.""" + # Mock get_documents to return existing documents for IDs "0" and "1" + existing_docs = [ + MagicMock( + document_id="0:::test_session", + ), + MagicMock( + document_id="1:::test_session", + ), + ] + self.memory._knowledge_store.get_documents = AsyncMock( + return_value=existing_docs, + ) + + await self.memory.add(self.msgs[:5], allow_duplicates=False) + + # Only messages 2, 3, 4 should be added + self.assertEqual( + self.memory._knowledge_store.put_document.call_count, + 3, + ) + + async def test_add_allow_duplicates(self) -> None: + """Test adding with allow_duplicates=True.""" + # When allow_duplicates=True, get_documents should not be called + await self.memory.add(self.msgs[:5], allow_duplicates=True) + + self.memory._knowledge_store.get_documents.assert_not_called() + + # All 5 messages should be added + self.assertEqual( + self.memory._knowledge_store.put_document.call_count, + 5, + ) + + await self.memory.add(self.msgs[:5], allow_duplicates=True) + # All 5 messages should be added + self.assertEqual( + self.memory._knowledge_store.put_document.call_count, + 10, + ) + + async def test_delete_messages(self) -> None: + """Test deleting messages by ID.""" + self.memory._get_existing_msg_ids_in_session = AsyncMock( + return_value={"0"}, + ) + + deleted = await self.memory.delete(msg_ids=["0"]) + + self.assertEqual(deleted, 1) + self.memory._knowledge_store.delete_document.assert_called_once_with( + document_id="0:::test_session", + tenant_id="test_user", + ) + + async def test_delete_nonexistent(self) -> None: + """Test deleting non-existent messages returns 0.""" + self.memory._search_msg_ids_by_marks = AsyncMock(return_value=set()) + + deleted = await self.memory.delete(msg_ids=["nonexistent"]) + + self.assertEqual(deleted, 0) + self.memory._knowledge_store.delete_document.assert_not_called() + + async def test_get_memory_all(self) -> None: + """Test getting all messages from memory.""" + docs = [ + _create_mock_document( + self.msgs[i], + user_id="test_user", + session_id="test_session", + ) + for i in range(5) + ] + self.memory._search_documents_by_marks_and_exclude_marks = AsyncMock( + return_value=docs, + ) + + result = await self.memory.get_memory(prepend_summary=False) + + self.assertEqual(len(result), 5) + for i, msg in enumerate(result): + self.assertEqual(msg.id, str(i)) + mock = self.memory._search_documents_by_marks_and_exclude_marks + mock.assert_called_once_with( + marks=None, + exclude_marks=None, + ) + + async def test_get_memory_with_mark_filter(self) -> None: + """Test getting messages filtered by mark.""" + # When mark is provided, _search_documents_by_marks_and_exclude_marks + # is used and only matching docs are returned from the database layer + docs = [ + _create_mock_document(self.msgs[1], marks=["important"]), + _create_mock_document(self.msgs[2], marks=["important", "todo"]), + ] + self.memory._search_documents_by_marks_and_exclude_marks = AsyncMock( + return_value=docs, + ) + + result = await self.memory.get_memory( + mark="important", + prepend_summary=False, + ) + + self.assertEqual(len(result), 2) + self.assertEqual(result[0].id, "1") + self.assertEqual(result[1].id, "2") + mock = self.memory._search_documents_by_marks_and_exclude_marks + mock.assert_called_once_with( + marks="important", + exclude_marks=None, + ) + + async def test_get_memory_with_exclude_mark(self) -> None: + """Test getting messages with excluded mark.""" + # exclude_mark filtering is now done at the database layer + docs = [ + _create_mock_document(self.msgs[0], marks=[]), + _create_mock_document(self.msgs[3], marks=[]), + ] + self.memory._search_documents_by_marks_and_exclude_marks = AsyncMock( + return_value=docs, + ) + + result = await self.memory.get_memory( + exclude_mark="important", + prepend_summary=False, + ) + + self.assertEqual(len(result), 2) + self.assertEqual(result[0].id, "0") + self.assertEqual(result[1].id, "3") + mock = self.memory._search_documents_by_marks_and_exclude_marks + mock.assert_called_once_with( + marks=None, + exclude_marks="important", + ) + + async def test_get_memory_with_summary(self) -> None: + """Test that compressed summary is prepended when available.""" + docs = [ + _create_mock_document(self.msgs[0]), + ] + self.memory._search_documents_by_marks_and_exclude_marks = AsyncMock( + return_value=docs, + ) + self.memory._compressed_summary = "Previous conversation summary." + + result = await self.memory.get_memory(prepend_summary=True) + + self.assertEqual(len(result), 2) + self.assertEqual( + result[0].content, + "Previous conversation summary.", + ) + self.assertEqual(result[1].id, "0") + + async def test_size(self) -> None: + """Test getting the size of memory.""" + msg_ids = [MagicMock() for _ in range(7)] + self.memory._search_msg_ids_by_marks = AsyncMock(return_value=msg_ids) + + result = await self.memory.size() + + self.assertEqual(result, 7) + + async def test_clear(self) -> None: + """Test clearing all messages.""" + self.memory._search_msg_ids_by_marks = AsyncMock( + return_value={"msg_0", "msg_1"}, + ) + + await self.memory.clear() + + self.assertEqual( + self.memory._knowledge_store.delete_document.call_count, + 2, + ) + deleted_doc_ids = { + call.kwargs["document_id"] + for call in ( + self.memory._knowledge_store.delete_document.call_args_list + ) + } + self.assertEqual( + deleted_doc_ids, + {"msg_0:::test_session", "msg_1:::test_session"}, + ) + + async def test_clear_empty(self) -> None: + """Test clearing when memory is already empty.""" + self.memory._search_msg_ids_by_marks = AsyncMock(return_value=set()) + + await self.memory.clear() + + self.memory._knowledge_store.delete_document.assert_not_called() + + async def test_delete_by_mark(self) -> None: + """Test deleting messages by mark.""" + self.memory._search_msg_ids_by_marks = AsyncMock( + return_value={"1", "2"}, + ) + + deleted = await self.memory.delete_by_mark("important") + + self.assertEqual(deleted, 2) + self.memory._search_msg_ids_by_marks.assert_called_once_with( + ["important"], + ) + self.assertEqual( + self.memory._knowledge_store.delete_document.call_count, + 2, + ) + + async def test_delete_by_mark_list(self) -> None: + """Test deleting messages by multiple marks.""" + self.memory._search_msg_ids_by_marks = AsyncMock( + return_value={"1", "2"}, + ) + + deleted = await self.memory.delete_by_mark(["important", "todo"]) + + self.assertEqual(deleted, 2) + self.memory._search_msg_ids_by_marks.assert_called_once_with( + ["important", "todo"], + ) + + async def test_update_messages_mark_add(self) -> None: + """Test adding a mark to messages.""" + # msg_ids is provided, so _get_existing_msg_ids_and_marks_in_session + # is used + self.memory._get_existing_msg_ids_and_marks_in_session = AsyncMock( + return_value={"0": [], "1": []}, + ) + self.memory._knowledge_store.update_document = AsyncMock() + + updated = await self.memory.update_messages_mark( + msg_ids=["0", "1"], + new_mark="review", + ) + + self.assertEqual(updated, 2) + mock = self.memory._get_existing_msg_ids_and_marks_in_session + mock.assert_called_once_with(["0", "1"]) + self.assertEqual( + self.memory._knowledge_store.update_document.call_count, + 2, + ) + + async def test_update_messages_mark_remove(self) -> None: + """Test removing a mark from messages.""" + # msg_ids is provided, so _get_existing_msg_ids_and_marks_in_session + # is used + self.memory._get_existing_msg_ids_and_marks_in_session = AsyncMock( + return_value={"0": ["important"]}, + ) + self.memory._knowledge_store.update_document = AsyncMock() + + updated = await self.memory.update_messages_mark( + msg_ids=["0"], + old_mark="important", + new_mark=None, + ) + + self.assertEqual(updated, 1) + mock = self.memory._get_existing_msg_ids_and_marks_in_session + mock.assert_called_once_with(["0"]) + self.assertEqual( + self.memory._knowledge_store.update_document.call_count, + 1, + ) + + async def test_update_messages_mark_replace(self) -> None: + """Test replacing a mark on messages.""" + # msg_ids is provided, so _get_existing_msg_ids_and_marks_in_session + # is used + self.memory._get_existing_msg_ids_and_marks_in_session = AsyncMock( + return_value={"0": ["important"], "1": ["important"]}, + ) + self.memory._knowledge_store.update_document = AsyncMock() + + updated = await self.memory.update_messages_mark( + msg_ids=["0", "1"], + old_mark="important", + new_mark="archived", + ) + + self.assertEqual(updated, 2) + mock = self.memory._get_existing_msg_ids_and_marks_in_session + mock.assert_called_once_with(["0", "1"]) + self.assertEqual( + self.memory._knowledge_store.update_document.call_count, + 2, + ) + + async def test_state_dict(self) -> None: + """Test state_dict serialization.""" + self.memory._compressed_summary = "Test summary" + + state = self.memory.state_dict() + + self.assertEqual(state["_compressed_summary"], "Test summary") + + async def test_load_state_dict(self) -> None: + """Test load_state_dict deserialization.""" + self.memory.load_state_dict( + { + "_compressed_summary": "Loaded summary", + }, + ) + + self.assertEqual( + self.memory._compressed_summary, + "Loaded summary", + ) + + async def test_close(self) -> None: + """Test closing the Tablestore memory.""" + mock_store = self.memory._knowledge_store + await self.memory.close() + + mock_store.close.assert_called_once() + self.assertIsNone(self.memory._knowledge_store) + self.assertFalse(self.memory._initialized) + + async def test_close_when_not_initialized(self) -> None: + """Test closing when not initialized.""" + self.memory._knowledge_store = None + self.memory._initialized = False + + # Should not raise + await self.memory.close() + + async def test_msg_to_document_string_content(self) -> None: + """Test converting a Msg with string content to document.""" + msg = Msg("Alice", "Hello world!", "user") + + doc = self.memory._msg_to_document(msg, ["mark1"]) + + self.assertEqual( + doc.document_id, + f"{msg.id}:::test_session", + ) + # Verify text contains full Msg JSON + msg_dict = json.loads(doc.text) + self.assertEqual(msg_dict["id"], msg.id) + self.assertEqual(msg_dict["name"], "Alice") + self.assertEqual(msg_dict["content"], "Hello world!") + self.assertEqual(msg_dict["role"], "user") + self.assertEqual(doc.tenant_id, "test_user") + self.assertEqual(doc.metadata["name"], "Alice") + self.assertEqual(doc.metadata["role"], "user") + self.assertEqual(doc.metadata["session_id"], "test_session") + marks = json.loads(doc.metadata["marks_json"]) + self.assertIn("mark1", marks) + # Verify msg_json is NOT in metadata + self.assertNotIn("msg_json", doc.metadata) + # Verify old fields are removed + self.assertNotIn("content_json", doc.metadata) + self.assertNotIn("metadata_json", doc.metadata) + + async def test_msg_to_document_list_content(self) -> None: + """Test converting a Msg with list content to document.""" + content = [{"type": "text", "text": "Hello from blocks!"}] + msg = Msg("Bob", content, "assistant") + + doc = self.memory._msg_to_document(msg, []) + + # Verify text contains full Msg JSON with list content + msg_dict = json.loads(doc.text) + self.assertEqual(msg_dict["content"], content) + self.assertEqual(doc.tenant_id, "test_user") + # Verify msg_json is NOT in metadata + self.assertNotIn("msg_json", doc.metadata) + + async def test_document_to_msg_roundtrip(self) -> None: + """Test roundtrip conversion Msg -> Document -> Msg.""" + original_msg = Msg( + "Alice", + "Test content", + "user", + metadata={"key": "value", "number": 42}, + ) + original_marks = ["important", "todo"] + + doc = self.memory._msg_to_document(original_msg, original_marks) + ( + restored_msg, + restored_marks, + ) = TablestoreMemory._document_to_msg_and_marks(doc) + + self.assertEqual(restored_msg.name, original_msg.name) + self.assertEqual(restored_msg.content, original_msg.content) + self.assertEqual(restored_msg.role, original_msg.role) + self.assertEqual(restored_msg.id, original_msg.id) + self.assertEqual(restored_msg.metadata, original_msg.metadata) + self.assertListEqual(restored_marks, original_marks) + + async def test_invalid_mark_type(self) -> None: + """Test that invalid mark types raise TypeError.""" + self.memory._search_msg_ids_by_marks = AsyncMock(return_value=set()) + + with self.assertRaises(TypeError): + await self.memory.add(self.msgs[0], marks=123) + + async def test_get_memory_invalid_mark_type(self) -> None: + """Test that invalid mark type in get_memory raises TypeError.""" + with self.assertRaises(TypeError): + await self.memory.get_memory(mark=123) + + async def test_make_document_id(self) -> None: + """Test _make_document_id produces correct format.""" + document_id = self.memory._make_document_id("msg_123") + self.assertEqual(document_id, "msg_123:::test_session") + + async def test_extract_msg_id(self) -> None: + """Test _extract_msg_id extracts msg ID from document ID.""" + msg_id = self.memory._extract_msg_id("msg_123:::test_session") + self.assertEqual(msg_id, "msg_123") + + async def test_extract_msg_id_invalid_suffix(self) -> None: + """Test _extract_msg_id logs error for invalid suffix.""" + with self.assertLogs("as", level="ERROR") as log_context: + msg_id = self.memory._extract_msg_id("msg_123:::wrong_session") + self.assertEqual(msg_id, "msg_123:::wrong_session") + self.assertTrue( + any( + "Unexpected document_id format" in m + for m in log_context.output + ), + ) + + async def test_extract_msg_id_no_separator(self) -> None: + """Test _extract_msg_id logs error when no separator found.""" + with self.assertLogs("as", level="ERROR") as log_context: + msg_id = self.memory._extract_msg_id("msg_123") + self.assertEqual(msg_id, "msg_123") + self.assertTrue( + any( + "Unexpected document_id format" in m + for m in log_context.output + ), + ) + + async def test_make_and_extract_roundtrip(self) -> None: + """Test roundtrip of _make_document_id and _extract_msg_id.""" + original_id = "test_msg_id_456" + document_id = self.memory._make_document_id(original_id) + extracted_id = self.memory._extract_msg_id(document_id) + self.assertEqual(extracted_id, original_id) + + async def test_delete_by_mark_invalid_type(self) -> None: + """Test that invalid mark type in delete_by_mark raises TypeError.""" + with self.assertRaises(TypeError): + await self.memory.delete_by_mark(123) diff --git a/tests/tablestore_session_ft_test.py b/tests/tablestore_session_ft_test.py new file mode 100644 index 0000000000..2a4fa5d6fc --- /dev/null +++ b/tests/tablestore_session_ft_test.py @@ -0,0 +1,358 @@ +# -*- coding: utf-8 -*- +"""Functional tests for TablestoreSession with real Tablestore instance. + +These tests require the following environment variables to be set: + - TABLESTORE_ENDPOINT + - TABLESTORE_INSTANCE_NAME + - TABLESTORE_ACCESS_KEY_ID + - TABLESTORE_ACCESS_KEY_SECRET + +If any of these are missing, the tests will be skipped. +""" +# pylint: disable=protected-access,redefined-outer-name +from __future__ import annotations + +import os +import unittest +from typing import TYPE_CHECKING + +import pytest +import pytest_asyncio + +from agentscope.memory import InMemoryMemory +from agentscope.message import Msg +from agentscope.module import StateModule + +if TYPE_CHECKING: + from agentscope.session import TablestoreSession + + +def _get_tablestore_config() -> dict[str, str] | None: + """Get Tablestore configuration from environment variables.""" + endpoint = os.getenv("TABLESTORE_ENDPOINT") + instance_name = os.getenv("TABLESTORE_INSTANCE_NAME") + access_key_id = os.getenv("TABLESTORE_ACCESS_KEY_ID") + access_key_secret = os.getenv("TABLESTORE_ACCESS_KEY_SECRET") + + if not all([endpoint, instance_name, access_key_id, access_key_secret]): + return None + + assert endpoint is not None + assert instance_name is not None + assert access_key_id is not None + assert access_key_secret is not None + + return { + "end_point": endpoint, + "instance_name": instance_name, + "access_key_id": access_key_id, + "access_key_secret": access_key_secret, + } + + +class SimpleStateModule(StateModule): + """A simple state module for testing.""" + + def __init__(self) -> None: + super().__init__() + self.name = "test_agent" + self.value = 42 + + def state_dict(self) -> dict: + return {"name": self.name, "value": self.value} + + def load_state_dict(self, state_dict: dict, strict: bool = True) -> None: + self.name = state_dict.get("name", self.name) + self.value = state_dict.get("value", self.value) + + +@pytest.fixture +def tablestore_config() -> dict[str, str]: + """Fixture that provides Tablestore config or skips the test.""" + config = _get_tablestore_config() + if config is None: + pytest.skip( + "Tablestore environment variables not set: " + "TABLESTORE_ENDPOINT, TABLESTORE_INSTANCE_NAME, " + "TABLESTORE_ACCESS_KEY_ID, TABLESTORE_ACCESS_KEY_SECRET", + ) + return config # type: ignore[return-value] + + +@pytest_asyncio.fixture +async def tablestore_session( # type: ignore[misc] + tablestore_config: dict[str, str], +) -> None: + """Fixture that creates and yields a TablestoreSession, then closes it.""" + from agentscope.session import TablestoreSession + + session = TablestoreSession( + session_table_name="agentscope_ft_session", + message_table_name="agentscope_ft_message", + **tablestore_config, + ) + + async with session as session_instance: + yield session_instance + + +@pytest.mark.asyncio +async def test_session_lifecycle(tablestore_config: dict[str, str]) -> None: + """Test creating and closing a TablestoreSession.""" + from agentscope.session import TablestoreSession + + session = TablestoreSession( + session_table_name="agentscope_ft_session", + message_table_name="agentscope_ft_message", + **tablestore_config, + ) + + await session._ensure_initialized() + assert session._initialized is True + assert session._memory_store is not None + + await session.close() + assert session._initialized is False + assert session._memory_store is None + + +@pytest.mark.asyncio +async def test_save_and_load_session_state( + tablestore_session: TablestoreSession, +) -> None: + """Test saving and loading session state with a simple state module.""" + session = tablestore_session + session_id = "ft_test_session_save_load" + user_id = "ft_test_user" + + # Create and save state + agent = SimpleStateModule() + agent.name = "Friday" + agent.value = 100 + + await session.save_session_state( + session_id=session_id, + user_id=user_id, + agent=agent, + ) + + # Load state into a new module + loaded_agent = SimpleStateModule() + assert loaded_agent.name == "test_agent" + assert loaded_agent.value == 42 + + await session.load_session_state( + session_id=session_id, + user_id=user_id, + agent=loaded_agent, + ) + + assert loaded_agent.name == "Friday" + assert loaded_agent.value == 100 + + # Cleanup + await session._memory_store.delete_session( + user_id=user_id, + session_id=session_id, + ) + + +@pytest.mark.asyncio +async def test_save_overwrites_existing_state( + tablestore_session: TablestoreSession, +) -> None: + """Test that saving state overwrites the previous state.""" + session = tablestore_session + session_id = "ft_test_session_overwrite" + user_id = "ft_test_user" + + # Save initial state + agent = SimpleStateModule() + agent.name = "Version1" + agent.value = 1 + + await session.save_session_state( + session_id=session_id, + user_id=user_id, + agent=agent, + ) + + # Save updated state + agent.name = "Version2" + agent.value = 2 + + await session.save_session_state( + session_id=session_id, + user_id=user_id, + agent=agent, + ) + + # Load and verify the latest state + loaded_agent = SimpleStateModule() + await session.load_session_state( + session_id=session_id, + user_id=user_id, + agent=loaded_agent, + ) + + assert loaded_agent.name == "Version2" + assert loaded_agent.value == 2 + + # Cleanup + await session._memory_store.delete_session( + user_id=user_id, + session_id=session_id, + ) + + +@pytest.mark.asyncio +async def test_load_nonexistent_session_allowed( + tablestore_session: TablestoreSession, +) -> None: + """Test loading a non-existent session with allow_not_exist=True.""" + session = tablestore_session + + agent = SimpleStateModule() + original_name = agent.name + original_value = agent.value + + # Should not raise, state should remain unchanged + await session.load_session_state( + session_id="ft_nonexistent_session_id", + user_id="ft_nonexistent_user", + allow_not_exist=True, + agent=agent, + ) + + assert agent.name == original_name + assert agent.value == original_value + + +@pytest.mark.asyncio +async def test_load_nonexistent_session_disallowed( + tablestore_session: TablestoreSession, +) -> None: + """Test loading a non-existent session with allow_not_exist=False.""" + session = tablestore_session + + agent = SimpleStateModule() + + with pytest.raises(ValueError): + await session.load_session_state( + session_id="ft_nonexistent_session_id_strict", + user_id="ft_nonexistent_user_strict", + allow_not_exist=False, + agent=agent, + ) + + +@pytest.mark.asyncio +async def test_save_and_load_multiple_modules( + tablestore_session: TablestoreSession, +) -> None: + """Test saving and loading multiple state modules in one session.""" + session = tablestore_session + session_id = "ft_test_session_multi_modules" + user_id = "ft_test_user" + + # Create multiple modules + agent1 = SimpleStateModule() + agent1.name = "Agent1" + agent1.value = 10 + + agent2 = SimpleStateModule() + agent2.name = "Agent2" + agent2.value = 20 + + await session.save_session_state( + session_id=session_id, + user_id=user_id, + agent1=agent1, + agent2=agent2, + ) + + # Load into new modules + loaded1 = SimpleStateModule() + loaded2 = SimpleStateModule() + + await session.load_session_state( + session_id=session_id, + user_id=user_id, + agent1=loaded1, + agent2=loaded2, + ) + + assert loaded1.name == "Agent1" + assert loaded1.value == 10 + assert loaded2.name == "Agent2" + assert loaded2.value == 20 + + # Cleanup + await session._memory_store.delete_session( + user_id=user_id, + session_id=session_id, + ) + + +@pytest.mark.asyncio +async def test_save_and_load_with_memory_module( + tablestore_session: TablestoreSession, +) -> None: + """Test saving and loading a session that includes an InMemoryMemory.""" + session = tablestore_session + session_id = "ft_test_session_with_memory" + user_id = "ft_test_user" + + # Create a memory module with messages + memory = InMemoryMemory() + await memory.add(Msg("Alice", "Hello!", "user")) + await memory.add(Msg("Bob", "Hi there!", "assistant")) + + await session.save_session_state( + session_id=session_id, + user_id=user_id, + memory=memory, + ) + + # Load into a new memory module + loaded_memory = InMemoryMemory() + await session.load_session_state( + session_id=session_id, + user_id=user_id, + memory=loaded_memory, + ) + + loaded_msgs = await loaded_memory.get_memory() + assert len(loaded_msgs) == 2 + assert loaded_msgs[0].name == "Alice" + assert loaded_msgs[0].content == "Hello!" + assert loaded_msgs[1].name == "Bob" + assert loaded_msgs[1].content == "Hi there!" + + # Cleanup + await session._memory_store.delete_session( + user_id=user_id, + session_id=session_id, + ) + + +@pytest.mark.asyncio +async def test_context_manager(tablestore_config: dict[str, str]) -> None: + """Test using TablestoreSession as an async context manager.""" + from agentscope.session import TablestoreSession + + session = TablestoreSession( + session_table_name="agentscope_ft_session", + message_table_name="agentscope_ft_message", + **tablestore_config, + ) + + async with session as session_instance: + assert session_instance._initialized is True + + assert session._initialized is False + assert session._memory_store is None + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tablestore_session_test.py b/tests/tablestore_session_test.py new file mode 100644 index 0000000000..a64b6d2fa7 --- /dev/null +++ b/tests/tablestore_session_test.py @@ -0,0 +1,416 @@ +# -*- coding: utf-8 -*- +"""Tests for the Tablestore session implementation.""" +# pylint: disable=protected-access +from __future__ import annotations + +import json +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, patch + +from agentscope.memory import InMemoryMemory +from agentscope.message import Msg +from agentscope.module import StateModule +from agentscope.session import TablestoreSession + + +class SimpleStateModule(StateModule): + """A simple state module for testing.""" + + def __init__(self) -> None: + super().__init__() + self.name = "test_agent" + self.value = 42 + self.register_state("name") + self.register_state("value") + + +class TablestoreSessionTest(IsolatedAsyncioTestCase): + """Test cases for the Tablestore session module.""" + + def _create_session_with_mocks(self) -> "TablestoreSession": + """Create a TablestoreSession with mocked dependencies.""" + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + # We can't call the real __init__ because it imports tablestore, + # so we construct the object manually with mocks + session = object.__new__(TablestoreSession) + session._tablestore_client = MagicMock() + session._session_table_name = "test_session" + session._message_table_name = "test_message" + session._memory_store = AsyncMock() + session._memory_store_kwargs = {} + session._initialized = True + return session + + async def test_save_session_state(self) -> None: + """Test saving session state to Tablestore.""" + session = self._create_session_with_mocks() + + session._memory_store.update_session = AsyncMock() + + # Create test state modules + agent = SimpleStateModule() + agent.name = "Friday" + agent.value = 100 + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + await session.save_session_state( + session_id="test_session_1", + user_id="user_1", + agent=agent, + ) + + # Verify update_session was called with metadata containing state + session._memory_store.update_session.assert_called_once() + saved_session = session._memory_store.update_session.call_args[0][0] + self.assertEqual(saved_session.session_id, "test_session_1") + self.assertEqual(saved_session.user_id, "user_1") + self.assertIn("__state__", saved_session.metadata) + + saved_state = json.loads(saved_session.metadata["__state__"]) + self.assertEqual(saved_state["agent"]["name"], "Friday") + self.assertEqual(saved_state["agent"]["value"], 100) + + async def test_save_session_state_existing_session(self) -> None: + """Test saving state to an existing session overwrites old state.""" + session = self._create_session_with_mocks() + + session._memory_store.update_session = AsyncMock() + + # First save + agent = SimpleStateModule() + agent.name = "OriginalName" + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + await session.save_session_state( + session_id="test_session_1", + user_id="user_1", + agent=agent, + ) + + # Second save with updated state + agent.name = "UpdatedName" + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + await session.save_session_state( + session_id="test_session_1", + user_id="user_1", + agent=agent, + ) + + # Verify update_session was called twice (upsert semantics) + self.assertEqual( + session._memory_store.update_session.call_count, + 2, + ) + + # Verify the second call contains the updated state + second_call_session = ( + session._memory_store.update_session.call_args_list[1][0][0] + ) + saved_state = json.loads( + second_call_session.metadata["__state__"], + ) + self.assertEqual(saved_state["agent"]["name"], "UpdatedName") + + async def test_load_session_state(self) -> None: + """Test loading session state from Tablestore.""" + session = self._create_session_with_mocks() + + # Create state data stored in session metadata + state_data = { + "agent": {"name": "Friday", "value": 100}, + } + + mock_session = MagicMock() + mock_session.metadata = { + "__state__": json.dumps(state_data), + } + session._memory_store.get_session = AsyncMock( + return_value=mock_session, + ) + + # Create agent and load state + agent = SimpleStateModule() + self.assertEqual(agent.name, "test_agent") + self.assertEqual(agent.value, 42) + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + await session.load_session_state( + session_id="test_session_1", + user_id="user_1", + agent=agent, + ) + + # Verify state was loaded + self.assertEqual(agent.name, "Friday") + self.assertEqual(agent.value, 100) + + async def test_load_session_state_not_exist_allowed(self) -> None: + """Test loading non-existent session with allow_not_exist=True.""" + session = self._create_session_with_mocks() + + session._memory_store.get_session = AsyncMock(return_value=None) + + agent = SimpleStateModule() + original_name = agent.name + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + # Should not raise + await session.load_session_state( + session_id="nonexistent", + user_id="user_1", + allow_not_exist=True, + agent=agent, + ) + + # State should remain unchanged + self.assertEqual(agent.name, original_name) + + async def test_load_session_state_not_exist_disallowed(self) -> None: + """Test loading non-existent session with allow_not_exist=False.""" + session = self._create_session_with_mocks() + + session._memory_store.get_session = AsyncMock(return_value=None) + + agent = SimpleStateModule() + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + with self.assertRaises(ValueError): + await session.load_session_state( + session_id="nonexistent", + user_id="user_1", + allow_not_exist=False, + agent=agent, + ) + + async def test_load_session_no_state_data(self) -> None: + """Test loading session that exists but has no state data.""" + session = self._create_session_with_mocks() + + # Session exists but metadata has no __state__ key + mock_session = MagicMock() + mock_session.metadata = {} + session._memory_store.get_session = AsyncMock( + return_value=mock_session, + ) + + agent = SimpleStateModule() + original_name = agent.name + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + await session.load_session_state( + session_id="test_session_1", + user_id="user_1", + allow_not_exist=True, + agent=agent, + ) + + # State should remain unchanged + self.assertEqual(agent.name, original_name) + + async def test_close(self) -> None: + """Test closing the Tablestore session.""" + session = self._create_session_with_mocks() + + mock_store = session._memory_store + await session.close() + + mock_store.close.assert_called_once() + self.assertIsNone(session._memory_store) + self.assertFalse(session._initialized) + + async def test_close_when_not_initialized(self) -> None: + """Test closing when not initialized does nothing.""" + session = self._create_session_with_mocks() + session._memory_store = None + session._initialized = False + + # Should not raise + await session.close() + + async def test_save_and_load_with_memory_module(self) -> None: + """Test saving and loading a state module that contains memory.""" + session = self._create_session_with_mocks() + + session._memory_store.update_session = AsyncMock() + + # Create a memory module with messages + memory = InMemoryMemory() + await memory.add(Msg("Alice", "Hello!", "user")) + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + await session.save_session_state( + session_id="test_session_1", + user_id="user_1", + memory=memory, + ) + + # Verify the state was serialized correctly in session metadata + saved_session = session._memory_store.update_session.call_args[0][0] + saved_state = json.loads(saved_session.metadata["__state__"]) + self.assertIn("memory", saved_state) + self.assertIn("content", saved_state["memory"]) + + async def test_empty_user_id_defaults_to_default(self) -> None: + """Test that empty user_id falls back to 'default'.""" + session = self._create_session_with_mocks() + + session._memory_store.update_session = AsyncMock() + + agent = SimpleStateModule() + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + await session.save_session_state( + session_id="test_session_1", + user_id="", + agent=agent, + ) + + saved_session = session._memory_store.update_session.call_args[0][0] + self.assertEqual(saved_session.user_id, "default") + + # Also verify load uses "default" for empty user_id + mock_session = MagicMock() + mock_session.metadata = { + "__state__": '{"agent": {"name": "X", "value": 1}}', + } + session._memory_store.get_session = AsyncMock( + return_value=mock_session, + ) + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + await session.load_session_state( + session_id="test_session_1", + user_id="", + agent=agent, + ) + + session._memory_store.get_session.assert_called_once_with( + user_id="default", + session_id="test_session_1", + ) + + async def test_load_session_no_state_raises_when_disallowed(self) -> None: + """Test loading session with no state data raises when disallowed.""" + session = self._create_session_with_mocks() + + mock_session = MagicMock() + mock_session.metadata = {} + session._memory_store.get_session = AsyncMock( + return_value=mock_session, + ) + + agent = SimpleStateModule() + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + with self.assertRaises(ValueError): + await session.load_session_state( + session_id="test_session_1", + user_id="user_1", + allow_not_exist=False, + agent=agent, + ) + + async def test_load_partial_modules(self) -> None: + """Test loading only a subset of saved modules works correctly.""" + session = self._create_session_with_mocks() + + state_data = { + "agent1": {"name": "Agent1", "value": 10}, + "agent2": {"name": "Agent2", "value": 20}, + } + + mock_session = MagicMock() + mock_session.metadata = { + "__state__": json.dumps(state_data), + } + session._memory_store.get_session = AsyncMock( + return_value=mock_session, + ) + + # Only load agent1, skip agent2 + loaded = SimpleStateModule() + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + await session.load_session_state( + session_id="test_session_1", + user_id="user_1", + agent1=loaded, + ) + + self.assertEqual(loaded.name, "Agent1") + self.assertEqual(loaded.value, 10) + + async def test_async_context_manager(self) -> None: + """Test the async context manager protocol.""" + session = self._create_session_with_mocks() + + mock_store = AsyncMock() + mock_store.close = AsyncMock() + session._memory_store = mock_store + session._initialized = True + + with patch( + "agentscope.session._tablestore_session.TablestoreSession" + "._ensure_initialized", + new_callable=AsyncMock, + ): + async with session as entered: + self.assertIs(entered, session) + + # close() sets _memory_store to None, so check the saved reference + mock_store.close.assert_called_once() + self.assertIsNone(session._memory_store) + self.assertFalse(session._initialized)