diff --git a/docs/docs/integrations/document_loaders/cratedb.ipynb b/docs/docs/integrations/document_loaders/cratedb.ipynb new file mode 100644 index 0000000000000..78a0e19138703 --- /dev/null +++ b/docs/docs/integrations/document_loaders/cratedb.ipynb @@ -0,0 +1,232 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CrateDB\n", + "\n", + "This notebook demonstrates how to load documents from a [CrateDB] database,\n", + "using the [SQLAlchemy] document loader.\n", + "\n", + "It loads the result of a database query with one document per row.\n", + "\n", + "[CrateDB]: https://github.com/crate/crate\n", + "[SQLAlchemy]: https://www.sqlalchemy.org/" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Prerequisites" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#!pip install crash 'langchain[cratedb]'" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Populate database." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001B[32mCONNECT OK\r\n", + "\u001B[0m\u001B[32mPSQL OK, 1 row affected (0.001 sec)\r\n", + "\u001B[0m\u001B[32mDELETE OK, 30 rows affected (0.008 sec)\r\n", + "\u001B[0m\u001B[32mINSERT OK, 30 rows affected (0.011 sec)\r\n", + "\u001B[0m\u001B[0m\u001B[32mCONNECT OK\r\n", + "\u001B[0m\u001B[32mREFRESH OK, 1 row affected (0.001 sec)\r\n", + "\u001B[0m\u001B[0m" + ] + } + ], + "source": [ + "!crash < ./example_data/mlb_teams_2012.sql\n", + "!crash --command \"REFRESH TABLE mlb_teams_2012;\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Usage" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.document_loaders import CrateDBLoader\n", + "from pprint import pprint\n", + "\n", + "CONNECTION_STRING = \"crate://crate@localhost/\"\n", + "\n", + "loader = CrateDBLoader(\n", + " 'SELECT * FROM mlb_teams_2012 ORDER BY \"Team\" LIMIT 5;',\n", + " url=CONNECTION_STRING,\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Angels\\nPayroll (millions): 154.49\\nWins: 89', metadata={}),\n", + " Document(page_content='Team: Astros\\nPayroll (millions): 60.65\\nWins: 55', metadata={}),\n", + " Document(page_content='Team: Athletics\\nPayroll (millions): 55.37\\nWins: 94', metadata={}),\n", + " Document(page_content='Team: Blue Jays\\nPayroll (millions): 75.48\\nWins: 73', metadata={}),\n", + " Document(page_content='Team: Braves\\nPayroll (millions): 83.31\\nWins: 94', metadata={})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specifying Which Columns are Content vs Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "loader = CrateDBLoader(\n", + " 'SELECT * FROM mlb_teams_2012 ORDER BY \"Team\" LIMIT 5;',\n", + " url=CONNECTION_STRING,\n", + " page_content_columns=[\"Team\"],\n", + " metadata_columns=[\"Payroll (millions)\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Angels', metadata={'Payroll (millions)': 154.49}),\n", + " Document(page_content='Team: Astros', metadata={'Payroll (millions)': 60.65}),\n", + " Document(page_content='Team: Athletics', metadata={'Payroll (millions)': 55.37}),\n", + " Document(page_content='Team: Blue Jays', metadata={'Payroll (millions)': 75.48}),\n", + " Document(page_content='Team: Braves', metadata={'Payroll (millions)': 83.31})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding Source to Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "loader = CrateDBLoader(\n", + " 'SELECT * FROM mlb_teams_2012 ORDER BY \"Team\" LIMIT 5;',\n", + " url=CONNECTION_STRING,\n", + " source_columns=[\"Team\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Angels\\nPayroll (millions): 154.49\\nWins: 89', metadata={'source': 'Angels'}),\n", + " Document(page_content='Team: Astros\\nPayroll (millions): 60.65\\nWins: 55', metadata={'source': 'Astros'}),\n", + " Document(page_content='Team: Athletics\\nPayroll (millions): 55.37\\nWins: 94', metadata={'source': 'Athletics'}),\n", + " Document(page_content='Team: Blue Jays\\nPayroll (millions): 75.48\\nWins: 73', metadata={'source': 'Blue Jays'}),\n", + " Document(page_content='Team: Braves\\nPayroll (millions): 83.31\\nWins: 94', metadata={'source': 'Braves'})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql b/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql new file mode 100644 index 0000000000000..6d94aeaa773b8 --- /dev/null +++ b/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql @@ -0,0 +1,41 @@ +-- Provisioning table "mlb_teams_2012". +-- +-- crash < mlb_teams_2012.sql +-- psql postgresql://postgres@localhost < mlb_teams_2012.sql + +DROP TABLE IF EXISTS mlb_teams_2012; +CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT); +INSERT INTO mlb_teams_2012 + ("Team", "Payroll (millions)", "Wins") +VALUES + ('Nationals', 81.34, 98), + ('Reds', 82.20, 97), + ('Yankees', 197.96, 95), + ('Giants', 117.62, 94), + ('Braves', 83.31, 94), + ('Athletics', 55.37, 94), + ('Rangers', 120.51, 93), + ('Orioles', 81.43, 93), + ('Rays', 64.17, 90), + ('Angels', 154.49, 89), + ('Tigers', 132.30, 88), + ('Cardinals', 110.30, 88), + ('Dodgers', 95.14, 86), + ('White Sox', 96.92, 85), + ('Brewers', 97.65, 83), + ('Phillies', 174.54, 81), + ('Diamondbacks', 74.28, 81), + ('Pirates', 63.43, 79), + ('Padres', 55.24, 76), + ('Mariners', 81.97, 75), + ('Mets', 93.35, 74), + ('Blue Jays', 75.48, 73), + ('Royals', 60.91, 72), + ('Marlins', 118.07, 69), + ('Red Sox', 173.18, 69), + ('Indians', 78.43, 68), + ('Twins', 94.08, 66), + ('Rockies', 78.06, 64), + ('Cubs', 88.19, 61), + ('Astros', 60.65, 55) +; diff --git a/docs/docs/integrations/document_loaders/sqlalchemy.ipynb b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb new file mode 100644 index 0000000000000..5d603d7263c53 --- /dev/null +++ b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb @@ -0,0 +1,237 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SQLAlchemy\n", + "\n", + "This notebook demonstrates how to load documents from an [SQLite] database,\n", + "using the [SQLAlchemy] document loader.\n", + "\n", + "It loads the result of a database query with one document per row.\n", + "\n", + "[SQLAlchemy]: https://www.sqlalchemy.org/\n", + "[SQLite]: https://sqlite.org/" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Prerequisites" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#!pip install langchain termsql" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Provide input data as SQLite database." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting example.csv\n" + ] + } + ], + "source": [ + "%%file example.csv\n", + "Team,Payroll\n", + "Nationals,81.34\n", + "Reds,82.20" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Nationals|81.34\r\n", + "Reds|82.2\r\n" + ] + } + ], + "source": [ + "!termsql --infile=example.csv --head --delimiter=\",\" --outfile=example.sqlite --table=payroll" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Usage" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.document_loaders import SQLAlchemyLoader\n", + "from pprint import pprint\n", + "\n", + "loader = SQLAlchemyLoader(\n", + " \"SELECT * FROM payroll\",\n", + " url=\"sqlite:///example.sqlite\",\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals\\nPayroll: 81.34', metadata={}),\n", + " Document(page_content='Team: Reds\\nPayroll: 82.2', metadata={})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specifying Which Columns are Content vs Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "loader = SQLAlchemyLoader(\n", + " \"SELECT * FROM payroll\",\n", + " url=\"sqlite:///example.sqlite\",\n", + " page_content_columns=[\"Team\"],\n", + " metadata_columns=[\"Payroll\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals', metadata={'Payroll': 81.34}),\n", + " Document(page_content='Team: Reds', metadata={'Payroll': 82.2})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding Source to Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "loader = SQLAlchemyLoader(\n", + " \"SELECT * FROM payroll\",\n", + " url=\"sqlite:///example.sqlite\",\n", + " source_columns=[\"Team\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals\\nPayroll: 81.34', metadata={'source': 'Nationals'}),\n", + " Document(page_content='Team: Reds\\nPayroll: 82.2', metadata={'source': 'Reds'})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb b/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb new file mode 100644 index 0000000000000..f51f5f1d63fca --- /dev/null +++ b/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb @@ -0,0 +1,357 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# CrateDB Chat Message History\n", + "\n", + "This notebook demonstrates how to use the `CrateDBChatMessageHistory`\n", + "to manage chat history in CrateDB, for supporting conversational memory." + ], + "metadata": { + "collapsed": false + }, + "id": "f22eab3f84cbeb37" + }, + { + "cell_type": "markdown", + "source": [ + "## Prerequisites" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!#pip install 'langchain[cratedb]'" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Configuration\n", + "\n", + "To use the storage wrapper, you will need to configure two details.\n", + "\n", + "1. Session Id - a unique identifier of the session, like user name, email, chat id etc.\n", + "2. Database connection string: An SQLAlchemy-compatible URI that specifies the database\n", + " connection. It will be passed to SQLAlchemy create_engine function." + ], + "metadata": { + "collapsed": false + }, + "id": "f8f2830ee9ca1e01" + }, + { + "cell_type": "code", + "execution_count": 52, + "outputs": [], + "source": [ + "from langchain.memory.chat_message_histories import CrateDBChatMessageHistory\n", + "\n", + "CONNECTION_STRING = \"crate://crate@localhost:4200/?schema=example\"\n", + "\n", + "chat_message_history = CrateDBChatMessageHistory(\n", + "\tsession_id=\"test_session\",\n", + "\tconnection_string=CONNECTION_STRING\n", + ")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Basic Usage" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 53, + "outputs": [], + "source": [ + "chat_message_history.add_user_message(\"Hello\")\n", + "chat_message_history.add_ai_message(\"Hi\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:38.077748Z", + "start_time": "2023-08-28T10:04:36.105894Z" + } + }, + "id": "4576e914a866fb40" + }, + { + "cell_type": "code", + "execution_count": 61, + "outputs": [ + { + "data": { + "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_message_history.messages" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:38.929396Z", + "start_time": "2023-08-28T10:04:38.915727Z" + } + }, + "id": "b476688cbb32ba90" + }, + { + "cell_type": "markdown", + "source": [ + "## Custom Storage Model\n", + "\n", + "The default data model, which stores information about conversation messages only\n", + "has two slots for storing message details, the session id, and the message dictionary.\n", + "\n", + "If you want to store additional information, like message date, author, language etc.,\n", + "please provide an implementation for a custom message converter.\n", + "\n", + "This example demonstrates how to create a custom message converter, by implementing\n", + "the `BaseMessageConverter` interface." + ], + "metadata": { + "collapsed": false + }, + "id": "2e5337719d5614fd" + }, + { + "cell_type": "code", + "execution_count": 55, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "from typing import Any\n", + "\n", + "from langchain.memory.chat_message_histories.cratedb import generate_autoincrement_identifier\n", + "from langchain.memory.chat_message_histories.sql import BaseMessageConverter\n", + "from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage\n", + "\n", + "import sqlalchemy as sa\n", + "from sqlalchemy.orm import declarative_base\n", + "\n", + "\n", + "Base = declarative_base()\n", + "\n", + "\n", + "class CustomMessage(Base):\n", + "\t__tablename__ = \"custom_message_store\"\n", + "\n", + "\tid = sa.Column(sa.BigInteger, primary_key=True, default=generate_autoincrement_identifier)\n", + "\tsession_id = sa.Column(sa.Text)\n", + "\ttype = sa.Column(sa.Text)\n", + "\tcontent = sa.Column(sa.Text)\n", + "\tcreated_at = sa.Column(sa.DateTime)\n", + "\tauthor_email = sa.Column(sa.Text)\n", + "\n", + "\n", + "class CustomMessageConverter(BaseMessageConverter):\n", + "\tdef __init__(self, author_email: str):\n", + "\t\tself.author_email = author_email\n", + "\t\n", + "\tdef from_sql_model(self, sql_message: Any) -> BaseMessage:\n", + "\t\tif sql_message.type == \"human\":\n", + "\t\t\treturn HumanMessage(\n", + "\t\t\t\tcontent=sql_message.content,\n", + "\t\t\t)\n", + "\t\telif sql_message.type == \"ai\":\n", + "\t\t\treturn AIMessage(\n", + "\t\t\t\tcontent=sql_message.content,\n", + "\t\t\t)\n", + "\t\telif sql_message.type == \"system\":\n", + "\t\t\treturn SystemMessage(\n", + "\t\t\t\tcontent=sql_message.content,\n", + "\t\t\t)\n", + "\t\telse:\n", + "\t\t\traise ValueError(f\"Unknown message type: {sql_message.type}\")\n", + "\t\n", + "\tdef to_sql_model(self, message: BaseMessage, session_id: str) -> Any:\n", + "\t\tnow = datetime.now()\n", + "\t\treturn CustomMessage(\n", + "\t\t\tsession_id=session_id,\n", + "\t\t\ttype=message.type,\n", + "\t\t\tcontent=message.content,\n", + "\t\t\tcreated_at=now,\n", + "\t\t\tauthor_email=self.author_email\n", + "\t\t)\n", + "\t\n", + "\tdef get_sql_model_class(self) -> Any:\n", + "\t\treturn CustomMessage\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + "\n", + "\tBase.metadata.drop_all(bind=sa.create_engine(CONNECTION_STRING))\n", + "\n", + "\tchat_message_history = CrateDBChatMessageHistory(\n", + "\t\tsession_id=\"test_session\",\n", + "\t\tconnection_string=CONNECTION_STRING,\n", + "\t\tcustom_message_converter=CustomMessageConverter(\n", + "\t\t\tauthor_email=\"test@example.com\"\n", + "\t\t)\n", + "\t)\n", + "\n", + "\tchat_message_history.add_user_message(\"Hello\")\n", + "\tchat_message_history.add_ai_message(\"Hi\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:41.510498Z", + "start_time": "2023-08-28T10:04:41.494912Z" + } + }, + "id": "fdfde84c07d071bb" + }, + { + "cell_type": "code", + "execution_count": 60, + "outputs": [ + { + "data": { + "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_message_history.messages" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:43.497990Z", + "start_time": "2023-08-28T10:04:43.492517Z" + } + }, + "id": "4a6a54d8a9e2856f" + }, + { + "cell_type": "markdown", + "source": [ + "## Custom Name for Session Column\n", + "\n", + "The session id, a unique token identifying the session, is an important property of\n", + "this subsystem. If your database table stores it in a different column, you can use\n", + "the `session_id_field_name` keyword argument to adjust the name correspondingly." + ], + "metadata": { + "collapsed": false + }, + "id": "622aded629a1adeb" + }, + { + "cell_type": "code", + "execution_count": 57, + "outputs": [], + "source": [ + "import json\n", + "import typing as t\n", + "\n", + "from langchain.memory.chat_message_histories.cratedb import generate_autoincrement_identifier, CrateDBMessageConverter\n", + "from langchain.schema import _message_to_dict\n", + "\n", + "\n", + "Base = declarative_base()\n", + "\n", + "class MessageWithDifferentSessionIdColumn(Base):\n", + "\t__tablename__ = \"message_store_different_session_id\"\n", + "\tid = sa.Column(sa.BigInteger, primary_key=True, default=generate_autoincrement_identifier)\n", + "\tcustom_session_id = sa.Column(sa.Text)\n", + "\tmessage = sa.Column(sa.Text)\n", + "\n", + "\n", + "class CustomMessageConverterWithDifferentSessionIdColumn(CrateDBMessageConverter):\n", + " def __init__(self):\n", + " self.model_class = MessageWithDifferentSessionIdColumn\n", + "\n", + " def to_sql_model(self, message: BaseMessage, custom_session_id: str) -> t.Any:\n", + " return self.model_class(\n", + " custom_session_id=custom_session_id, message=json.dumps(_message_to_dict(message))\n", + " )\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + "\tBase.metadata.drop_all(bind=sa.create_engine(CONNECTION_STRING))\n", + "\n", + "\tchat_message_history = CrateDBChatMessageHistory(\n", + "\t\tsession_id=\"test_session\",\n", + "\t\tconnection_string=CONNECTION_STRING,\n", + "\t\tcustom_message_converter=CustomMessageConverterWithDifferentSessionIdColumn(),\n", + "\t\tsession_id_field_name=\"custom_session_id\",\n", + "\t)\n", + "\n", + "\tchat_message_history.add_user_message(\"Hello\")\n", + "\tchat_message_history.add_ai_message(\"Hi\")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 58, + "outputs": [ + { + "data": { + "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_message_history.messages" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/integrations/providers/cratedb.mdx b/docs/docs/integrations/providers/cratedb.mdx new file mode 100644 index 0000000000000..4764a7ad92369 --- /dev/null +++ b/docs/docs/integrations/providers/cratedb.mdx @@ -0,0 +1,203 @@ +# CrateDB + +This documentation section shows how to use the CrateDB vector store +functionality around [`FLOAT_VECTOR`] and [`KNN_MATCH`]. You will learn +how to use it for similarity search and other purposes. + + +## What is CrateDB? + +[CrateDB] is an open-source, distributed, and scalable SQL analytics database +for storing and analyzing massive amounts of data in near real-time, even with +complex queries. It is PostgreSQL-compatible, based on [Lucene], and inherits +the shared-nothing distribution layer of [Elasticsearch]. + +It provides a distributed, multi-tenant-capable relational database and search +engine with HTTP and PostgreSQL interfaces, and schema-free objects. It supports +sharding, partitioning, and replication out of the box. + +CrateDB enables you to efficiently store billions of records, and terabytes of +data, and query it using SQL. + +- Provides a standards-based SQL interface for querying relational data, nested + documents, geospatial constraints, and vector embeddings at the same time. +- Improves your operations by storing time-series data, relational metadata, + and vector embeddings within a single database. +- Builds upon approved technologies from Lucene and Elasticsearch. + + +## CrateDB Cloud + +- Offers on-demand CrateDB clusters without operational overhead, + with enterprise-grade features and [ISO 27001] certification. +- The entrypoint to [CrateDB Cloud] is the [CrateDB Cloud Console]. +- Crate.io offers a free tier via [CrateDB Cloud CRFREE]. +- To get started, [sign up] to CrateDB Cloud, deploy a database cluster, + and follow the upcoming instructions. + + +## Features + +The CrateDB adapter supports the _Vector Store_, _Document Loader_, +and _Conversational Memory_ subsystems of LangChain. + +### Vector Store + +`CrateDBVectorSearch` is an API wrapper around CrateDB's `FLOAT_VECTOR` type +and the corresponding `KNN_MATCH` function, based on SQLAlchemy and CrateDB's +SQLAlchemy dialect. It provides an interface to store and retrieve floating +point vectors, and to conduct similarity searches. + +Supports: +- Approximate nearest neighbor search. +- Euclidean distance. + +### Document Loader + +`CrateDBLoader` provides loading documents from a database table by an SQL +query expression or an SQLAlchemy selectable instance. + +### Conversational Memory + +`CrateDBChatMessageHistory` uses CrateDB to manage conversation history. + + +## Installation and Setup + +There are multiple ways to get started with CrateDB. + +### Install CrateDB on your local machine + +You can [download CrateDB], or use the [OCI image] to run CrateDB on Docker or Podman. +Note that this is not recommended for production use. + +```shell +docker run --rm -it --name=cratedb --publish=4200:4200 --publish=5432:5432 \ + --env=CRATE_HEAP_SIZE=4g crate/crate:nightly \ + -Cdiscovery.type=single-node +``` + +### Deploy a cluster on CrateDB Cloud + +[CrateDB Cloud] is a managed CrateDB service. Sign up for a [free trial]. + +### Install Client + +```bash +pip install 'crate[sqlalchemy]' 'langchain[openai]' 'crash' +``` + + +## Usage » Vector Store + +For a more detailed walkthrough of the `CrateDBVectorSearch` wrapper, there is also +a corresponding [Jupyter notebook](/docs/extras/integrations/vectorstores/cratedb.html). + +### Provide input data +The example uses the canonical `state_of_the_union.txt`. +```shell +wget https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt +``` + +### Set environment variables +Use a valid OpenAI API key and SQL connection string. This one fits a local instance of CrateDB. +```shell +export OPENAI_API_KEY=foobar # FIXME +export CRATEDB_CONNECTION_STRING=crate://crate@localhost +``` + +### Example + +Load and index documents, and invoke query. +```python +from langchain.document_loaders import UnstructuredURLLoader +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.text_splitter import CharacterTextSplitter +from langchain.vectorstores import CrateDBVectorSearch + + +def main(): + # Load the document, split it into chunks, embed each chunk and load it into the vector store. + raw_documents = UnstructuredURLLoader("https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt").load() + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + documents = text_splitter.split_documents(raw_documents) + db = CrateDBVectorSearch.from_documents(documents, OpenAIEmbeddings()) + + query = "What did the president say about Ketanji Brown Jackson" + docs = db.similarity_search(query) + print(docs[0].page_content) + + +if __name__ == "__main__": + main() +``` + + +## Usage » Document Loader + +For a more detailed walkthrough of the `CrateDBLoader`, there is also a corresponding +[Jupyter notebook](/docs/extras/integrations/document_loaders/cratedb.html). + + +### Provide input data +```shell +wget https://github.com/crate-workbench/langchain/raw/cratedb/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql +crash < ./example_data/mlb_teams_2012.sql +crash --command "REFRESH TABLE mlb_teams_2012;" +``` + +### Load documents by SQL query +```python +from langchain.document_loaders import CrateDBLoader +from pprint import pprint + +def main(): + loader = CrateDBLoader( + 'SELECT * FROM mlb_teams_2012 ORDER BY "Team" LIMIT 5;', + url="crate://crate@localhost/", + ) + documents = loader.load() + pprint(documents) + +if __name__ == "__main__": + main() +``` + + +## Usage » Conversational Memory + +For a more detailed walkthrough of the `CrateDBChatMessageHistory`, there is also a corresponding +[Jupyter notebook](/docs/extras/integrations/memory/cratedb_chat_message_history.html). + +```python +from langchain.memory.chat_message_histories import CrateDBChatMessageHistory +from pprint import pprint + +def main(): + chat_message_history = CrateDBChatMessageHistory( + session_id="test_session", + connection_string="crate://crate@localhost/", + ) + chat_message_history.add_user_message("Hello") + chat_message_history.add_ai_message("Hi") + pprint(chat_message_history) + +if __name__ == "__main__": + main() +``` + + +[CrateDB]: https://github.com/crate/crate +[CrateDB Cloud]: https://crate.io/product +[CrateDB Cloud Console]: https://console.cratedb.cloud/ +[CrateDB Cloud CRFREE]: https://community.crate.io/t/new-cratedb-cloud-edge-feature-cratedb-cloud-free-tier/1402 +[CrateDB SQLAlchemy dialect]: https://crate.io/docs/python/en/latest/sqlalchemy.html +[download CrateDB]: https://crate.io/download +[Elastisearch]: https://github.com/elastic/elasticsearch +[`FLOAT_VECTOR`]: https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector +[free trial]: https://crate.io/lp-crfree?utm_source=langchain +[ISO 27001]: https://crate.io/blog/cratedb-elevates-its-security-standards-and-achieves-iso-27001-certification +[`KNN_MATCH`]: https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match +[Lucene]: https://github.com/apache/lucene +[OCI image]: https://hub.docker.com/_/crate +[sign up]: https://console.cratedb.cloud/ diff --git a/docs/docs/integrations/vectorstores/cratedb.ipynb b/docs/docs/integrations/vectorstores/cratedb.ipynb new file mode 100644 index 0000000000000..06430e6355ae9 --- /dev/null +++ b/docs/docs/integrations/vectorstores/cratedb.ipynb @@ -0,0 +1,499 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CrateDB\n", + "\n", + "This notebook shows how to use the CrateDB vector store functionality around\n", + "[`FLOAT_VECTOR`] and [`KNN_MATCH`]. You will learn how to use it for similarity\n", + "search and other purposes.\n", + "\n", + "It supports:\n", + "- Similarity Search with Euclidean Distance\n", + "- Maximal Marginal Relevance Search (MMR)\n", + "\n", + "## What is CrateDB?\n", + "\n", + "[CrateDB] is an open-source, distributed, and scalable SQL analytics database\n", + "for storing and analyzing massive amounts of data in near real-time, even with\n", + "complex queries. It is PostgreSQL-compatible, based on [Lucene], and inherits\n", + "the shared-nothing distribution layer of [Elasticsearch].\n", + "\n", + "This example uses the [Python client driver for CrateDB]. For more documentation,\n", + "see also [LangChain with CrateDB].\n", + "\n", + "\n", + "[CrateDB]: https://github.com/crate/crate\n", + "[Elasticsearch]: https://github.com/elastic/elasticsearch\n", + "[`FLOAT_VECTOR`]: https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector\n", + "[`KNN_MATCH`]: https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match\n", + "[LangChain with CrateDB]: /docs/extras/integrations/providers/cratedb.html\n", + "[Lucene]: https://github.com/apache/lucene\n", + "[Python client driver for CrateDB]: https://crate.io/docs/python/" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Getting Started" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [], + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "# Install required packages: LangChain, OpenAI SDK, and the CrateDB Python driver.\n", + "!pip install 'langchain[cratedb,openai]'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You need to provide an OpenAI API key, optionally using the environment\n", + "variable `OPENAI_API_KEY`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:02:16.802456Z", + "start_time": "2023-09-09T08:02:07.065604Z" + } + }, + "outputs": [], + "source": [ + "import os\n", + "import getpass\n", + "from dotenv import load_dotenv, find_dotenv\n", + "\n", + "# Run `export OPENAI_API_KEY=sk-YOUR_OPENAI_API_KEY`.\n", + "# Get OpenAI api key from `.env` file.\n", + "# Otherwise, prompt for it.\n", + "_ = load_dotenv(find_dotenv())\n", + "OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', getpass.getpass(\"OpenAI API key:\"))\n", + "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY" + ] + }, + { + "cell_type": "markdown", + "source": [ + "You also need to provide a connection string to your CrateDB database cluster,\n", + "optionally using the environment variable `CRATEDB_CONNECTION_STRING`.\n", + "\n", + "This example uses a CrateDB instance on your workstation, which you can start by\n", + "running [CrateDB using Docker]. Alternatively, you can also connect to a cluster\n", + "running on [CrateDB Cloud].\n", + "\n", + "[CrateDB Cloud]: https://console.cratedb.cloud/\n", + "[CrateDB using Docker]: https://crate.io/docs/crate/tutorials/en/latest/basic/index.html#docker" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import os\n", + "\n", + "CONNECTION_STRING = os.environ.get(\n", + " \"CRATEDB_CONNECTION_STRING\",\n", + " \"crate://crate@localhost:4200/?schema=langchain\",\n", + ")\n", + "\n", + "# For CrateDB Cloud, use:\n", + "# CONNECTION_STRING = os.environ.get(\n", + "# \"CRATEDB_CONNECTION_STRING\",\n", + "# \"crate://username:password@hostname:4200/?ssl=true&schema=langchain\",\n", + "# )" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:02:28.174088Z", + "start_time": "2023-09-09T08:02:28.162698Z" + } + }, + "outputs": [], + "source": [ + "\"\"\"\n", + "# Alternatively, the connection string can be assembled from individual\n", + "# environment variables.\n", + "import os\n", + "\n", + "CONNECTION_STRING = CrateDBVectorSearch.connection_string_from_db_params(\n", + " driver=os.environ.get(\"CRATEDB_DRIVER\", \"crate\"),\n", + " host=os.environ.get(\"CRATEDB_HOST\", \"localhost\"),\n", + " port=int(os.environ.get(\"CRATEDB_PORT\", \"4200\")),\n", + " database=os.environ.get(\"CRATEDB_DATABASE\", \"langchain\"),\n", + " user=os.environ.get(\"CRATEDB_USER\", \"crate\"),\n", + " password=os.environ.get(\"CRATEDB_PASSWORD\", \"\"),\n", + ")\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "You will start by importing all required modules." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.vectorstores import CrateDBVectorSearch\n", + "from langchain.document_loaders import UnstructuredURLLoader\n", + "from langchain.docstore.document import Document" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Load and Index Documents\n", + "\n", + "Next, you will read input data, and tokenize it. The module will create a table\n", + "with the name of the collection. Make sure the collection name is unique, and\n", + "that you have the permission to create a table." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "loader = UnstructuredURLLoader(\"https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt\")\n", + "documents = loader.load()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "docs = text_splitter.split_documents(documents)\n", + "\n", + "COLLECTION_NAME = \"state_of_the_union_test\"\n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "\n", + "db = CrateDBVectorSearch.from_documents(\n", + " embedding=embeddings,\n", + " documents=docs,\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + ")" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "is_executing": true + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Search Documents\n", + "\n", + "### Similarity Search with Euclidean Distance\n", + "Searching by euclidean distance is the default." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:05:11.104135Z", + "start_time": "2023-09-09T08:05:10.548998Z" + } + }, + "outputs": [], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs_with_score = db.similarity_search_with_score(query)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:05:13.532334Z", + "start_time": "2023-09-09T08:05:13.523191Z" + } + }, + "outputs": [], + "source": [ + "for doc, score in docs_with_score:\n", + " print(\"-\" * 80)\n", + " print(\"Score: \", score)\n", + " print(doc.page_content)\n", + " print(\"-\" * 80)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Maximal Marginal Relevance Search (MMR)\n", + "Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "docs_with_score = db.max_marginal_relevance_search_with_score(query)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-09-09T08:05:23.276819Z", + "start_time": "2023-09-09T08:05:21.972256Z" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "for doc, score in docs_with_score:\n", + " print(\"-\" * 80)\n", + " print(\"Score: \", score)\n", + " print(doc.page_content)\n", + " print(\"-\" * 80)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-09-09T08:05:27.478580Z", + "start_time": "2023-09-09T08:05:27.470138Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "### Searching in Multiple Collections\n", + "`CrateDBVectorSearchMultiCollection` is a special adapter which provides similarity search across\n", + "multiple collections. It can not be used for indexing documents." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from langchain.vectorstores.cratedb import CrateDBVectorSearchMultiCollection\n", + "\n", + "multisearch = CrateDBVectorSearchMultiCollection(\n", + " collection_names=[\"test_collection_1\", \"test_collection_2\"],\n", + " embedding_function=embeddings,\n", + " connection_string=CONNECTION_STRING,\n", + ")\n", + "docs_with_score = multisearch.similarity_search_with_score(query)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Working with the Vector Store\n", + "\n", + "In the example above, you created a vector store from scratch. When\n", + "aiming to work with an existing vector store, you can initialize it directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "store = CrateDBVectorSearch(\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + " embedding_function=embeddings,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Add Documents\n", + "\n", + "You can also add documents to an existing vector store." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "store.add_documents([Document(page_content=\"foo\")])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score = db.similarity_search_with_score(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Overwriting a Vector Store\n", + "\n", + "If you have an existing collection, you can overwrite it by using `from_documents`,\n", + "aad setting `pre_delete_collection = True`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "db = CrateDBVectorSearch.from_documents(\n", + " documents=docs,\n", + " embedding=embeddings,\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + " pre_delete_collection=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score = db.similarity_search_with_score(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using a Vector Store as a Retriever" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retriever = store.as_retriever()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(retriever)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/docs/modules/data_connection/document_loaders/sqlalchemy.mdx b/docs/docs/modules/data_connection/document_loaders/sqlalchemy.mdx new file mode 100644 index 0000000000000..9f7e663db075e --- /dev/null +++ b/docs/docs/modules/data_connection/document_loaders/sqlalchemy.mdx @@ -0,0 +1,155 @@ +# SQLAlchemy + + +## About + +The [SQLAlchemy] document loader loads records from any supported database, +see [SQLAlchemy dialects] for all supported SQL databases and dialects. + +You can either use plain SQL for querying, or use an SQLAlchemy `Select` +statement object, if you are using SQLAlchemy-Core or -ORM. + +You can select which columns to place into the document, which columns +to place into its metadata, which columns to use as a `source` attribute +in metadata, and whether to include the result row number and/or the SQL +query expression into the metadata. + + +## Example + +This example uses PostgreSQL, and the `psycopg2` driver. + + +### Prerequisites + +```shell +psql postgresql://postgres@localhost/ --command "CREATE DATABASE testdrive;" +psql postgresql://postgres@localhost/testdrive < ./libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql +``` + + +### Basic loading + +```python +from langchain.document_loaders import SQLAlchemyLoader +from pprint import pprint + + +loader = SQLAlchemyLoader( + query="SELECT * FROM mlb_teams_2012 LIMIT 3;", + url="postgresql+psycopg2://postgres@localhost:5432/testdrive", +) +docs = loader.load() +``` + +```python +pprint(docs) +``` + + + +``` +[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={}), + Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={}), + Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={})] +``` + + + + +## Enriching metadata + +Use the `include_rownum_into_metadata` and `include_query_into_metadata` options to +optionally populate the `metadata` dictionary with corresponding information. + +Having the `query` within metadata is useful when using documents loaded from +database tables for chains that answer questions using their origin queries. + +```python +loader = SQLAlchemyLoader( + query="SELECT * FROM mlb_teams_2012 LIMIT 3;", + url="postgresql+psycopg2://postgres@localhost:5432/testdrive", + include_rownum_into_metadata=True, + include_query_into_metadata=True, +) +docs = loader.load() +``` + +```python +pprint(docs) +``` + + + +``` +[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={'row': 0, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'}), + Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={'row': 1, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'}), + Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={'row': 2, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'})] +``` + + + + +## Customizing metadata + +Use the `page_content_columns`, and `metadata_columns` options to optionally populate +the `metadata` dictionary with corresponding information. When `page_content_columns` +is empty, all columns will be used. + +```python +loader = SQLAlchemyLoader( + query="SELECT * FROM mlb_teams_2012 LIMIT 3;", + url="postgresql+psycopg2://postgres@localhost:5432/testdrive", + page_content_columns=["Payroll (millions)", "Wins"], + metadata_columns=["Team"], +) +docs = loader.load() +``` + +```python +pprint(docs) +``` + + + +``` +[Document(page_content='Payroll (millions): 81.34\nWins: 98', metadata={'Team': 'Nationals'}), + Document(page_content='Payroll (millions): 82.2\nWins: 97', metadata={'Team': 'Reds'}), + Document(page_content='Payroll (millions): 197.96\nWins: 95', metadata={'Team': 'Yankees'})] +``` + + + + +## Specify column(s) to identify the document source + +Use the `source_columns` option to specify the columns to use as a "source" for the +document created from each row. This is useful for identifying documents through +their metadata. Typically, you may use the primary key column(s) for that purpose. + +```python +loader = SQLAlchemyLoader( + query="SELECT * FROM mlb_teams_2012 LIMIT 3;", + url="postgresql+psycopg2://postgres@localhost:5432/testdrive", + source_columns="Team", +) +docs = loader.load() +``` + +```python +pprint(docs) +``` + + + +``` +[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={'source': 'Nationals'}), + Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={'source': 'Reds'}), + Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={'source': 'Yankees'})] +``` + + + + +[SQLAlchemy]: https://www.sqlalchemy.org/ +[SQLAlchemy dialects]: https://docs.sqlalchemy.org/en/20/dialects/ diff --git a/docs/vercel.json b/docs/vercel.json index 025c56cdb69af..20de4ea57968c 100644 --- a/docs/vercel.json +++ b/docs/vercel.json @@ -552,6 +552,10 @@ "source": "/docs/integrations/chaindesk", "destination": "/docs/integrations/providers/chaindesk" }, + { + "source": "/docs/integrations/cratedb", + "destination": "/docs/integrations/providers/cratedb" + }, { "source": "/docs/integrations/databricks", "destination": "/docs/integrations/providers/databricks" @@ -1732,6 +1736,10 @@ "source": "/docs/modules/data_connection/document_loaders/integrations/copypaste", "destination": "/docs/integrations/document_loaders/copypaste" }, + { + "source": "/docs/modules/data_connection/document_loaders/integrations/cratedb", + "destination": "/docs/integrations/document_loaders/cratedb" + }, { "source": "/en/latest/modules/indexes/document_loaders/examples/csv.html", "destination": "/docs/integrations/document_loaders/csv" @@ -2680,6 +2688,14 @@ "source": "/docs/modules/data_connection/vectorstores/integrations/chroma", "destination": "/docs/integrations/vectorstores/chroma" }, + { + "source": "/en/latest/modules/indexes/vectorstores/examples/cratedb.html", + "destination": "/docs/integrations/vectorstores/cratedb" + }, + { + "source": "/docs/modules/data_connection/vectorstores/integrations/cratedb", + "destination": "/docs/integrations/vectorstores/cratedb" + }, { "source": "/en/latest/modules/indexes/vectorstores/examples/deeplake.html", "destination": "/docs/integrations/vectorstores/activeloop_deeplake" @@ -2944,6 +2960,14 @@ "source": "/docs/integrations/memory/entity_memory_with_sqlite", "destination": "/docs/integrations/memory/sqlite" }, + { + "source": "/en/latest/modules/memory/examples/cratedb_chat_message_history.html", + "destination": "/docs/integrations/memory/cratedb_chat_message_history" + }, + { + "source": "/docs/modules/memory/integrations/cratedb_chat_message_history", + "destination": "/docs/integrations/memory/cratedb_chat_message_history" + }, { "source": "/en/latest/modules/memory/examples/dynamodb_chat_message_history.html", "destination": "/docs/integrations/memory/dynamodb_chat_message_history" diff --git a/libs/experimental/tests/unit_tests/conftest.py b/libs/experimental/tests/unit_tests/conftest.py index da45a330f50af..afb609f3ce31f 100644 --- a/libs/experimental/tests/unit_tests/conftest.py +++ b/libs/experimental/tests/unit_tests/conftest.py @@ -40,8 +40,8 @@ def test_something(): # Used to avoid repeated calls to `util.find_spec` required_pkgs_info: Dict[str, bool] = {} - only_extended = config.getoption("--only-extended") or False - only_core = config.getoption("--only-core") or False + only_extended = config.getoption("--only-extended", False) + only_core = config.getoption("--only-core", False) if only_extended and only_core: raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py index 96cfd9e1b5486..a454a1baa9eb4 100644 --- a/libs/langchain/langchain/document_loaders/__init__.py +++ b/libs/langchain/langchain/document_loaders/__init__.py @@ -59,6 +59,7 @@ from langchain.document_loaders.concurrent import ConcurrentLoader from langchain.document_loaders.confluence import ConfluenceLoader from langchain.document_loaders.conllu import CoNLLULoader +from langchain.document_loaders.cratedb import CrateDBLoader from langchain.document_loaders.csv_loader import CSVLoader, UnstructuredCSVLoader from langchain.document_loaders.cube_semantic import CubeSemanticLoader from langchain.document_loaders.datadog_logs import DatadogLogsLoader @@ -158,6 +159,7 @@ from langchain.document_loaders.slack_directory import SlackDirectoryLoader from langchain.document_loaders.snowflake_loader import SnowflakeLoader from langchain.document_loaders.spreedly import SpreedlyLoader +from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader from langchain.document_loaders.srt import SRTLoader from langchain.document_loaders.stripe import StripeLoader from langchain.document_loaders.telegram import ( @@ -244,6 +246,7 @@ "CollegeConfidentialLoader", "ConcurrentLoader", "ConfluenceLoader", + "CrateDBLoader", "CubeSemanticLoader", "DataFrameLoader", "DatadogLogsLoader", @@ -333,6 +336,7 @@ "SlackDirectoryLoader", "SnowflakeLoader", "SpreedlyLoader", + "SQLAlchemyLoader", "StripeLoader", "TelegramChatApiLoader", "TelegramChatFileLoader", diff --git a/libs/langchain/langchain/document_loaders/cratedb.py b/libs/langchain/langchain/document_loaders/cratedb.py new file mode 100644 index 0000000000000..9e34b4d0cb9ec --- /dev/null +++ b/libs/langchain/langchain/document_loaders/cratedb.py @@ -0,0 +1,5 @@ +from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader + + +class CrateDBLoader(SQLAlchemyLoader): + pass diff --git a/libs/langchain/langchain/document_loaders/sqlalchemy.py b/libs/langchain/langchain/document_loaders/sqlalchemy.py new file mode 100644 index 0000000000000..787c9f339b686 --- /dev/null +++ b/libs/langchain/langchain/document_loaders/sqlalchemy.py @@ -0,0 +1,112 @@ +from typing import Dict, List, Optional, Union + +import sqlalchemy as sa + +from langchain.docstore.document import Document +from langchain.document_loaders.base import BaseLoader + + +class SQLAlchemyLoader(BaseLoader): + """ + Load documents by querying database tables supported by SQLAlchemy. + Each document represents one row of the result. + """ + + def __init__( + self, + query: Union[str, sa.Select], + url: str, + page_content_columns: Optional[List[str]] = None, + metadata_columns: Optional[List[str]] = None, + source_columns: Optional[List[str]] = None, + include_rownum_into_metadata: bool = False, + include_query_into_metadata: bool = False, + sqlalchemy_kwargs: Optional[Dict] = None, + ): + """ + + Args: + query: The query to execute. + url: The SQLAlchemy connection string of the database to connect to. + page_content_columns: The columns to write into the `page_content` + of the document. Optional. + metadata_columns: The columns to write into the `metadata` of the document. + Optional. + source_columns: The names of the columns to use as the `source` within the + metadata dictionary. Optional. + include_rownum_into_metadata: Whether to include the row number into the + metadata dictionary. Optional. Default: False. + include_query_into_metadata: Whether to include the query expression into + the metadata dictionary. Optional. Default: False. + sqlalchemy_kwargs: More keyword arguments for SQLAlchemy's `create_engine`. + """ + self.query = query + self.url = url + self.page_content_columns = page_content_columns + self.metadata_columns = metadata_columns + self.source_columns = source_columns + self.include_rownum_into_metadata = include_rownum_into_metadata + self.include_query_into_metadata = include_query_into_metadata + self.sqlalchemy_kwargs = sqlalchemy_kwargs or {} + + def load(self) -> List[Document]: + try: + import sqlalchemy as sa + except ImportError: + raise ImportError( + "Could not import sqlalchemy python package. " + "Please install it with `pip install sqlalchemy`." + ) + + engine = sa.create_engine(self.url, **self.sqlalchemy_kwargs) + + docs = [] + with engine.connect() as conn: + if isinstance(self.query, sa.Select): + result = conn.execute(self.query) + query_sql = str(self.query.compile(bind=engine)) + elif isinstance(self.query, str): + result = conn.execute(sa.text(self.query)) + query_sql = self.query + else: + raise TypeError( + f"Unable to process query of unknown type: {self.query}" + ) + field_names = list(result.mappings().keys()) + + if self.page_content_columns is None: + page_content_columns = field_names + else: + page_content_columns = self.page_content_columns + + if self.metadata_columns is None: + metadata_columns = [] + else: + metadata_columns = self.metadata_columns + + for i, row in enumerate(result.mappings()): + page_content = "\n".join( + f"{column}: {value}" + for column, value in row.items() + if column in page_content_columns + ) + + metadata: Dict[str, Union[str, int]] = {} + if self.include_rownum_into_metadata: + metadata["row"] = i + if self.include_query_into_metadata: + metadata["query"] = query_sql + + source_values = [] + for column, value in row.items(): + if column in metadata_columns: + metadata[column] = value + if self.source_columns and column in self.source_columns: + source_values.append(value) + if source_values: + metadata["source"] = ",".join(source_values) + + doc = Document(page_content=page_content, metadata=metadata) + docs.append(doc) + + return docs diff --git a/libs/langchain/langchain/memory/chat_message_histories/__init__.py b/libs/langchain/langchain/memory/chat_message_histories/__init__.py index 83fc7fa519ad0..7275d66d1108b 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/__init__.py +++ b/libs/langchain/langchain/memory/chat_message_histories/__init__.py @@ -5,6 +5,7 @@ CassandraChatMessageHistory, ) from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory +from langchain.memory.chat_message_histories.cratedb import CrateDBChatMessageHistory from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory from langchain.memory.chat_message_histories.elasticsearch import ( ElasticsearchChatMessageHistory, @@ -38,6 +39,7 @@ "ChatMessageHistory", "CassandraChatMessageHistory", "CosmosDBChatMessageHistory", + "CrateDBChatMessageHistory", "DynamoDBChatMessageHistory", "ElasticsearchChatMessageHistory", "FileChatMessageHistory", diff --git a/libs/langchain/langchain/memory/chat_message_histories/cratedb.py b/libs/langchain/langchain/memory/chat_message_histories/cratedb.py new file mode 100644 index 0000000000000..19007176cb193 --- /dev/null +++ b/libs/langchain/langchain/memory/chat_message_histories/cratedb.py @@ -0,0 +1,113 @@ +import json +import typing as t + +import sqlalchemy as sa +from cratedb_toolkit.sqlalchemy import ( + patch_inspector, + polyfill_refresh_after_dml, + refresh_table, +) + +from langchain.memory.chat_message_histories.sql import ( + BaseMessageConverter, + SQLChatMessageHistory, +) +from langchain.schema import BaseMessage, _message_to_dict, messages_from_dict + + +def create_message_model(table_name, DynamicBase): # type: ignore + """ + Create a message model for a given table name. + + This is a specialized version for CrateDB for generating integer-based primary keys. + TODO: Find a way to converge CrateDB's generate_random_uuid() into a variant + returning its integer value. + + Args: + table_name: The name of the table to use. + DynamicBase: The base class to use for the model. + + Returns: + The model class. + """ + + # Model is declared inside a function to be able to use a dynamic table name. + class Message(DynamicBase): + __tablename__ = table_name + id = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now()) + session_id = sa.Column(sa.Text) + message = sa.Column(sa.Text) + + return Message + + +class CrateDBMessageConverter(BaseMessageConverter): + """ + The default message converter for CrateDBMessageConverter. + + It is the same as the generic `SQLChatMessageHistory` converter, + but swaps in a different `create_message_model` function. + """ + + def __init__(self, table_name: str): + self.model_class = create_message_model(table_name, sa.orm.declarative_base()) + + def from_sql_model(self, sql_message: t.Any) -> BaseMessage: + return messages_from_dict([json.loads(sql_message.message)])[0] + + def to_sql_model(self, message: BaseMessage, session_id: str) -> t.Any: + return self.model_class( + session_id=session_id, message=json.dumps(_message_to_dict(message)) + ) + + def get_sql_model_class(self) -> t.Any: + return self.model_class + + +class CrateDBChatMessageHistory(SQLChatMessageHistory): + """ + It is the same as the generic `SQLChatMessageHistory` implementation, + but swaps in a different message converter by default. + """ + + DEFAULT_MESSAGE_CONVERTER: t.Type[BaseMessageConverter] = CrateDBMessageConverter + + def __init__( + self, + session_id: str, + connection_string: str, + table_name: str = "message_store", + session_id_field_name: str = "session_id", + custom_message_converter: t.Optional[BaseMessageConverter] = None, + ): + # FIXME: Refactor elsewhere. + patch_inspector() + + super().__init__( + session_id, + connection_string, + table_name=table_name, + session_id_field_name=session_id_field_name, + custom_message_converter=custom_message_converter, + ) + + # TODO: Check how this can be improved. + polyfill_refresh_after_dml(self.Session) + + def _messages_query(self) -> sa.Select: + """ + Construct an SQLAlchemy selectable to query for messages. + For CrateDB, add an `ORDER BY` clause on the primary key. + """ + selectable = super()._messages_query() + selectable = selectable.order_by(self.sql_model_class.id) + return selectable + + def clear(self) -> None: + """ + Needed for CrateDB to synchronize data because `on_flush` does not catch it. + """ + outcome = super().clear() + with self.Session() as session: + refresh_table(session, self.sql_model_class) + return outcome diff --git a/libs/langchain/langchain/memory/chat_message_histories/sql.py b/libs/langchain/langchain/memory/chat_message_histories/sql.py index fcc3ac71ab1a6..edf59685dc2fe 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/sql.py +++ b/libs/langchain/langchain/memory/chat_message_histories/sql.py @@ -1,9 +1,9 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any, List, Optional, Type -from sqlalchemy import Column, Integer, Text, create_engine +from sqlalchemy import Column, Integer, Select, Text, create_engine, select try: from sqlalchemy.orm import declarative_base @@ -23,6 +23,10 @@ class BaseMessageConverter(ABC): """The class responsible for converting BaseMessage to your SQLAlchemy model.""" + @abstractmethod + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError + @abstractmethod def from_sql_model(self, sql_message: Any) -> BaseMessage: """Convert a SQLAlchemy model to a BaseMessage instance.""" @@ -52,7 +56,7 @@ def create_message_model(table_name, DynamicBase): # type: ignore """ - # Model decleared inside a function to have a dynamic table name + # Model declared inside a function to have a dynamic table name class Message(DynamicBase): __tablename__ = table_name id = Column(Integer, primary_key=True) @@ -83,6 +87,8 @@ def get_sql_model_class(self) -> Any: class SQLChatMessageHistory(BaseChatMessageHistory): """Chat message history stored in an SQL database.""" + DEFAULT_MESSAGE_CONVERTER: Type[BaseMessageConverter] = DefaultMessageConverter + def __init__( self, session_id: str, @@ -94,7 +100,9 @@ def __init__( self.connection_string = connection_string self.engine = create_engine(connection_string, echo=False) self.session_id_field_name = session_id_field_name - self.converter = custom_message_converter or DefaultMessageConverter(table_name) + self.converter = custom_message_converter or self.DEFAULT_MESSAGE_CONVERTER( + table_name + ) self.sql_model_class = self.converter.get_sql_model_class() if not hasattr(self.sql_model_class, session_id_field_name): raise ValueError("SQL model class must have session_id column") @@ -106,21 +114,25 @@ def __init__( def _create_table_if_not_exists(self) -> None: self.sql_model_class.metadata.create_all(self.engine) + def _messages_query(self) -> Select: + """Construct an SQLAlchemy selectable to query for messages""" + return ( + select(self.sql_model_class) + .where( + getattr(self.sql_model_class, self.session_id_field_name) + == self.session_id + ) + .order_by(self.sql_model_class.id.asc()) + ) + @property def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve all messages from db""" with self.Session() as session: - result = ( - session.query(self.sql_model_class) - .where( - getattr(self.sql_model_class, self.session_id_field_name) - == self.session_id - ) - .order_by(self.sql_model_class.id.asc()) - ) + result = session.execute(self._messages_query()) messages = [] for record in result: - messages.append(self.converter.from_sql_model(record)) + messages.append(self.converter.from_sql_model(record[0])) return messages def add_message(self, message: BaseMessage) -> None: diff --git a/libs/langchain/langchain/vectorstores/__init__.py b/libs/langchain/langchain/vectorstores/__init__.py index 6c79ccfdf40e6..813535fc22fdb 100644 --- a/libs/langchain/langchain/vectorstores/__init__.py +++ b/libs/langchain/langchain/vectorstores/__init__.py @@ -134,6 +134,12 @@ def _import_clickhouse_settings() -> Any: return ClickhouseSettings +def _import_cratedb() -> Any: + from langchain.vectorstores.cratedb import CrateDBVectorSearch + + return CrateDBVectorSearch + + def _import_dashvector() -> Any: from langchain.vectorstores.dashvector import DashVector @@ -465,6 +471,8 @@ def __getattr__(name: str) -> Any: return _import_clickhouse_settings() elif name == "Clickhouse": return _import_clickhouse() + elif name == "CrateDBVectorSearch": + return _import_cratedb() elif name == "DashVector": return _import_dashvector() elif name == "DatabricksVectorSearch": @@ -582,6 +590,7 @@ def __getattr__(name: str) -> Any: "Clarifai", "Clickhouse", "ClickhouseSettings", + "CrateDBVectorSearch", "DashVector", "DatabricksVectorSearch", "DeepLake", diff --git a/libs/langchain/langchain/vectorstores/cratedb/__init__.py b/libs/langchain/langchain/vectorstores/cratedb/__init__.py new file mode 100644 index 0000000000000..3b3e23270ec1b --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/__init__.py @@ -0,0 +1,8 @@ +from .base import CrateDBVectorSearch, StorageStrategy +from .extended import CrateDBVectorSearchMultiCollection + +__all__ = [ + "CrateDBVectorSearch", + "CrateDBVectorSearchMultiCollection", + "StorageStrategy", +] diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py new file mode 100644 index 0000000000000..9257ce2772c99 --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +import enum +import math +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, +) + +import sqlalchemy +from cratedb_toolkit.sqlalchemy.patch import patch_inspector +from cratedb_toolkit.sqlalchemy.polyfill import ( + polyfill_refresh_after_dml, + refresh_table, +) +from sqlalchemy.orm import sessionmaker + +from langchain.docstore.document import Document +from langchain.schema.embeddings import Embeddings +from langchain.utils import get_from_dict_or_env +from langchain.vectorstores.cratedb.model import ModelFactory +from langchain.vectorstores.pgvector import PGVector + + +class DistanceStrategy(str, enum.Enum): + """Enumerator of similarity distance strategies.""" + + EUCLIDEAN = "euclidean" + COSINE = "cosine" + MAX_INNER_PRODUCT = "inner" + + +class StorageStrategy(enum.Enum): + """Enumerator of storage strategies.""" + + # This storage strategy reflects the vanilla way the pgvector adapter manages + # the data model: There is a single `collection` table and a single + # `embedding` table. + LANGCHAIN_PGVECTOR = "langchain_pgvector" + + # This storage strategy reflects a more advanced way to manage + # the data model: There is a single `collection` table, and multiple + # `embedding` tables, one per collection. + EMBEDDING_TABLE_PER_COLLECTION = "embedding_table_per_collection" + + +DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN +DEFAULT_STORAGE_STRATEGY = StorageStrategy.LANGCHAIN_PGVECTOR + + +_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" + + +def _results_to_docs(docs_and_scores: Any) -> List[Document]: + """Return docs from docs and scores.""" + return [doc for doc, _ in docs_and_scores] + + +class CrateDBVectorSearch(PGVector): + """`CrateDB` vector store. + + To use it, you should have the ``crate[sqlalchemy]`` python package installed. + + Args: + connection_string: Database connection string. + embedding_function: Any embedding function implementing + `langchain.embeddings.base.Embeddings` interface. + collection_name: The name of the collection to use. (default: langchain) + NOTE: This is not the name of the table, but the name of the collection. + The tables will be created when initializing the store (if not exists) + So, make sure the user has the right permissions to create tables. + distance_strategy: The distance strategy to use. (default: EUCLIDEAN) + pre_delete_collection: If True, will delete the collection if it exists. + (default: False). Useful for testing. + + Example: + .. code-block:: python + + from langchain.vectorstores import CrateDBVectorSearch + from langchain.embeddings.openai import OpenAIEmbeddings + + CONNECTION_STRING = "crate://crate@localhost:4200/test3" + COLLECTION_NAME = "state_of_the_union_test" + embeddings = OpenAIEmbeddings() + vectorestore = CrateDBVectorSearch.from_documents( + embedding=embeddings, + documents=docs, + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + ) + + """ + + # Select storage strategy: Either use two database tables (`collection` + # and `embedding`), or multiple `embedding` tables, one per collection. + STORAGE_STRATEGY: StorageStrategy = DEFAULT_STORAGE_STRATEGY + + @classmethod + def configure( + cls, storage_strategy: StorageStrategy = DEFAULT_STORAGE_STRATEGY + ) -> None: + cls.STORAGE_STRATEGY = storage_strategy + + def __post_init__( + self, + ) -> None: + """ + Initialize the store. + """ + + # FIXME: Could be a bug in CrateDB SQLAlchemy dialect. + patch_inspector() + + self._engine = self.create_engine() + self.Session = sessionmaker(self._engine) + + # TODO: See what can be improved here. + polyfill_refresh_after_dml(self.Session) + + # Need to defer initialization, because dimension size + # can only be figured out at runtime. + self.BaseModel = None + self.CollectionStore = None # type: ignore[assignment] + self.EmbeddingStore = None # type: ignore[assignment] + + def __del__(self) -> None: + """ + Work around premature session close. + + sqlalchemy.orm.exc.DetachedInstanceError: Parent instance is not bound + to a Session; lazy load operation of attribute 'embeddings' cannot proceed. + -- https://docs.sqlalchemy.org/en/20/errors.html#error-bhk3 + + TODO: Review! + """ # noqa: E501 + pass + + def _init_models(self, embedding: List[float]) -> None: + """ + Initialize SQLAlchemy model classes, using dimensionality from given vector. + + It will only initialize the model classes once. + """ + + # TODO: Use a better way to run this only once. + if self.CollectionStore is not None and self.EmbeddingStore is not None: + return + + size = len(embedding) + self._init_models_with_dimensionality(size=size) + + def _init_models_with_dimensionality(self, size: int) -> None: + """ + Initialize SQLAlchemy model classes, using given dimensionality value. + """ + mf = self._get_model_factory(size=size) + self.BaseModel, self.CollectionStore, self.EmbeddingStore = ( + mf.BaseModel, # type: ignore[assignment] + mf.CollectionStore, + mf.EmbeddingStore, + ) + + def _get_model_factory(self, size: Optional[int] = None) -> ModelFactory: + """ + Initialize SQLAlchemy model classes, based on the selected storage strategy. + """ + if self.STORAGE_STRATEGY is StorageStrategy.EMBEDDING_TABLE_PER_COLLECTION: + mf = ModelFactory( + dimensions=size, embedding_table=f"embedding_{self.collection_name}" + ) + else: + mf = ModelFactory(dimensions=size) + return mf + + def get_collection(self, session: sqlalchemy.orm.Session) -> Any: + if self.CollectionStore is None: + raise RuntimeError( + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" + ) + return self.CollectionStore.get_by_name(session, self.collection_name) + + def add_embeddings( + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Add embeddings to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + embeddings: List of list of embedding vectors. + metadatas: List of metadatas associated with the texts. + kwargs: vectorstore specific parameters + """ + if not embeddings: + return [] + self._init_models(embeddings[0]) + + # When the user requested to delete the collection before running subsequent + # operations on it, run the deletion gracefully if the table does not exist + # yet. + if self.pre_delete_collection: + try: + self.delete_collection() + except sqlalchemy.exc.ProgrammingError as ex: + if "RelationUnknown" not in str(ex): + raise + + # Tables need to be created at runtime, because the `EmbeddingStore.embedding` + # field, a `FloatVector`, needs to be initialized with a dimensionality + # parameter, which is only obtained at runtime. + self.create_tables_if_not_exists() + self.create_collection() + + # After setting up the table/collection at runtime, add embeddings. + embedding_ids = super().add_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + refresh_table(self.Session(), self.EmbeddingStore) + return embedding_ids + + def create_tables_if_not_exists(self) -> None: + """ + Need to overwrite because this `Base` is different from parent's `Base`. + """ + if self.BaseModel is None: + raise RuntimeError("Storage models not initialized") + self.BaseModel.metadata.create_all(self._engine) + + def drop_tables(self) -> None: + """ + Need to overwrite because this `Base` is different from parent's `Base`. + """ + mf = self._get_model_factory() + mf.Base.metadata.drop_all(self._engine) + + def delete( + self, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """ + Delete vectors by ids or uuids. + + Remark: Specialized for CrateDB to synchronize data. + + Args: + ids: List of ids to delete. + + Remark: Patch for CrateDB needs to overwrite this, in order to + add a "REFRESH TABLE" statement afterwards. The other + patch, listening to `after_delete` events seems not be + able to catch it. + """ + super().delete(ids=ids, **kwargs) + + # CrateDB: Synchronize data because `on_flush` does not catch it. + with self.Session() as session: + refresh_table(session, self.EmbeddingStore) + + @property + def distance_strategy(self) -> Any: + if self._distance_strategy == DistanceStrategy.EUCLIDEAN: + return self.EmbeddingStore.embedding.euclidean_distance + elif self._distance_strategy == DistanceStrategy.COSINE: + raise NotImplementedError("Cosine similarity not implemented yet") + elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + raise NotImplementedError("Dot-product similarity not implemented yet") + else: + raise ValueError( + f"Got unexpected value for distance: {self._distance_strategy}. " + f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}." + ) + + def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]: + """Return docs and scores from results.""" + docs = [ + ( + Document( + page_content=result.EmbeddingStore.document, + metadata=result.EmbeddingStore.cmetadata, + ), + result._score if self.embedding_function is not None else None, + ) + for result in results + ] + return docs + + def _query_collection( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: + """Query the collection.""" + self._init_models(embedding) + with self.Session() as session: + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + return self._query_collection_multi( + collections=[collection], embedding=embedding, k=k, filter=filter + ) + + def _query_collection_multi( + self, + collections: List[Any], + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: + """Query the collection.""" + self._init_models(embedding) + + collection_names = [coll.name for coll in collections] + collection_uuids = [coll.uuid for coll in collections] + self.logger.info(f"Querying collections: {collection_names}") + + with self.Session() as session: + filter_by = self.EmbeddingStore.collection_id.in_(collection_uuids) + + if filter is not None: + filter_clauses = [] + for key, value in filter.items(): + IN = "in" + if isinstance(value, dict) and IN in map(str.lower, value): + value_case_insensitive = { + k.lower(): v for k, v in value.items() + } + filter_by_metadata = self.EmbeddingStore.cmetadata[key].in_( + value_case_insensitive[IN] + ) + filter_clauses.append(filter_by_metadata) + else: + filter_by_metadata = self.EmbeddingStore.cmetadata[key] == str( + value + ) # type: ignore[assignment] + filter_clauses.append(filter_by_metadata) + + filter_by = sqlalchemy.and_(filter_by, *filter_clauses) # type: ignore[assignment] + + _type = self.EmbeddingStore + + results: List[Any] = ( + session.query( # type: ignore[attr-defined] + self.EmbeddingStore, + # TODO: Original pgvector code uses `self.distance_strategy`. + # CrateDB currently only supports EUCLIDEAN. + # self.distance_strategy(embedding).label("distance") + sqlalchemy.literal_column( + f"{self.EmbeddingStore.__tablename__}._score" + ).label("_score"), + ) + .filter(filter_by) + # CrateDB applies `KNN_MATCH` within the `WHERE` clause. + .filter( + sqlalchemy.func.knn_match( + self.EmbeddingStore.embedding, embedding, k + ) + ) + .order_by(sqlalchemy.desc("_score")) + .join( + self.CollectionStore, + self.EmbeddingStore.collection_id == self.CollectionStore.uuid, + ) + .limit(k) + ) + return results + + @classmethod + def from_texts( # type: ignore[override] + cls: Type[CrateDBVectorSearch], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> CrateDBVectorSearch: + """ + Return VectorStore initialized from texts and embeddings. + Database connection string is required. + + Either pass it as a parameter, or set the CRATEDB_CONNECTION_STRING + environment variable. + + Remark: Needs to be overwritten, because CrateDB uses a different + DEFAULT_DISTANCE_STRATEGY. + """ + return super().from_texts( # type: ignore[return-value] + texts, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, # type: ignore[arg-type] + pre_delete_collection=pre_delete_collection, + **kwargs, + ) + + @classmethod + def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: + connection_string: str = get_from_dict_or_env( + data=kwargs, + key="connection_string", + env_key="CRATEDB_CONNECTION_STRING", + ) + + if not connection_string: + raise ValueError( + "Database connection string is required." + "Either pass it as a parameter, or set the " + "CRATEDB_CONNECTION_STRING environment variable." + ) + + return connection_string + + @classmethod + def connection_string_from_db_params( + cls, + driver: str, + host: str, + port: int, + database: str, + user: str, + password: str, + ) -> str: + """Return connection string from database parameters.""" + return str( + sqlalchemy.URL.create( + drivername=driver, + host=host, + port=port, + username=user, + password=password, + query={"schema": database}, + ) + ) + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + The 'correct' relevance function + may differ depending on a few things, including: + - the distance / similarity metric used by the VectorStore + - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) + - embedding dimensionality + - etc. + """ + if self.override_relevance_score_fn is not None: + return self.override_relevance_score_fn + + # Default strategy is to rely on distance strategy provided + # in vectorstore constructor + if self._distance_strategy == DistanceStrategy.COSINE: + return self._cosine_relevance_score_fn + elif self._distance_strategy == DistanceStrategy.EUCLIDEAN: + return self._euclidean_relevance_score_fn + elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + return self._max_inner_product_relevance_score_fn + else: + raise ValueError( + "No supported normalization function for distance_strategy of " + f"{self._distance_strategy}. Consider providing relevance_score_fn to " + "CrateDBVectorSearch constructor." + ) + + @staticmethod + def _euclidean_relevance_score_fn(score: float) -> float: + """Return a similarity score on a scale [0, 1].""" + # The 'correct' relevance function + # may differ depending on a few things, including: + # - the distance / similarity metric used by the VectorStore + # - the scale of your embeddings (OpenAI's are unit normed. Many + # others are not!) + # - embedding dimensionality + # - etc. + # This function converts the euclidean norm of normalized embeddings + # (0 is most similar, sqrt(2) most dissimilar) + # to a similarity function (0 to 1) + + # Original: + # return 1.0 - distance / math.sqrt(2) + return score / math.sqrt(2) diff --git a/libs/langchain/langchain/vectorstores/cratedb/extended.py b/libs/langchain/langchain/vectorstores/cratedb/extended.py new file mode 100644 index 0000000000000..553ce281c5e24 --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/extended.py @@ -0,0 +1,103 @@ +import logging +from typing import ( + Any, + Callable, + Dict, + List, + Optional, +) + +import sqlalchemy +from sqlalchemy.orm import sessionmaker + +from langchain.schema.embeddings import Embeddings +from langchain.vectorstores.cratedb.base import ( + DEFAULT_DISTANCE_STRATEGY, + CrateDBVectorSearch, + DistanceStrategy, + StorageStrategy, +) +from langchain.vectorstores.pgvector import _LANGCHAIN_DEFAULT_COLLECTION_NAME + + +class CrateDBVectorSearchMultiCollection(CrateDBVectorSearch): + """ + Provide functionality for searching multiple collections. + It can not be used for indexing documents. + + To use it, you should have the ``crate[sqlalchemy]`` Python package installed. + + Synopsis:: + + from langchain.vectorstores.cratedb import CrateDBVectorSearchMultiCollection + + multisearch = CrateDBVectorSearchMultiCollection( + collection_names=["collection_foo", "collection_bar"], + embedding_function=embeddings, + connection_string=CONNECTION_STRING, + ) + docs_with_score = multisearch.similarity_search_with_score(query) + """ + + def __init__( + self, + connection_string: str, + embedding_function: Embeddings, + collection_names: List[str] = [_LANGCHAIN_DEFAULT_COLLECTION_NAME], + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, # type: ignore[arg-type] + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + *, + connection: Optional[sqlalchemy.engine.Connection] = None, + engine_args: Optional[dict[str, Any]] = None, + ) -> None: + # Sanity checks. + # TODO: The CrateDBVectorSearchMultiCollection access variant needs further + # adjustments to support the EMBEDDING_TABLE_PER_COLLECTION storage + # strategy. + if self.STORAGE_STRATEGY is not StorageStrategy.LANGCHAIN_PGVECTOR: + raise NotImplementedError( + f"Multi-collection querying not supported " + f"by strategy: {self.STORAGE_STRATEGY}" + ) + + self.connection_string = connection_string + self.embedding_function = embedding_function + self.collection_names = collection_names + self._distance_strategy = distance_strategy # type: ignore[assignment] + self.logger = logger or logging.getLogger(__name__) + self.override_relevance_score_fn = relevance_score_fn + self.engine_args = engine_args or {} + # Create a connection if not provided, otherwise use the provided connection + self._engine = self.create_engine() + self.Session = sessionmaker(self._engine) + self._conn = connection if connection else self.connect() + self.__post_init__() + + @classmethod + def _from(cls, *args: List, **kwargs: Dict): # type: ignore[no-untyped-def,override] + raise NotImplementedError("This adapter can not be used for indexing documents") + + def get_collections(self, session: sqlalchemy.orm.Session) -> Any: + if self.CollectionStore is None: + raise RuntimeError( + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" + ) + return self.CollectionStore.get_by_names(session, self.collection_names) + + def _query_collection( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: + """Query multiple collections.""" + self._init_models(embedding) + with self.Session() as session: + collections = self.get_collections(session) + if not collections: + raise ValueError("No collections found") + return self._query_collection_multi( + collections=collections, embedding=embedding, k=k, filter=filter + ) diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py new file mode 100644 index 0000000000000..167b75ba23e13 --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/model.py @@ -0,0 +1,117 @@ +import uuid +from typing import Any, List, Optional, Tuple + +import sqlalchemy +from crate.client.sqlalchemy.types import ObjectType +from sqlalchemy.orm import Session, declarative_base, relationship + +from langchain.vectorstores.cratedb.sqlalchemy_type import FloatVector + + +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +class ModelFactory: + """Provide SQLAlchemy model objects at runtime.""" + + def __init__( + self, + dimensions: Optional[int] = None, + collection_table: Optional[str] = None, + embedding_table: Optional[str] = None, + ): + # While it does not have any function here, you will still need to supply a + # dummy dimension size value for operations like deleting records. + self.dimensions = dimensions or 1024 + + # Set default values for table names. + collection_table = collection_table or "collection" + embedding_table = embedding_table or "embedding" + + Base: Any = declarative_base(class_registry=dict()) + + # Optional: Use a custom schema for the langchain tables. + # Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any + + class BaseModel(Base): + """Base model for the SQL stores.""" + + __abstract__ = True + uuid = sqlalchemy.Column( + sqlalchemy.String, primary_key=True, default=generate_uuid + ) + + class CollectionStore(BaseModel): + """Collection store.""" + + __tablename__ = collection_table + + name = sqlalchemy.Column(sqlalchemy.String) + cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType) + + embeddings = relationship( + "EmbeddingStore", + back_populates="collection", + cascade="all, delete-orphan", + passive_deletes=False, + ) + + @classmethod + def get_by_name(cls, session: Session, name: str) -> "CollectionStore": + return session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] + + @classmethod + def get_by_names( + cls, session: Session, names: List[str] + ) -> List["CollectionStore"]: + return session.query(cls).filter(cls.name.in_(names)).all() # type: ignore[attr-defined] + + @classmethod + def get_or_create( + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True + if the collection was created. + """ + created = False + collection = cls.get_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + session.commit() + created = True + return collection, created + + class EmbeddingStore(BaseModel): + """Embedding store.""" + + __tablename__ = embedding_table + + collection_id = sqlalchemy.Column( + sqlalchemy.String, + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship("CollectionStore", back_populates="embeddings") + + embedding = sqlalchemy.Column(FloatVector(self.dimensions)) + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True) + + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) + + self.Base = Base + self.BaseModel = BaseModel + self.CollectionStore = CollectionStore + self.EmbeddingStore = EmbeddingStore diff --git a/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py b/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py new file mode 100644 index 0000000000000..e784c3013a3d9 --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py @@ -0,0 +1,84 @@ +# TODO: Refactor to CrateDB SQLAlchemy dialect. +import typing as t + +import numpy as np +import numpy.typing as npt +import sqlalchemy as sa +from sqlalchemy.types import UserDefinedType + +__all__ = ["FloatVector"] + + +def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]: + # from `pgvector.utils` + # could be ndarray if already cast by lower-level driver + if value is None or isinstance(value, np.ndarray): + return value + + return np.array(value, dtype=np.float32) + + +def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]: + # from `pgvector.utils` + if value is None: + return value + + if isinstance(value, np.ndarray): + if value.ndim != 1: + raise ValueError("expected ndim to be 1") + + if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype( + value.dtype, np.floating + ): + raise ValueError("dtype must be numeric") + + value = value.tolist() + + if dim is not None and len(value) != dim: + raise ValueError("expected %d dimensions, not %d" % (dim, len(value))) + + return value + + +class FloatVector(UserDefinedType): + """ + https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector + https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match + """ + + cache_ok = True + + def __init__(self, dim: t.Optional[int] = None) -> None: + super(UserDefinedType, self).__init__() + self.dim = dim + + def get_col_spec(self, **kw: t.Any) -> str: + if self.dim is None: + return "FLOAT_VECTOR" + return "FLOAT_VECTOR(%d)" % self.dim + + def bind_processor(self, dialect: sa.Dialect) -> t.Callable: + def process(value: t.Iterable) -> t.Optional[t.List]: + return to_db(value, self.dim) + + return process + + def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable: + def process(value: t.Any) -> t.Optional[npt.ArrayLike]: + return from_db(value) + + return process + + """ + CrateDB currently only supports similarity function `VectorSimilarityFunction.EUCLIDEAN`. + -- https://github.com/crate/crate/blob/1ca5c6dbb2/server/src/main/java/io/crate/types/FloatVectorType.java#L55 + + On the other hand, pgvector use a comparator to apply different similarity functions as operators, + see `pgvector.sqlalchemy.Vector.comparator_factory`. + + <->: l2/euclidean_distance + <#>: max_inner_product + <=>: cosine_distance + + TODO: Discuss. + """ # noqa: E501 diff --git a/libs/langchain/langchain/vectorstores/pgvector.py b/libs/langchain/langchain/vectorstores/pgvector.py index e1a58e81afdde..030f1962e43ae 100644 --- a/libs/langchain/langchain/vectorstores/pgvector.py +++ b/libs/langchain/langchain/vectorstores/pgvector.py @@ -23,12 +23,12 @@ import sqlalchemy from sqlalchemy import delete from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Session try: from sqlalchemy.orm import declarative_base except ImportError: from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session, sessionmaker from langchain_core.documents import Document from langchain_core.embeddings import Embeddings @@ -130,6 +130,8 @@ def __init__( self.override_relevance_score_fn = relevance_score_fn self.engine_args = engine_args or {} # Create a connection if not provided, otherwise use the provided connection + self._engine = self.create_engine() + self.Session = sessionmaker(self._engine) self._conn = connection if connection else self.connect() self.__post_init__() @@ -156,14 +158,15 @@ def __del__(self) -> None: def embeddings(self) -> Embeddings: return self.embedding_function + def create_engine(self) -> sqlalchemy.Engine: + return sqlalchemy.create_engine(self.connection_string, echo=False) + def connect(self) -> sqlalchemy.engine.Connection: - engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args) - conn = engine.connect() - return conn + return self._engine.connect() def create_vector_extension(self) -> None: try: - with Session(self._conn) as session: + with self.Session() as session: # The advisor lock fixes issue arising from concurrent # creation of the vector extension. # https://github.com/langchain-ai/langchain/issues/12933 @@ -181,24 +184,22 @@ def create_vector_extension(self) -> None: raise Exception(f"Failed to create vector extension: {e}") from e def create_tables_if_not_exists(self) -> None: - with self._conn.begin(): - Base.metadata.create_all(self._conn) + Base.metadata.create_all(self._engine) def drop_tables(self) -> None: - with self._conn.begin(): - Base.metadata.drop_all(self._conn) + Base.metadata.drop_all(self._engine) def create_collection(self) -> None: if self.pre_delete_collection: self.delete_collection() - with Session(self._conn) as session: + with self.Session() as session: self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) def delete_collection(self) -> None: self.logger.debug("Trying to delete collection") - with Session(self._conn) as session: + with self.Session() as session: collection = self.get_collection(session) if not collection: self.logger.warning("Collection not found") @@ -209,7 +210,7 @@ def delete_collection(self) -> None: @contextlib.contextmanager def _make_session(self) -> Generator[Session, None, None]: """Create a context manager for the session, bind to _conn string.""" - yield Session(self._conn) + yield self.Session() def delete( self, @@ -221,7 +222,7 @@ def delete( Args: ids: List of ids to delete. """ - with Session(self._conn) as session: + with self.Session() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -237,7 +238,7 @@ def get_collection(self, session: Session) -> Optional["CollectionStore"]: return self.CollectionStore.get_by_name(session, self.collection_name) @classmethod - def __from( + def _from( cls, texts: List[str], embeddings: List[List[float]], @@ -295,10 +296,11 @@ def add_embeddings( if not metadatas: metadatas = [{} for _ in texts] - with Session(self._conn) as session: + with self.Session() as session: collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") + documents = [] for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): embedding_store = self.EmbeddingStore( embedding=embedding, @@ -307,7 +309,8 @@ def add_embeddings( custom_id=id, collection_id=collection.uuid, ) - session.add(embedding_store) + documents.append(embedding_store) + session.bulk_save_objects(documents) session.commit() return ids @@ -400,7 +403,7 @@ def similarity_search_with_score_by_vector( k: int = 4, filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: - results = self.__query_collection(embedding=embedding, k=k, filter=filter) + results = self._query_collection(embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results) @@ -418,14 +421,14 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa ] return docs - def __query_collection( + def _query_collection( self, embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, ) -> List[Any]: """Query the collection.""" - with Session(self._conn) as session: + with self.Session() as session: collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -512,7 +515,7 @@ def from_texts( """ embeddings = embedding.embed_documents(list(texts)) - return cls.__from( + return cls._from( texts, embeddings, embedding, @@ -557,7 +560,7 @@ def from_embeddings( texts = [t[0] for t in text_embeddings] embeddings = [t[1] for t in text_embeddings] - return cls.__from( + return cls._from( texts, embeddings, embedding, @@ -718,7 +721,7 @@ def max_marginal_relevance_search_with_score_by_vector( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) + results = self._query_collection(embedding=embedding, k=fetch_k, filter=filter) embedding_list = [result.EmbeddingStore.embedding for result in results] diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index b9bf2b0546d1c..30b515c099718 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -147,6 +147,8 @@ praw = {version = "^7.7.1", optional = true} msal = {version = "^1.25.0", optional = true} databricks-vectorsearch = {version = "^0.21", optional = true} dgml-utils = {version = "^0.3.0", optional = true} +crate = {version = "^0.34.0", extras=["sqlalchemy"], optional = true} +cratedb-toolkit = {version = ">=0.0.1", optional = true} [tool.poetry.group.test.dependencies] # The only dependencies that should be added are @@ -167,6 +169,7 @@ pytest-socket = "^0.6.0" syrupy = "^4.0.2" requests-mock = "^1.11.0" langchain-core = {path = "../core", develop = true} +sqlparse = "^0.4.4" [tool.poetry.group.codespell.dependencies] codespell = "^2.2.0" @@ -229,6 +232,7 @@ cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] javascript = ["esprima"] +cratedb = ["crate", "cratedb-toolkit"] azure = [ "azure-identity", "azure-cosmos", @@ -315,6 +319,8 @@ all = [ "librosa", "python-arango", "dgml-utils", + "crate", + "cratedb-toolkit", ] cli = [ @@ -388,6 +394,8 @@ extended_testing = [ "databricks-vectorsearch", "dgml-utils", "cohere", + "crate", + "cratedb-toolkit", ] [tool.ruff] diff --git a/libs/langchain/tests/data.py b/libs/langchain/tests/data.py index c3b240bbc57cd..be48867430641 100644 --- a/libs/langchain/tests/data.py +++ b/libs/langchain/tests/data.py @@ -9,3 +9,7 @@ HELLO_PDF = _EXAMPLES_DIR / "hello.pdf" LAYOUT_PARSER_PAPER_PDF = _EXAMPLES_DIR / "layout-parser-paper.pdf" DUPLICATE_CHARS = _EXAMPLES_DIR / "duplicate-chars.pdf" + +# Paths to data files +MLB_TEAMS_2012_CSV = _EXAMPLES_DIR / "mlb_teams_2012.csv" +MLB_TEAMS_2012_SQL = _EXAMPLES_DIR / "mlb_teams_2012.sql" diff --git a/libs/langchain/tests/integration_tests/document_loaders/docker-compose/cratedb.yml b/libs/langchain/tests/integration_tests/document_loaders/docker-compose/cratedb.yml new file mode 100644 index 0000000000000..b547b2c766f20 --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/docker-compose/cratedb.yml @@ -0,0 +1,20 @@ +version: "3" + +services: + postgresql: + image: crate/crate:nightly + environment: + - CRATE_HEAP_SIZE=4g + ports: + - "4200:4200" + - "5432:5432" + command: | + crate -Cdiscovery.type=single-node + healthcheck: + test: + [ + "CMD-SHELL", + "curl --silent --fail http://localhost:4200/ || exit 1", + ] + interval: 5s + retries: 60 diff --git a/libs/langchain/tests/integration_tests/document_loaders/docker-compose/postgresql.yml b/libs/langchain/tests/integration_tests/document_loaders/docker-compose/postgresql.yml new file mode 100644 index 0000000000000..f8ab2cfdeb418 --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/docker-compose/postgresql.yml @@ -0,0 +1,19 @@ +version: "3" + +services: + postgresql: + image: postgres:16 + environment: + - POSTGRES_HOST_AUTH_METHOD=trust + ports: + - "5432:5432" + command: | + postgres -c log_statement=all + healthcheck: + test: + [ + "CMD-SHELL", + "psql postgresql://postgres@localhost --command 'SELECT 1;' || exit 1", + ] + interval: 5s + retries: 60 diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py new file mode 100644 index 0000000000000..eec3a428a74e8 --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py @@ -0,0 +1,146 @@ +""" +Test SQLAlchemy/CrateDB document loader functionality. + +cd tests/integration_tests/document_loaders/docker-compose +docker-compose -f cratedb.yml up +""" +import logging +import os +import unittest + +import pytest +import sqlalchemy as sa +import sqlparse + +from langchain.document_loaders import CrateDBLoader +from tests.data import MLB_TEAMS_2012_SQL + +logging.basicConfig(level=logging.DEBUG) + +try: + import crate.client.sqlalchemy # noqa: F401 + + crate_client_installed = True +except ImportError: + crate_client_installed = False + + +CONNECTION_STRING = os.environ.get( + "TEST_CRATEDB_CONNECTION_STRING", "crate://crate@localhost/?schema=testdrive" +) + + +@pytest.fixture +def engine() -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(CONNECTION_STRING, echo=False) + + +@pytest.fixture() +def provision_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + sql_statements = MLB_TEAMS_2012_SQL.read_text() + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) + connection.execute(sa.text("REFRESH TABLE mlb_teams_2012;")) + connection.commit() + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_no_options() -> None: + """Test SQLAlchemy loader with CrateDB.""" + + loader = CrateDBLoader("SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {} + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_page_content_columns() -> None: + """Test SQLAlchemy loader with CrateDB.""" + + loader = CrateDBLoader( + "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", + url=CONNECTION_STRING, + page_content_columns=["a"], + ) + docs = loader.load() + + assert len(docs) == 2 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {} + + assert docs[1].page_content == "a: 3" + assert docs[1].metadata == {} + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_metadata_columns() -> None: + """Test SQLAlchemy loader with CrateDB.""" + + loader = CrateDBLoader( + "SELECT 1 AS a, 2 AS b", + url=CONNECTION_STRING, + page_content_columns=["a"], + metadata_columns=["b"], + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {"b": 2} + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_real_data_with_sql(provision_database: None) -> None: + """Test SQLAlchemy loader with CrateDB.""" + + loader = CrateDBLoader( + query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', + url=CONNECTION_STRING, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == {} + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_real_data_with_selectable(provision_database: None) -> None: + """Test SQLAlchemy loader with CrateDB.""" + + # Define an SQLAlchemy table. + mlb_teams_2012 = sa.Table( + "mlb_teams_2012", + sa.MetaData(), + sa.Column("Team", sa.VARCHAR), + sa.Column("Payroll (millions)", sa.FLOAT), + sa.Column("Wins", sa.BIGINT), + ) + + # Query the database table using an SQLAlchemy selectable. + select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) + loader = CrateDBLoader( + query=select, + url=CONNECTION_STRING, + include_query_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == { + "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' + 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' + 'ORDER BY mlb_teams_2012."Team"' + } diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py new file mode 100644 index 0000000000000..29f52cb9f7a33 --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py @@ -0,0 +1,177 @@ +""" +Test SQLAlchemy/PostgreSQL document loader functionality. + +cd tests/integration_tests/document_loaders/docker-compose +docker-compose -f postgresql.yml up +""" +import logging +import os +import unittest + +import pytest +import sqlalchemy as sa +import sqlparse + +from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader +from tests.data import MLB_TEAMS_2012_SQL + +logging.basicConfig(level=logging.DEBUG) + + +try: + import psycopg2 # noqa: F401 + + psycopg2_installed = True +except ImportError: + psycopg2_installed = False + + +CONNECTION_STRING = os.environ.get( + "TEST_POSTGRESQL_CONNECTION_STRING", + "postgresql+psycopg2://postgres@localhost:5432/", +) + + +@pytest.fixture +def engine() -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(CONNECTION_STRING, echo=False) + + +@pytest.fixture() +def provision_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + sql_statements = MLB_TEAMS_2012_SQL.read_text() + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) + connection.commit() + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_no_options() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader("SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_include_rownum_into_metadata() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", + url=CONNECTION_STRING, + include_rownum_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"row": 0} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_include_query_into_metadata() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING, include_query_into_metadata=True + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_page_content_columns() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", + url=CONNECTION_STRING, + page_content_columns=["a"], + ) + docs = loader.load() + + assert len(docs) == 2 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {} + + assert docs[1].page_content == "a: 3" + assert docs[1].metadata == {} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_metadata_columns() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", + url=CONNECTION_STRING, + page_content_columns=["a"], + metadata_columns=["b"], + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {"b": 2} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_real_data_with_sql(provision_database: None) -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', + url=CONNECTION_STRING, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == {} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_real_data_with_selectable(provision_database: None) -> None: + """Test SQLAlchemy loader with psycopg2.""" + + # Define an SQLAlchemy table. + mlb_teams_2012 = sa.Table( + "mlb_teams_2012", + sa.MetaData(), + sa.Column("Team", sa.VARCHAR), + sa.Column("Payroll (millions)", sa.FLOAT), + sa.Column("Wins", sa.BIGINT), + ) + + # Query the database table using an SQLAlchemy selectable. + select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) + loader = SQLAlchemyLoader( + query=select, + url=CONNECTION_STRING, + include_query_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == { + "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' + 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' + 'ORDER BY mlb_teams_2012."Team"' + } diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py new file mode 100644 index 0000000000000..f1fac2cecbc00 --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py @@ -0,0 +1,181 @@ +""" +Test SQLAlchemy/SQLite document loader functionality. +""" +import logging +import unittest + +import pytest +import sqlalchemy as sa +import sqlparse +from _pytest.tmpdir import TempPathFactory + +from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader +from tests.data import MLB_TEAMS_2012_SQL + +logging.basicConfig(level=logging.DEBUG) + + +try: + import sqlite3 # noqa: F401 + + sqlite3_installed = True +except ImportError: + sqlite3_installed = False + + +@pytest.fixture(scope="module") +def db_uri(tmp_path_factory: TempPathFactory) -> str: + """ + Return an SQLAlchemy URI for a temporary SQLite database. + """ + db_path = tmp_path_factory.getbasetemp().joinpath("testdrive.sqlite") + return f"sqlite:///{db_path}" + + +@pytest.fixture(scope="module") +def engine(db_uri: str) -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(db_uri, echo=False) + + +@pytest.fixture() +def provision_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + sql_statements = MLB_TEAMS_2012_SQL.read_text() + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) + connection.commit() + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_no_options(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader("SELECT 1 AS a, 2 AS b", url=db_uri) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_include_rownum_into_metadata(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", + url=db_uri, + include_rownum_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"row": 0} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_include_query_into_metadata(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", url=db_uri, include_query_into_metadata=True + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_page_content_columns(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", + url=db_uri, + page_content_columns=["a"], + ) + docs = loader.load() + + assert len(docs) == 2 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {} + + assert docs[1].page_content == "a: 3" + assert docs[1].metadata == {} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_metadata_columns(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", + url=db_uri, + page_content_columns=["a"], + metadata_columns=["b"], + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {"b": 2} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_real_data_with_sql( + db_uri: str, provision_database: None +) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', + url=db_uri, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == {} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_real_data_with_selectable( + db_uri: str, provision_database: None +) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + # Define an SQLAlchemy table. + mlb_teams_2012 = sa.Table( + "mlb_teams_2012", + sa.MetaData(), + sa.Column("Team", sa.VARCHAR), + sa.Column("Payroll (millions)", sa.FLOAT), + sa.Column("Wins", sa.BIGINT), + ) + + # Query the database table using an SQLAlchemy selectable. + select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) + loader = SQLAlchemyLoader( + query=select, + url=db_uri, + include_query_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == { + "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' + 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' + 'ORDER BY mlb_teams_2012."Team"' + } diff --git a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv new file mode 100644 index 0000000000000..b22ae961a1331 --- /dev/null +++ b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv @@ -0,0 +1,32 @@ +"Team", "Payroll (millions)", "Wins" +"Nationals", 81.34, 98 +"Reds", 82.20, 97 +"Yankees", 197.96, 95 +"Giants", 117.62, 94 +"Braves", 83.31, 94 +"Athletics", 55.37, 94 +"Rangers", 120.51, 93 +"Orioles", 81.43, 93 +"Rays", 64.17, 90 +"Angels", 154.49, 89 +"Tigers", 132.30, 88 +"Cardinals", 110.30, 88 +"Dodgers", 95.14, 86 +"White Sox", 96.92, 85 +"Brewers", 97.65, 83 +"Phillies", 174.54, 81 +"Diamondbacks", 74.28, 81 +"Pirates", 63.43, 79 +"Padres", 55.24, 76 +"Mariners", 81.97, 75 +"Mets", 93.35, 74 +"Blue Jays", 75.48, 73 +"Royals", 60.91, 72 +"Marlins", 118.07, 69 +"Red Sox", 173.18, 69 +"Indians", 78.43, 68 +"Twins", 94.08, 66 +"Rockies", 78.06, 64 +"Cubs", 88.19, 61 +"Astros", 60.65, 55 + diff --git a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql new file mode 100644 index 0000000000000..9df72ef19954a --- /dev/null +++ b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql @@ -0,0 +1,41 @@ +-- Provisioning table "mlb_teams_2012". +-- +-- psql postgresql://postgres@localhost < mlb_teams_2012.sql +-- crash < mlb_teams_2012.sql + +DROP TABLE IF EXISTS mlb_teams_2012; +CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT); +INSERT INTO mlb_teams_2012 + ("Team", "Payroll (millions)", "Wins") +VALUES + ('Nationals', 81.34, 98), + ('Reds', 82.20, 97), + ('Yankees', 197.96, 95), + ('Giants', 117.62, 94), + ('Braves', 83.31, 94), + ('Athletics', 55.37, 94), + ('Rangers', 120.51, 93), + ('Orioles', 81.43, 93), + ('Rays', 64.17, 90), + ('Angels', 154.49, 89), + ('Tigers', 132.30, 88), + ('Cardinals', 110.30, 88), + ('Dodgers', 95.14, 86), + ('White Sox', 96.92, 85), + ('Brewers', 97.65, 83), + ('Phillies', 174.54, 81), + ('Diamondbacks', 74.28, 81), + ('Pirates', 63.43, 79), + ('Padres', 55.24, 76), + ('Mariners', 81.97, 75), + ('Mets', 93.35, 74), + ('Blue Jays', 75.48, 73), + ('Royals', 60.91, 72), + ('Marlins', 118.07, 69), + ('Red Sox', 173.18, 69), + ('Indians', 78.43, 68), + ('Twins', 94.08, 66), + ('Rockies', 78.06, 64), + ('Cubs', 88.19, 61), + ('Astros', 60.65, 55) +; diff --git a/libs/langchain/tests/integration_tests/memory/test_cratedb.py b/libs/langchain/tests/integration_tests/memory/test_cratedb.py new file mode 100644 index 0000000000000..2c00b5d2b200b --- /dev/null +++ b/libs/langchain/tests/integration_tests/memory/test_cratedb.py @@ -0,0 +1,170 @@ +import json +import os +from typing import Any, Generator, Tuple + +import pytest +import sqlalchemy as sa +from sqlalchemy import Column, Integer, Text +from sqlalchemy.orm import DeclarativeBase + +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories import CrateDBChatMessageHistory +from langchain.memory.chat_message_histories.sql import DefaultMessageConverter +from langchain.schema.messages import AIMessage, HumanMessage, _message_to_dict + + +@pytest.fixture() +def connection_string() -> str: + return os.environ.get( + "TEST_CRATEDB_CONNECTION_STRING", "crate://crate@localhost/?schema=testdrive" + ) + + +@pytest.fixture() +def engine(connection_string: str) -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(connection_string, echo=True) + + +@pytest.fixture(autouse=True) +def reset_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS test_table;")) + connection.commit() + + +@pytest.fixture() +def sql_histories( + connection_string: str, +) -> Generator[Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory], None, None]: + """ + Provide the test cases with data fixtures. + """ + message_history = CrateDBChatMessageHistory( + session_id="123", connection_string=connection_string, table_name="test_table" + ) + # Create history for other session + other_history = CrateDBChatMessageHistory( + session_id="456", connection_string=connection_string, table_name="test_table" + ) + + yield message_history, other_history + message_history.clear() + other_history.clear() + + +def test_add_messages( + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory] +) -> None: + history1, _ = sql_histories + history1.add_user_message("Hello!") + history1.add_ai_message("Hi there!") + + messages = history1.messages + assert len(messages) == 2 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert messages[0].content == "Hello!" + assert messages[1].content == "Hi there!" + + +def test_multiple_sessions( + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory] +) -> None: + history1, history2 = sql_histories + + # first session + history1.add_user_message("Hello!") + history1.add_ai_message("Hi there!") + history1.add_user_message("Whats cracking?") + + # second session + history2.add_user_message("Hellox") + + messages1 = history1.messages + messages2 = history2.messages + + # Ensure the messages are added correctly in the first session + assert len(messages1) == 3, "waat" + assert messages1[0].content == "Hello!" + assert messages1[1].content == "Hi there!" + assert messages1[2].content == "Whats cracking?" + + assert len(messages2) == 1 + assert len(messages1) == 3 + assert messages2[0].content == "Hellox" + assert messages1[0].content == "Hello!" + assert messages1[1].content == "Hi there!" + assert messages1[2].content == "Whats cracking?" + + +def test_clear_messages( + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory] +) -> None: + sql_history, other_history = sql_histories + sql_history.add_user_message("Hello!") + sql_history.add_ai_message("Hi there!") + assert len(sql_history.messages) == 2 + # Now create another history with different session id + other_history.add_user_message("Hellox") + assert len(other_history.messages) == 1 + assert len(sql_history.messages) == 2 + # Now clear the first history + sql_history.clear() + assert len(sql_history.messages) == 0 + assert len(other_history.messages) == 1 + + +def test_model_no_session_id_field_error(connection_string: str) -> None: + class Base(DeclarativeBase): + pass + + class Model(Base): + __tablename__ = "test_table" + id = Column(Integer, primary_key=True) + test_field = Column(Text) + + class CustomMessageConverter(DefaultMessageConverter): + def get_sql_model_class(self) -> Any: + return Model + + with pytest.raises(ValueError): + CrateDBChatMessageHistory( + "test", + connection_string, + custom_message_converter=CustomMessageConverter("test_table"), + ) + + +def test_memory_with_message_store(connection_string: str) -> None: + """ + Test ConversationBufferMemory with a message store. + """ + # Setup CrateDB as a message store. + message_history = CrateDBChatMessageHistory( + connection_string=connection_string, session_id="test-session" + ) + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + + # Add a few messages. + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + # Get the message history from the memory store and turn it into JSON. + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + # Verify the outcome. + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # Clear the conversation history, and verify that. + memory.chat_memory.clear() + assert memory.chat_memory.messages == [] diff --git a/libs/langchain/tests/integration_tests/vectorstores/docker-compose/cratedb.yml b/libs/langchain/tests/integration_tests/vectorstores/docker-compose/cratedb.yml new file mode 100644 index 0000000000000..b547b2c766f20 --- /dev/null +++ b/libs/langchain/tests/integration_tests/vectorstores/docker-compose/cratedb.yml @@ -0,0 +1,20 @@ +version: "3" + +services: + postgresql: + image: crate/crate:nightly + environment: + - CRATE_HEAP_SIZE=4g + ports: + - "4200:4200" + - "5432:5432" + command: | + crate -Cdiscovery.type=single-node + healthcheck: + test: + [ + "CMD-SHELL", + "curl --silent --fail http://localhost:4200/ || exit 1", + ] + interval: 5s + retries: 60 diff --git a/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py b/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py index 7b99c696444af..016ebe85fed18 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py +++ b/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py @@ -52,7 +52,6 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_query(self, text: str) -> List[float]: """Return consistent embeddings for the text, if seen before, or a constant one if the text is unknown.""" - return self.embed_documents([text])[0] if text not in self.known_texts: return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)] return [float(1.0)] * (self.dimensionality - 1) + [ diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py new file mode 100644 index 0000000000000..1687dd10a30d9 --- /dev/null +++ b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py @@ -0,0 +1,803 @@ +""" +Test CrateDB `FLOAT_VECTOR` / `KNN_MATCH` functionality. + +cd tests/integration_tests/vectorstores/docker-compose +docker-compose -f cratedb.yml up +""" +import os +import re +from typing import Dict, Generator, List, Tuple + +import pytest +import sqlalchemy as sa +import sqlalchemy.orm +from pytest_mock import MockerFixture +from sqlalchemy.exc import ProgrammingError +from sqlalchemy.orm import Session + +from langchain.docstore.document import Document +from langchain.vectorstores.cratedb import CrateDBVectorSearch +from langchain.vectorstores.cratedb.base import StorageStrategy +from langchain.vectorstores.cratedb.extended import CrateDBVectorSearchMultiCollection +from langchain.vectorstores.cratedb.model import ModelFactory +from tests.integration_tests.vectorstores.fake_embeddings import ( + ConsistentFakeEmbeddings, + FakeEmbeddings, +) + +SCHEMA_NAME = os.environ.get("TEST_CRATEDB_DATABASE", "testdrive") + +CONNECTION_STRING = CrateDBVectorSearch.connection_string_from_db_params( + driver=os.environ.get("TEST_CRATEDB_DRIVER", "crate"), + host=os.environ.get("TEST_CRATEDB_HOST", "localhost"), + port=int(os.environ.get("TEST_CRATEDB_PORT", "4200")), + database=SCHEMA_NAME, + user=os.environ.get("TEST_CRATEDB_USER", "crate"), + password=os.environ.get("TEST_CRATEDB_PASSWORD", ""), +) + +ADA_TOKEN_COUNT = 1536 + + +@pytest.fixture +def engine() -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(CONNECTION_STRING, echo=False) + + +@pytest.fixture +def session(engine: sa.Engine) -> Generator[sa.orm.Session, None, None]: + with engine.connect() as conn: + with Session(conn) as session: + yield session + + +@pytest.fixture(autouse=True) +def drop_tables(engine: sa.Engine) -> None: + """ + Drop relevant database tables before invoking each test case. + + TODO: Check how only those database tables can be dropped, which + have actually been used, in order to increase performance. + Alternatively, reduce the number of different table names + used for testing. + """ + + def sqlalchemy_drop_model(mf: ModelFactory) -> None: + try: + mf.BaseModel.metadata.drop_all(engine, checkfirst=False) + except Exception as ex: + if "RelationUnknown" not in str(ex): + raise + + collection_name_candidates = [ + "test_collection", + "test_collection_filter", + "test_collection_foo", + "test_collection_bar", + "test_collection_1", + "test_collection_2", + ] + + # Recycling for vanilla storage strategy. + mf = ModelFactory(embedding_table="embedding") + sqlalchemy_drop_model(mf) + + # Recycling for advanced "embedding-table-per-collection" storage strategy. + for collection_name in collection_name_candidates: + mf = ModelFactory(embedding_table=f"embedding_{collection_name}") + sqlalchemy_drop_model(mf) + + +@pytest.fixture +def prune_tables(engine: sa.Engine) -> None: + """ + Delete data from database tables. + + Note: This fixture is currently not used. + """ + with engine.connect() as conn: + with Session(conn) as session: + mf = ModelFactory() + try: + session.query(mf.CollectionStore).delete() + except ProgrammingError: + pass + try: + session.query(mf.EmbeddingStore).delete() + except ProgrammingError: + pass + + +def ensure_collection(session: sa.orm.Session, name: str) -> None: + """ + Create a (fake) collection item. + """ + embedding_table_name = "embedding" + if ( + CrateDBVectorSearch.STORAGE_STRATEGY + is StorageStrategy.EMBEDDING_TABLE_PER_COLLECTION + ): + embedding_table_name = f"embedding_{name}" + session.execute( + sa.text( + """ + CREATE TABLE IF NOT EXISTS collection ( + uuid TEXT, + name TEXT, + cmetadata OBJECT + ); + """ + ) + ) + session.execute( + sa.text( + f""" + CREATE TABLE IF NOT EXISTS {embedding_table_name} ( + uuid TEXT, + collection_id TEXT, + embedding FLOAT_VECTOR(123), + document TEXT, + cmetadata OBJECT, + custom_id TEXT + ); + """ + ) + ) + try: + session.execute( + sa.text( + f"INSERT INTO collection (uuid, name, cmetadata) " + f"VALUES ('uuid-{name}', '{name}', {{}});" + ) + ) + session.execute(sa.text("REFRESH TABLE collection")) + except sa.exc.IntegrityError: + pass + + +class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): + """Fake embeddings functionality for testing.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings.""" + return [ + [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts)) + ] + + def embed_query(self, text: str) -> List[float]: + """Return simple embeddings.""" + return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] + + +class ConsistentFakeEmbeddingsWithAdaDimension(ConsistentFakeEmbeddings): + """ + Fake embeddings which remember all the texts seen so far to return + consistent vectors for the same texts. + + Other than this, they also have a fixed dimensionality, which is + important in this case. + """ + + def __init__(self, *args: List, **kwargs: Dict) -> None: + super().__init__(dimensionality=ADA_TOKEN_COUNT) + + +@pytest.fixture +def two_stores() -> Tuple[CrateDBVectorSearch, CrateDBVectorSearch]: + """ + Provide two different vector search handles to test case functions, + associated with two different collections, correspondingly. + """ + store_foo = CrateDBVectorSearch.from_texts( + texts=["foo"], + collection_name="test_collection_foo", + collection_metadata={"category": "foo"}, + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=[{"document": "foo"}], + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + store_bar = CrateDBVectorSearch.from_texts( + texts=["bar"], + collection_name="test_collection_bar", + collection_metadata={"category": "bar"}, + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=[{"document": "bar"}], + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + return store_foo, store_bar + + +def test_cratedb_texts() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_cratedb_embedding_dimension() -> None: + """Verify the `embedding` column uses the correct vector dimensionality.""" + texts = ["foo", "bar", "baz"] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + with docsearch.Session() as session: + result = session.execute(sa.text(f"SHOW CREATE TABLE {SCHEMA_NAME}.embedding")) + record = result.first() + if not record: + raise ValueError("No data found") + ddl = record[0] + assert f'"embedding" FLOAT_VECTOR({ADA_TOKEN_COUNT})' in ddl + + +def test_cratedb_embeddings() -> None: + """Test end to end construction with embeddings and search.""" + texts = ["foo", "bar", "baz"] + text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + docsearch = CrateDBVectorSearch.from_embeddings( + text_embeddings=text_embedding_pairs, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_cratedb_with_metadatas() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": "0"})] + + +def test_cratedb_with_metadatas_with_scores() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 2.0)] + + +def test_cratedb_with_filter_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + # TODO: Original: + # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(2.2, 0.1)) + ] + + +def test_cratedb_with_filter_distant_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=2, filter={"page": "2"}) + # Original score value: 0.0013003906671379406 + assert output == [ + (Document(page_content="baz", metadata={"page": "2"}), pytest.approx(1.5, 0.2)) + ] + + +def test_cratedb_with_filter_no_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) + assert output == [] + + +@pytest.fixture +def storage_strategy_langchain_pgvector(mocker: MockerFixture) -> None: + mocker.patch( + "langchain.vectorstores.cratedb.base.CrateDBVectorSearch.STORAGE_STRATEGY", + StorageStrategy.LANGCHAIN_PGVECTOR, + ) + + +@pytest.fixture +def storage_strategy_embedding_table_per_collection(mocker: MockerFixture) -> None: + mocker.patch( + "langchain.vectorstores.cratedb.base.CrateDBVectorSearch.STORAGE_STRATEGY", + StorageStrategy.EMBEDDING_TABLE_PER_COLLECTION, + ) + + +def test_cratedb_storage_strategy_langchain_pgvector( + storage_strategy_langchain_pgvector: None, + two_stores: Tuple[CrateDBVectorSearch, CrateDBVectorSearch], +) -> None: + """ + Verify collection construction and deletion using the vanilla storage strategy. + + It uses two different collections of embeddings. By such, it proves that + the embeddings are managed according to the storage strategy. + + In this case, embeddings for multiple collections are managed on behalf of + a single database table, called `embedding`. + """ + store_foo, store_bar = two_stores + + session = store_foo.Session() + + # Verify data in database. + collection_foo = store_foo.get_collection(session) + collection_bar = store_bar.get_collection(session) + if collection_foo is None or collection_bar is None: + assert False, "Expected CollectionStore objects but received None" + assert collection_foo.embeddings[0].cmetadata == {"document": "foo"} + assert collection_bar.embeddings[0].cmetadata == {"document": "bar"} + + # Verify number of records before deletion. + assert session.query(store_foo.EmbeddingStore).count() == 2 + assert session.query(store_bar.EmbeddingStore).count() == 2 + + # Delete first collection. + store_foo.delete_collection() + + # Verify that the "foo" collection has been deleted. + collection_foo = store_foo.get_collection(session) + collection_bar = store_bar.get_collection(session) + if collection_bar is None: + assert False, "Expected CollectionStore object but received None" + assert collection_foo is None + assert collection_bar.embeddings[0].cmetadata == {"document": "bar"} + + # Verify number of records after deletion, to proof that associated + # embeddings also have been deleted. + assert session.query(store_foo.EmbeddingStore).count() == 1 + assert session.query(store_bar.EmbeddingStore).count() == 1 + + +def test_cratedb_storage_strategy_embedding_table_per_collection( + storage_strategy_embedding_table_per_collection: None, + two_stores: Tuple[CrateDBVectorSearch, CrateDBVectorSearch], +) -> None: + """ + Verify collection construction and deletion using a more advanced storage strategy. + + It uses two different collections of embeddings. By such, it proves that + the embeddings are managed according to the storage strategy. + + In this case, embeddings for multiple collections are managed on behalf of + separate database tables, called `embedding_{collection_name}`. + """ + store_foo, store_bar = two_stores + + session = store_foo.Session() + + # Verify data in database. + collection_foo = store_foo.get_collection(session) + collection_bar = store_bar.get_collection(session) + assert collection_foo.embeddings[0].cmetadata == {"document": "foo"} + assert collection_bar.embeddings[0].cmetadata == {"document": "bar"} + + # Verify number of records before deletion. + assert session.query(store_foo.EmbeddingStore).count() == 1 + assert session.query(store_bar.EmbeddingStore).count() == 1 + + # Delete first collection. + store_foo.delete_collection() + + # Verify that the "foo" collection has been deleted. + collection_foo = store_foo.get_collection(session) + collection_bar = store_bar.get_collection(session) + assert collection_foo is None + assert collection_bar.embeddings[0].cmetadata == {"document": "bar"} + + # Verify number of records after deletion, to proof that associated + # embeddings also have been deleted. + assert session.query(store_foo.EmbeddingStore).count() == 0 + assert session.query(store_bar.EmbeddingStore).count() == 1 + + +def test_cratedb_collection_with_metadata() -> None: + """Test end to end collection construction""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + cratedb_vector = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + collection_metadata={"foo": "bar"}, + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + collection = cratedb_vector.get_collection(cratedb_vector.Session()) + if collection is None: + assert False, "Expected a CollectionStore object but received None" + else: + assert collection.name == "test_collection" + assert collection.cmetadata == {"foo": "bar"} + + +def test_cratedb_collection_no_embedding_dimension() -> None: + """ + Verify that addressing collections fails when not specifying dimensions. + """ + cratedb_vector = CrateDBVectorSearch( + embedding_function=None, # type: ignore[arg-type] + connection_string=CONNECTION_STRING, + ) + session = Session(cratedb_vector.connect()) + with pytest.raises(RuntimeError) as ex: + cratedb_vector.get_collection(session) + assert ex.match( + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" + ) + + +def test_cratedb_collection_read_only(session: Session) -> None: + """ + Test using a collection, without adding any embeddings upfront. + + This happens when just invoking the "retrieval" case. + + In this scenario, embedding dimensionality needs to be figured out + from the supplied `embedding_function`. + """ + + # Create a fake collection item. + ensure_collection(session, "baz2") + + # This test case needs an embedding _with_ dimensionality. + # Otherwise, the data access layer is unable to figure it + # out at runtime. + embedding = ConsistentFakeEmbeddingsWithAdaDimension() + + vectorstore = CrateDBVectorSearch( + collection_name="baz2", + connection_string=CONNECTION_STRING, + embedding_function=embedding, + ) + output = vectorstore.similarity_search("foo", k=1) + + # No documents/embeddings have been loaded, the collection is empty. + # This is why there are also no results. + assert output == [] + + +def test_cratedb_with_filter_in_set() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score( + "foo", k=2, filter={"page": {"IN": ["0", "2"]}} + ) + # Original score values: 0.0, 0.0013003906671379406 + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(3.0, 0.1)), + (Document(page_content="baz", metadata={"page": "2"}), pytest.approx(2.2, 0.1)), + ] + + +def test_cratedb_delete_docs() -> None: + """Add and delete documents.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + ids=["1", "2", "3"], + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + docsearch.delete(["1", "2"]) + with docsearch._make_session() as session: + records = list(session.query(docsearch.EmbeddingStore).all()) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.custom_id for record in records) == ["3"] # type: ignore + + docsearch.delete(["2", "3"]) # Should not raise on missing ids + with docsearch._make_session() as session: + records = list(session.query(docsearch.EmbeddingStore).all()) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.custom_id for record in records) == [] # type: ignore + + +def test_cratedb_relevance_score() -> None: + """Test to make sure the relevance score is scaled to 0-1.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + output = docsearch.similarity_search_with_relevance_scores("foo", k=3) + # Original score values: 1.0, 0.9996744261675065, 0.9986996093328621 + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(1.4, 0.1)), + (Document(page_content="bar", metadata={"page": "1"}), pytest.approx(1.1, 0.1)), + (Document(page_content="baz", metadata={"page": "2"}), pytest.approx(0.8, 0.1)), + ] + + +def test_cratedb_retriever_search_threshold() -> None: + """Test using retriever for searching with threshold.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": 3, "score_threshold": 0.999}, + ) + output = retriever.get_relevant_documents("summer") + assert output == [ + Document(page_content="foo", metadata={"page": "0"}), + Document(page_content="bar", metadata={"page": "1"}), + ] + + +def test_cratedb_retriever_search_threshold_custom_normalization_fn() -> None: + """Test searching with threshold and custom normalization function""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + relevance_score_fn=lambda d: d * 0, + ) + + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": 3, "score_threshold": 0.5}, + ) + output = retriever.get_relevant_documents("foo") + assert output == [] + + +def test_cratedb_max_marginal_relevance_search() -> None: + """Test max marginal relevance search.""" + texts = ["foo", "bar", "baz"] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3) + assert output == [Document(page_content="foo")] + + +def test_cratedb_max_marginal_relevance_search_with_score() -> None: + """Test max marginal relevance search with relevance scores.""" + texts = ["foo", "bar", "baz"] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3) + assert output == [(Document(page_content="foo"), 2.0)] + + +def test_cratedb_multicollection_search_success() -> None: + """ + `CrateDBVectorSearchMultiCollection` provides functionality for + searching multiple collections. + """ + + store_1 = CrateDBVectorSearch.from_texts( + texts=["Räuber", "Hotzenplotz"], + collection_name="test_collection_1", + embedding=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + _ = CrateDBVectorSearch.from_texts( + texts=["John", "Doe"], + collection_name="test_collection_2", + embedding=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + # Probe the first store. + output = store_1.similarity_search("Räuber", k=1) + assert Document(page_content="Räuber") in output[:2] + output = store_1.similarity_search("Hotzenplotz", k=1) + assert Document(page_content="Hotzenplotz") in output[:2] + output = store_1.similarity_search("John Doe", k=1) + assert Document(page_content="Räuber") in output[:2] + + # Probe the multi-store. + multisearch = CrateDBVectorSearchMultiCollection( + collection_names=["test_collection_1", "test_collection_2"], + embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + ) + output = multisearch.similarity_search("Räuber Hotzenplotz", k=2) + assert Document(page_content="Räuber") in output[:2] + output = multisearch.similarity_search("John Doe", k=2) + assert Document(page_content="John") in output[:2] + + +def test_cratedb_multicollection_fail_indexing_not_permitted() -> None: + """ + `CrateDBVectorSearchMultiCollection` does not provide functionality for + indexing documents. + """ + + with pytest.raises(NotImplementedError) as ex: + CrateDBVectorSearchMultiCollection.from_texts( + texts=["foo"], + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + ) + assert ex.match("This adapter can not be used for indexing documents") + + +def test_cratedb_multicollection_search_table_does_not_exist() -> None: + """ + `CrateDBVectorSearchMultiCollection` will fail when the `collection` + table does not exist. + """ + + store = CrateDBVectorSearchMultiCollection( + collection_names=["unknown"], + embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + ) + with pytest.raises(ProgrammingError) as ex: + store.similarity_search("foo") + assert ex.match(re.escape("RelationUnknown[Relation 'collection' unknown]")) + + +def test_cratedb_multicollection_search_unknown_collection() -> None: + """ + `CrateDBVectorSearchMultiCollection` will fail when not able to identify + collections to search in. + """ + + CrateDBVectorSearch.from_texts( + texts=["Räuber", "Hotzenplotz"], + collection_name="test_collection", + embedding=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + store = CrateDBVectorSearchMultiCollection( + collection_names=["unknown"], + embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + ) + with pytest.raises(ValueError) as ex: + store.similarity_search("foo") + assert ex.match("No collections found") + + +def test_cratedb_multicollection_no_embedding_dimension() -> None: + """ + Verify that addressing collections fails when not specifying dimensions. + """ + store = CrateDBVectorSearchMultiCollection( + embedding_function=None, # type: ignore[arg-type] + connection_string=CONNECTION_STRING, + ) + session = Session(store.connect()) + with pytest.raises(RuntimeError) as ex: + store.get_collection(session) + assert ex.match( + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" + ) + + +def test_cratedb_multicollection_storage_strategy_embedding_table_per_collection() -> ( + None +): + """ + Verify that using the multi-collection querying trips corresponding safety checks, + when configured to use the `EMBEDDING_TABLE_PER_COLLECTION` storage strategy. + + They are not supported together, yet. + """ + CrateDBVectorSearchMultiCollection.configure( + storage_strategy=StorageStrategy.EMBEDDING_TABLE_PER_COLLECTION + ) + with pytest.raises(NotImplementedError) as ex: + CrateDBVectorSearchMultiCollection( + embedding_function=None, # type: ignore[arg-type] + connection_string=CONNECTION_STRING, + ) + assert ex.match( + "Multi-collection querying not supported by strategy: " + "StorageStrategy.EMBEDDING_TABLE_PER_COLLECTION" + ) diff --git a/libs/langchain/tests/unit_tests/conftest.py b/libs/langchain/tests/unit_tests/conftest.py index da45a330f50af..afb609f3ce31f 100644 --- a/libs/langchain/tests/unit_tests/conftest.py +++ b/libs/langchain/tests/unit_tests/conftest.py @@ -40,8 +40,8 @@ def test_something(): # Used to avoid repeated calls to `util.find_spec` required_pkgs_info: Dict[str, bool] = {} - only_extended = config.getoption("--only-extended") or False - only_core = config.getoption("--only-core") or False + only_extended = config.getoption("--only-extended", False) + only_core = config.getoption("--only-core", False) if only_extended and only_core: raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")