Skip to content

Redis refactor dirty #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
628 changes: 585 additions & 43 deletions docs/extras/integrations/vectorstores/redis.ipynb

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions libs/langchain/langchain/memory/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Any, Dict, List

from langchain.schema.messages import get_buffer_string # noqa: 401


def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
"""
Expand Down
58 changes: 53 additions & 5 deletions libs/langchain/langchain/utilities/redis.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,64 @@
from __future__ import annotations

import logging
from typing import (
TYPE_CHECKING,
Any,
)
import re
from typing import TYPE_CHECKING, Any, List, Optional, Pattern
from urllib.parse import urlparse

import numpy as np

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from redis.client import Redis as RedisType

logger = logging.getLogger(__name__)

def _array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:
return np.array(array).astype(dtype).tobytes()


class TokenEscaper:
"""
Escape punctuation within an input string.
"""

# Characters that RediSearch requires us to escape during queries.
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/]"

def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)

def escape(self, value: str) -> str:
def escape_symbol(match: re.Match) -> str:
value = match.group(0)
return f"\\{value}"

return self.escaped_chars_re.sub(escape_symbol, value)


def check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
"""Check if the correct Redis modules are installed."""
installed_modules = client.module_list()
installed_modules = {
module[b"name"].decode("utf-8"): module for module in installed_modules
}
for module in required_modules:
if module["name"] in installed_modules and int(
installed_modules[module["name"]][b"ver"]
) >= int(module["ver"]):
return
# otherwise raise error
error_message = (
"Redis cannot be used as a vector database without RediSearch >=2.4"
"Please head to https://redis.io/docs/stack/search/quick_start/"
"to know more about installing the RediSearch module within Redis Stack."
)
logger.error(error_message)
raise ValueError(error_message)


def get_client(redis_url: str, **kwargs: Any) -> RedisType:
Expand Down
Loading