diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index c78148e..4c268a3 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -22,3 +22,9 @@ steps: name: python:3.11 entrypoint: python args: ["-m", "pytest"] + env: + - 'REDIS_URL=$_REDIS_URL' + +options: + pool: + name: 'projects/$PROJECT_ID/locations/$LOCATION/workerPools/pool1' diff --git a/pyproject.toml b/pyproject.toml index cd4e1ac..2017123 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ license = {file = "LICENSE"} requires-python = ">=3.8" dependencies = [ "langchain==0.1.1", + "redis>=5.0.0", ] [project.urls] @@ -37,4 +38,4 @@ profile = "black" [tool.mypy] python_version = 3.8 -warn_unused_configs = true \ No newline at end of file +warn_unused_configs = true diff --git a/src/langchain_google_memorystore_redis/__init__.py b/src/langchain_google_memorystore_redis/__init__.py index 6d5e14b..df4fd2d 100644 --- a/src/langchain_google_memorystore_redis/__init__.py +++ b/src/langchain_google_memorystore_redis/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from langchain_google_memorystore_redis.redis_chat_message_history import MemorystoreChatMessageHistory + +__all__ = ["MemorystoreChatMessageHistory"] diff --git a/src/langchain_google_memorystore_redis/redis_chat_message_history.py b/src/langchain_google_memorystore_redis/redis_chat_message_history.py new file mode 100644 index 0000000..f321627 --- /dev/null +++ b/src/langchain_google_memorystore_redis/redis_chat_message_history.py @@ -0,0 +1,69 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import redis +from typing import List, Optional + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import ( + BaseMessage, + message_to_dict, + messages_from_dict, +) + + +class MemorystoreChatMessageHistory(BaseChatMessageHistory): + """Chat message history stored in a Cloud Memorystore for Redis database.""" + + def __init__( + self, + client: redis.Redis, + session_id: str, + ttl: Optional[int] = None, + ): + """Initializes the chat message history for Memorystore for Redis. + + Args: + client: A redis.Redis object that connects to the Redis instance. + session_id: The session ID for this chat message history. + ttl: The expiration time in seconds of the whole chat history after + the most recent add_message was called. + """ + + self.redis = client + self.key = session_id + self.ttl = ttl + + def __del__(self): + self.redis.close() + + @property + def messages(self) -> List[BaseMessage]: + """Retrieve all messages chronologically stored in this session.""" + all_elements = self.redis.lrange(self.key, 0, -1) + messages = messages_from_dict( + [json.loads(e.decode("utf-8")) for e in all_elements] + ) + return messages + + def add_message(self, message: BaseMessage) -> None: + """Append one message to this session.""" + self.redis.rpush(self.key, json.dumps(message_to_dict(message))) + if self.ttl: + self.redis.expire(self.key, self.ttl) + + def clear(self) -> None: + """Clear all messages in this session.""" + self.redis.delete(self.key) diff --git a/tests/test_chat_message_history.py b/tests/test_chat_message_history.py new file mode 100644 index 0000000..65da34a --- /dev/null +++ b/tests/test_chat_message_history.py @@ -0,0 +1,70 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import random +import re +import string +import time +from typing import Iterator +import uuid +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_google_memorystore_redis import MemorystoreChatMessageHistory +import pytest +import redis + + +def test_redis_multiple_sessions() -> None: + client = redis.from_url( + get_env_var("REDIS_URL", "URL of the Redis instance") + ) + + session_id1 = uuid.uuid4().hex + history1 = MemorystoreChatMessageHistory( + client=client, + session_id=session_id1, + ) + session_id2 = uuid.uuid4().hex + history2 = MemorystoreChatMessageHistory( + client=client, + session_id=session_id2, + ) + + history1.add_ai_message("Hey! I am AI!") + history2.add_user_message("Hey! I am human!") + messages1 = history1.messages + messages2 = history2.messages + + assert len(messages1) == 1 + assert len(messages2) == 1 + assert isinstance(messages1[0], AIMessage) + assert messages1[0].content == "Hey! I am AI!" + assert isinstance(messages2[0], HumanMessage) + assert messages2[0].content == "Hey! I am human!" + + history1.clear() + assert len(history1.messages) == 0 + assert len(history2.messages) == 1 + + history2.clear() + assert len(history1.messages) == 0 + assert len(history2.messages) == 0 + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v