diff --git a/docs/docs/integrations/document_loaders/cratedb.ipynb b/docs/docs/integrations/document_loaders/cratedb.ipynb
new file mode 100644
index 0000000000000..78a0e19138703
--- /dev/null
+++ b/docs/docs/integrations/document_loaders/cratedb.ipynb
@@ -0,0 +1,232 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# CrateDB\n",
+ "\n",
+ "This notebook demonstrates how to load documents from a [CrateDB] database,\n",
+ "using the [SQLAlchemy] document loader.\n",
+ "\n",
+ "It loads the result of a database query with one document per row.\n",
+ "\n",
+ "[CrateDB]: https://github.com/crate/crate\n",
+ "[SQLAlchemy]: https://www.sqlalchemy.org/"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Prerequisites"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "#!pip install crash 'langchain[cratedb]'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Populate database."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001B[32mCONNECT OK\r\n",
+ "\u001B[0m\u001B[32mPSQL OK, 1 row affected (0.001 sec)\r\n",
+ "\u001B[0m\u001B[32mDELETE OK, 30 rows affected (0.008 sec)\r\n",
+ "\u001B[0m\u001B[32mINSERT OK, 30 rows affected (0.011 sec)\r\n",
+ "\u001B[0m\u001B[0m\u001B[32mCONNECT OK\r\n",
+ "\u001B[0m\u001B[32mREFRESH OK, 1 row affected (0.001 sec)\r\n",
+ "\u001B[0m\u001B[0m"
+ ]
+ }
+ ],
+ "source": [
+ "!crash < ./example_data/mlb_teams_2012.sql\n",
+ "!crash --command \"REFRESH TABLE mlb_teams_2012;\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Usage"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from langchain.document_loaders import CrateDBLoader\n",
+ "from pprint import pprint\n",
+ "\n",
+ "CONNECTION_STRING = \"crate://crate@localhost/\"\n",
+ "\n",
+ "loader = CrateDBLoader(\n",
+ " 'SELECT * FROM mlb_teams_2012 ORDER BY \"Team\" LIMIT 5;',\n",
+ " url=CONNECTION_STRING,\n",
+ ")\n",
+ "documents = loader.load()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[Document(page_content='Team: Angels\\nPayroll (millions): 154.49\\nWins: 89', metadata={}),\n",
+ " Document(page_content='Team: Astros\\nPayroll (millions): 60.65\\nWins: 55', metadata={}),\n",
+ " Document(page_content='Team: Athletics\\nPayroll (millions): 55.37\\nWins: 94', metadata={}),\n",
+ " Document(page_content='Team: Blue Jays\\nPayroll (millions): 75.48\\nWins: 73', metadata={}),\n",
+ " Document(page_content='Team: Braves\\nPayroll (millions): 83.31\\nWins: 94', metadata={})]\n"
+ ]
+ }
+ ],
+ "source": [
+ "pprint(documents)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Specifying Which Columns are Content vs Metadata"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loader = CrateDBLoader(\n",
+ " 'SELECT * FROM mlb_teams_2012 ORDER BY \"Team\" LIMIT 5;',\n",
+ " url=CONNECTION_STRING,\n",
+ " page_content_columns=[\"Team\"],\n",
+ " metadata_columns=[\"Payroll (millions)\"],\n",
+ ")\n",
+ "documents = loader.load()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[Document(page_content='Team: Angels', metadata={'Payroll (millions)': 154.49}),\n",
+ " Document(page_content='Team: Astros', metadata={'Payroll (millions)': 60.65}),\n",
+ " Document(page_content='Team: Athletics', metadata={'Payroll (millions)': 55.37}),\n",
+ " Document(page_content='Team: Blue Jays', metadata={'Payroll (millions)': 75.48}),\n",
+ " Document(page_content='Team: Braves', metadata={'Payroll (millions)': 83.31})]\n"
+ ]
+ }
+ ],
+ "source": [
+ "pprint(documents)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Adding Source to Metadata"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loader = CrateDBLoader(\n",
+ " 'SELECT * FROM mlb_teams_2012 ORDER BY \"Team\" LIMIT 5;',\n",
+ " url=CONNECTION_STRING,\n",
+ " source_columns=[\"Team\"],\n",
+ ")\n",
+ "documents = loader.load()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[Document(page_content='Team: Angels\\nPayroll (millions): 154.49\\nWins: 89', metadata={'source': 'Angels'}),\n",
+ " Document(page_content='Team: Astros\\nPayroll (millions): 60.65\\nWins: 55', metadata={'source': 'Astros'}),\n",
+ " Document(page_content='Team: Athletics\\nPayroll (millions): 55.37\\nWins: 94', metadata={'source': 'Athletics'}),\n",
+ " Document(page_content='Team: Blue Jays\\nPayroll (millions): 75.48\\nWins: 73', metadata={'source': 'Blue Jays'}),\n",
+ " Document(page_content='Team: Braves\\nPayroll (millions): 83.31\\nWins: 94', metadata={'source': 'Braves'})]\n"
+ ]
+ }
+ ],
+ "source": [
+ "pprint(documents)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql b/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql
new file mode 100644
index 0000000000000..6d94aeaa773b8
--- /dev/null
+++ b/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql
@@ -0,0 +1,41 @@
+-- Provisioning table "mlb_teams_2012".
+--
+-- crash < mlb_teams_2012.sql
+-- psql postgresql://postgres@localhost < mlb_teams_2012.sql
+
+DROP TABLE IF EXISTS mlb_teams_2012;
+CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT);
+INSERT INTO mlb_teams_2012
+ ("Team", "Payroll (millions)", "Wins")
+VALUES
+ ('Nationals', 81.34, 98),
+ ('Reds', 82.20, 97),
+ ('Yankees', 197.96, 95),
+ ('Giants', 117.62, 94),
+ ('Braves', 83.31, 94),
+ ('Athletics', 55.37, 94),
+ ('Rangers', 120.51, 93),
+ ('Orioles', 81.43, 93),
+ ('Rays', 64.17, 90),
+ ('Angels', 154.49, 89),
+ ('Tigers', 132.30, 88),
+ ('Cardinals', 110.30, 88),
+ ('Dodgers', 95.14, 86),
+ ('White Sox', 96.92, 85),
+ ('Brewers', 97.65, 83),
+ ('Phillies', 174.54, 81),
+ ('Diamondbacks', 74.28, 81),
+ ('Pirates', 63.43, 79),
+ ('Padres', 55.24, 76),
+ ('Mariners', 81.97, 75),
+ ('Mets', 93.35, 74),
+ ('Blue Jays', 75.48, 73),
+ ('Royals', 60.91, 72),
+ ('Marlins', 118.07, 69),
+ ('Red Sox', 173.18, 69),
+ ('Indians', 78.43, 68),
+ ('Twins', 94.08, 66),
+ ('Rockies', 78.06, 64),
+ ('Cubs', 88.19, 61),
+ ('Astros', 60.65, 55)
+;
diff --git a/docs/docs/integrations/document_loaders/sqlalchemy.ipynb b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb
new file mode 100644
index 0000000000000..5d603d7263c53
--- /dev/null
+++ b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb
@@ -0,0 +1,237 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# SQLAlchemy\n",
+ "\n",
+ "This notebook demonstrates how to load documents from an [SQLite] database,\n",
+ "using the [SQLAlchemy] document loader.\n",
+ "\n",
+ "It loads the result of a database query with one document per row.\n",
+ "\n",
+ "[SQLAlchemy]: https://www.sqlalchemy.org/\n",
+ "[SQLite]: https://sqlite.org/"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Prerequisites"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "#!pip install langchain termsql"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Provide input data as SQLite database."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Overwriting example.csv\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%file example.csv\n",
+ "Team,Payroll\n",
+ "Nationals,81.34\n",
+ "Reds,82.20"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Nationals|81.34\r\n",
+ "Reds|82.2\r\n"
+ ]
+ }
+ ],
+ "source": [
+ "!termsql --infile=example.csv --head --delimiter=\",\" --outfile=example.sqlite --table=payroll"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Usage"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from langchain.document_loaders import SQLAlchemyLoader\n",
+ "from pprint import pprint\n",
+ "\n",
+ "loader = SQLAlchemyLoader(\n",
+ " \"SELECT * FROM payroll\",\n",
+ " url=\"sqlite:///example.sqlite\",\n",
+ ")\n",
+ "documents = loader.load()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[Document(page_content='Team: Nationals\\nPayroll: 81.34', metadata={}),\n",
+ " Document(page_content='Team: Reds\\nPayroll: 82.2', metadata={})]\n"
+ ]
+ }
+ ],
+ "source": [
+ "pprint(documents)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Specifying Which Columns are Content vs Metadata"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loader = SQLAlchemyLoader(\n",
+ " \"SELECT * FROM payroll\",\n",
+ " url=\"sqlite:///example.sqlite\",\n",
+ " page_content_columns=[\"Team\"],\n",
+ " metadata_columns=[\"Payroll\"],\n",
+ ")\n",
+ "documents = loader.load()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[Document(page_content='Team: Nationals', metadata={'Payroll': 81.34}),\n",
+ " Document(page_content='Team: Reds', metadata={'Payroll': 82.2})]\n"
+ ]
+ }
+ ],
+ "source": [
+ "pprint(documents)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Adding Source to Metadata"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loader = SQLAlchemyLoader(\n",
+ " \"SELECT * FROM payroll\",\n",
+ " url=\"sqlite:///example.sqlite\",\n",
+ " source_columns=[\"Team\"],\n",
+ ")\n",
+ "documents = loader.load()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[Document(page_content='Team: Nationals\\nPayroll: 81.34', metadata={'source': 'Nationals'}),\n",
+ " Document(page_content='Team: Reds\\nPayroll: 82.2', metadata={'source': 'Reds'})]\n"
+ ]
+ }
+ ],
+ "source": [
+ "pprint(documents)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb b/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb
new file mode 100644
index 0000000000000..f51f5f1d63fca
--- /dev/null
+++ b/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb
@@ -0,0 +1,357 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# CrateDB Chat Message History\n",
+ "\n",
+ "This notebook demonstrates how to use the `CrateDBChatMessageHistory`\n",
+ "to manage chat history in CrateDB, for supporting conversational memory."
+ ],
+ "metadata": {
+ "collapsed": false
+ },
+ "id": "f22eab3f84cbeb37"
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Prerequisites"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "!#pip install 'langchain[cratedb]'"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Configuration\n",
+ "\n",
+ "To use the storage wrapper, you will need to configure two details.\n",
+ "\n",
+ "1. Session Id - a unique identifier of the session, like user name, email, chat id etc.\n",
+ "2. Database connection string: An SQLAlchemy-compatible URI that specifies the database\n",
+ " connection. It will be passed to SQLAlchemy create_engine function."
+ ],
+ "metadata": {
+ "collapsed": false
+ },
+ "id": "f8f2830ee9ca1e01"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "outputs": [],
+ "source": [
+ "from langchain.memory.chat_message_histories import CrateDBChatMessageHistory\n",
+ "\n",
+ "CONNECTION_STRING = \"crate://crate@localhost:4200/?schema=example\"\n",
+ "\n",
+ "chat_message_history = CrateDBChatMessageHistory(\n",
+ "\tsession_id=\"test_session\",\n",
+ "\tconnection_string=CONNECTION_STRING\n",
+ ")"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Basic Usage"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "outputs": [],
+ "source": [
+ "chat_message_history.add_user_message(\"Hello\")\n",
+ "chat_message_history.add_ai_message(\"Hi\")"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2023-08-28T10:04:38.077748Z",
+ "start_time": "2023-08-28T10:04:36.105894Z"
+ }
+ },
+ "id": "4576e914a866fb40"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 61,
+ "outputs": [
+ {
+ "data": {
+ "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]"
+ },
+ "execution_count": 61,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "chat_message_history.messages"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2023-08-28T10:04:38.929396Z",
+ "start_time": "2023-08-28T10:04:38.915727Z"
+ }
+ },
+ "id": "b476688cbb32ba90"
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Custom Storage Model\n",
+ "\n",
+ "The default data model, which stores information about conversation messages only\n",
+ "has two slots for storing message details, the session id, and the message dictionary.\n",
+ "\n",
+ "If you want to store additional information, like message date, author, language etc.,\n",
+ "please provide an implementation for a custom message converter.\n",
+ "\n",
+ "This example demonstrates how to create a custom message converter, by implementing\n",
+ "the `BaseMessageConverter` interface."
+ ],
+ "metadata": {
+ "collapsed": false
+ },
+ "id": "2e5337719d5614fd"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "outputs": [],
+ "source": [
+ "from datetime import datetime\n",
+ "from typing import Any\n",
+ "\n",
+ "from langchain.memory.chat_message_histories.cratedb import generate_autoincrement_identifier\n",
+ "from langchain.memory.chat_message_histories.sql import BaseMessageConverter\n",
+ "from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage\n",
+ "\n",
+ "import sqlalchemy as sa\n",
+ "from sqlalchemy.orm import declarative_base\n",
+ "\n",
+ "\n",
+ "Base = declarative_base()\n",
+ "\n",
+ "\n",
+ "class CustomMessage(Base):\n",
+ "\t__tablename__ = \"custom_message_store\"\n",
+ "\n",
+ "\tid = sa.Column(sa.BigInteger, primary_key=True, default=generate_autoincrement_identifier)\n",
+ "\tsession_id = sa.Column(sa.Text)\n",
+ "\ttype = sa.Column(sa.Text)\n",
+ "\tcontent = sa.Column(sa.Text)\n",
+ "\tcreated_at = sa.Column(sa.DateTime)\n",
+ "\tauthor_email = sa.Column(sa.Text)\n",
+ "\n",
+ "\n",
+ "class CustomMessageConverter(BaseMessageConverter):\n",
+ "\tdef __init__(self, author_email: str):\n",
+ "\t\tself.author_email = author_email\n",
+ "\t\n",
+ "\tdef from_sql_model(self, sql_message: Any) -> BaseMessage:\n",
+ "\t\tif sql_message.type == \"human\":\n",
+ "\t\t\treturn HumanMessage(\n",
+ "\t\t\t\tcontent=sql_message.content,\n",
+ "\t\t\t)\n",
+ "\t\telif sql_message.type == \"ai\":\n",
+ "\t\t\treturn AIMessage(\n",
+ "\t\t\t\tcontent=sql_message.content,\n",
+ "\t\t\t)\n",
+ "\t\telif sql_message.type == \"system\":\n",
+ "\t\t\treturn SystemMessage(\n",
+ "\t\t\t\tcontent=sql_message.content,\n",
+ "\t\t\t)\n",
+ "\t\telse:\n",
+ "\t\t\traise ValueError(f\"Unknown message type: {sql_message.type}\")\n",
+ "\t\n",
+ "\tdef to_sql_model(self, message: BaseMessage, session_id: str) -> Any:\n",
+ "\t\tnow = datetime.now()\n",
+ "\t\treturn CustomMessage(\n",
+ "\t\t\tsession_id=session_id,\n",
+ "\t\t\ttype=message.type,\n",
+ "\t\t\tcontent=message.content,\n",
+ "\t\t\tcreated_at=now,\n",
+ "\t\t\tauthor_email=self.author_email\n",
+ "\t\t)\n",
+ "\t\n",
+ "\tdef get_sql_model_class(self) -> Any:\n",
+ "\t\treturn CustomMessage\n",
+ "\n",
+ "\n",
+ "if __name__ == \"__main__\":\n",
+ "\n",
+ "\tBase.metadata.drop_all(bind=sa.create_engine(CONNECTION_STRING))\n",
+ "\n",
+ "\tchat_message_history = CrateDBChatMessageHistory(\n",
+ "\t\tsession_id=\"test_session\",\n",
+ "\t\tconnection_string=CONNECTION_STRING,\n",
+ "\t\tcustom_message_converter=CustomMessageConverter(\n",
+ "\t\t\tauthor_email=\"test@example.com\"\n",
+ "\t\t)\n",
+ "\t)\n",
+ "\n",
+ "\tchat_message_history.add_user_message(\"Hello\")\n",
+ "\tchat_message_history.add_ai_message(\"Hi\")"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2023-08-28T10:04:41.510498Z",
+ "start_time": "2023-08-28T10:04:41.494912Z"
+ }
+ },
+ "id": "fdfde84c07d071bb"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 60,
+ "outputs": [
+ {
+ "data": {
+ "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]"
+ },
+ "execution_count": 60,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "chat_message_history.messages"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2023-08-28T10:04:43.497990Z",
+ "start_time": "2023-08-28T10:04:43.492517Z"
+ }
+ },
+ "id": "4a6a54d8a9e2856f"
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Custom Name for Session Column\n",
+ "\n",
+ "The session id, a unique token identifying the session, is an important property of\n",
+ "this subsystem. If your database table stores it in a different column, you can use\n",
+ "the `session_id_field_name` keyword argument to adjust the name correspondingly."
+ ],
+ "metadata": {
+ "collapsed": false
+ },
+ "id": "622aded629a1adeb"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 57,
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "import typing as t\n",
+ "\n",
+ "from langchain.memory.chat_message_histories.cratedb import generate_autoincrement_identifier, CrateDBMessageConverter\n",
+ "from langchain.schema import _message_to_dict\n",
+ "\n",
+ "\n",
+ "Base = declarative_base()\n",
+ "\n",
+ "class MessageWithDifferentSessionIdColumn(Base):\n",
+ "\t__tablename__ = \"message_store_different_session_id\"\n",
+ "\tid = sa.Column(sa.BigInteger, primary_key=True, default=generate_autoincrement_identifier)\n",
+ "\tcustom_session_id = sa.Column(sa.Text)\n",
+ "\tmessage = sa.Column(sa.Text)\n",
+ "\n",
+ "\n",
+ "class CustomMessageConverterWithDifferentSessionIdColumn(CrateDBMessageConverter):\n",
+ " def __init__(self):\n",
+ " self.model_class = MessageWithDifferentSessionIdColumn\n",
+ "\n",
+ " def to_sql_model(self, message: BaseMessage, custom_session_id: str) -> t.Any:\n",
+ " return self.model_class(\n",
+ " custom_session_id=custom_session_id, message=json.dumps(_message_to_dict(message))\n",
+ " )\n",
+ "\n",
+ "\n",
+ "if __name__ == \"__main__\":\n",
+ "\tBase.metadata.drop_all(bind=sa.create_engine(CONNECTION_STRING))\n",
+ "\n",
+ "\tchat_message_history = CrateDBChatMessageHistory(\n",
+ "\t\tsession_id=\"test_session\",\n",
+ "\t\tconnection_string=CONNECTION_STRING,\n",
+ "\t\tcustom_message_converter=CustomMessageConverterWithDifferentSessionIdColumn(),\n",
+ "\t\tsession_id_field_name=\"custom_session_id\",\n",
+ "\t)\n",
+ "\n",
+ "\tchat_message_history.add_user_message(\"Hello\")\n",
+ "\tchat_message_history.add_ai_message(\"Hi\")"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "outputs": [
+ {
+ "data": {
+ "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]"
+ },
+ "execution_count": 58,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "chat_message_history.messages"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/docs/integrations/providers/cratedb.mdx b/docs/docs/integrations/providers/cratedb.mdx
new file mode 100644
index 0000000000000..4764a7ad92369
--- /dev/null
+++ b/docs/docs/integrations/providers/cratedb.mdx
@@ -0,0 +1,203 @@
+# CrateDB
+
+This documentation section shows how to use the CrateDB vector store
+functionality around [`FLOAT_VECTOR`] and [`KNN_MATCH`]. You will learn
+how to use it for similarity search and other purposes.
+
+
+## What is CrateDB?
+
+[CrateDB] is an open-source, distributed, and scalable SQL analytics database
+for storing and analyzing massive amounts of data in near real-time, even with
+complex queries. It is PostgreSQL-compatible, based on [Lucene], and inherits
+the shared-nothing distribution layer of [Elasticsearch].
+
+It provides a distributed, multi-tenant-capable relational database and search
+engine with HTTP and PostgreSQL interfaces, and schema-free objects. It supports
+sharding, partitioning, and replication out of the box.
+
+CrateDB enables you to efficiently store billions of records, and terabytes of
+data, and query it using SQL.
+
+- Provides a standards-based SQL interface for querying relational data, nested
+ documents, geospatial constraints, and vector embeddings at the same time.
+- Improves your operations by storing time-series data, relational metadata,
+ and vector embeddings within a single database.
+- Builds upon approved technologies from Lucene and Elasticsearch.
+
+
+## CrateDB Cloud
+
+- Offers on-demand CrateDB clusters without operational overhead,
+ with enterprise-grade features and [ISO 27001] certification.
+- The entrypoint to [CrateDB Cloud] is the [CrateDB Cloud Console].
+- Crate.io offers a free tier via [CrateDB Cloud CRFREE].
+- To get started, [sign up] to CrateDB Cloud, deploy a database cluster,
+ and follow the upcoming instructions.
+
+
+## Features
+
+The CrateDB adapter supports the _Vector Store_, _Document Loader_,
+and _Conversational Memory_ subsystems of LangChain.
+
+### Vector Store
+
+`CrateDBVectorSearch` is an API wrapper around CrateDB's `FLOAT_VECTOR` type
+and the corresponding `KNN_MATCH` function, based on SQLAlchemy and CrateDB's
+SQLAlchemy dialect. It provides an interface to store and retrieve floating
+point vectors, and to conduct similarity searches.
+
+Supports:
+- Approximate nearest neighbor search.
+- Euclidean distance.
+
+### Document Loader
+
+`CrateDBLoader` provides loading documents from a database table by an SQL
+query expression or an SQLAlchemy selectable instance.
+
+### Conversational Memory
+
+`CrateDBChatMessageHistory` uses CrateDB to manage conversation history.
+
+
+## Installation and Setup
+
+There are multiple ways to get started with CrateDB.
+
+### Install CrateDB on your local machine
+
+You can [download CrateDB], or use the [OCI image] to run CrateDB on Docker or Podman.
+Note that this is not recommended for production use.
+
+```shell
+docker run --rm -it --name=cratedb --publish=4200:4200 --publish=5432:5432 \
+ --env=CRATE_HEAP_SIZE=4g crate/crate:nightly \
+ -Cdiscovery.type=single-node
+```
+
+### Deploy a cluster on CrateDB Cloud
+
+[CrateDB Cloud] is a managed CrateDB service. Sign up for a [free trial].
+
+### Install Client
+
+```bash
+pip install 'crate[sqlalchemy]' 'langchain[openai]' 'crash'
+```
+
+
+## Usage » Vector Store
+
+For a more detailed walkthrough of the `CrateDBVectorSearch` wrapper, there is also
+a corresponding [Jupyter notebook](/docs/extras/integrations/vectorstores/cratedb.html).
+
+### Provide input data
+The example uses the canonical `state_of_the_union.txt`.
+```shell
+wget https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt
+```
+
+### Set environment variables
+Use a valid OpenAI API key and SQL connection string. This one fits a local instance of CrateDB.
+```shell
+export OPENAI_API_KEY=foobar # FIXME
+export CRATEDB_CONNECTION_STRING=crate://crate@localhost
+```
+
+### Example
+
+Load and index documents, and invoke query.
+```python
+from langchain.document_loaders import UnstructuredURLLoader
+from langchain.embeddings.openai import OpenAIEmbeddings
+from langchain.text_splitter import CharacterTextSplitter
+from langchain.vectorstores import CrateDBVectorSearch
+
+
+def main():
+ # Load the document, split it into chunks, embed each chunk and load it into the vector store.
+ raw_documents = UnstructuredURLLoader("https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt").load()
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
+ documents = text_splitter.split_documents(raw_documents)
+ db = CrateDBVectorSearch.from_documents(documents, OpenAIEmbeddings())
+
+ query = "What did the president say about Ketanji Brown Jackson"
+ docs = db.similarity_search(query)
+ print(docs[0].page_content)
+
+
+if __name__ == "__main__":
+ main()
+```
+
+
+## Usage » Document Loader
+
+For a more detailed walkthrough of the `CrateDBLoader`, there is also a corresponding
+[Jupyter notebook](/docs/extras/integrations/document_loaders/cratedb.html).
+
+
+### Provide input data
+```shell
+wget https://github.com/crate-workbench/langchain/raw/cratedb/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql
+crash < ./example_data/mlb_teams_2012.sql
+crash --command "REFRESH TABLE mlb_teams_2012;"
+```
+
+### Load documents by SQL query
+```python
+from langchain.document_loaders import CrateDBLoader
+from pprint import pprint
+
+def main():
+ loader = CrateDBLoader(
+ 'SELECT * FROM mlb_teams_2012 ORDER BY "Team" LIMIT 5;',
+ url="crate://crate@localhost/",
+ )
+ documents = loader.load()
+ pprint(documents)
+
+if __name__ == "__main__":
+ main()
+```
+
+
+## Usage » Conversational Memory
+
+For a more detailed walkthrough of the `CrateDBChatMessageHistory`, there is also a corresponding
+[Jupyter notebook](/docs/extras/integrations/memory/cratedb_chat_message_history.html).
+
+```python
+from langchain.memory.chat_message_histories import CrateDBChatMessageHistory
+from pprint import pprint
+
+def main():
+ chat_message_history = CrateDBChatMessageHistory(
+ session_id="test_session",
+ connection_string="crate://crate@localhost/",
+ )
+ chat_message_history.add_user_message("Hello")
+ chat_message_history.add_ai_message("Hi")
+ pprint(chat_message_history)
+
+if __name__ == "__main__":
+ main()
+```
+
+
+[CrateDB]: https://github.com/crate/crate
+[CrateDB Cloud]: https://crate.io/product
+[CrateDB Cloud Console]: https://console.cratedb.cloud/
+[CrateDB Cloud CRFREE]: https://community.crate.io/t/new-cratedb-cloud-edge-feature-cratedb-cloud-free-tier/1402
+[CrateDB SQLAlchemy dialect]: https://crate.io/docs/python/en/latest/sqlalchemy.html
+[download CrateDB]: https://crate.io/download
+[Elastisearch]: https://github.com/elastic/elasticsearch
+[`FLOAT_VECTOR`]: https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector
+[free trial]: https://crate.io/lp-crfree?utm_source=langchain
+[ISO 27001]: https://crate.io/blog/cratedb-elevates-its-security-standards-and-achieves-iso-27001-certification
+[`KNN_MATCH`]: https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match
+[Lucene]: https://github.com/apache/lucene
+[OCI image]: https://hub.docker.com/_/crate
+[sign up]: https://console.cratedb.cloud/
diff --git a/docs/docs/integrations/vectorstores/cratedb.ipynb b/docs/docs/integrations/vectorstores/cratedb.ipynb
new file mode 100644
index 0000000000000..06430e6355ae9
--- /dev/null
+++ b/docs/docs/integrations/vectorstores/cratedb.ipynb
@@ -0,0 +1,499 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# CrateDB\n",
+ "\n",
+ "This notebook shows how to use the CrateDB vector store functionality around\n",
+ "[`FLOAT_VECTOR`] and [`KNN_MATCH`]. You will learn how to use it for similarity\n",
+ "search and other purposes.\n",
+ "\n",
+ "It supports:\n",
+ "- Similarity Search with Euclidean Distance\n",
+ "- Maximal Marginal Relevance Search (MMR)\n",
+ "\n",
+ "## What is CrateDB?\n",
+ "\n",
+ "[CrateDB] is an open-source, distributed, and scalable SQL analytics database\n",
+ "for storing and analyzing massive amounts of data in near real-time, even with\n",
+ "complex queries. It is PostgreSQL-compatible, based on [Lucene], and inherits\n",
+ "the shared-nothing distribution layer of [Elasticsearch].\n",
+ "\n",
+ "This example uses the [Python client driver for CrateDB]. For more documentation,\n",
+ "see also [LangChain with CrateDB].\n",
+ "\n",
+ "\n",
+ "[CrateDB]: https://github.com/crate/crate\n",
+ "[Elasticsearch]: https://github.com/elastic/elasticsearch\n",
+ "[`FLOAT_VECTOR`]: https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector\n",
+ "[`KNN_MATCH`]: https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match\n",
+ "[LangChain with CrateDB]: /docs/extras/integrations/providers/cratedb.html\n",
+ "[Lucene]: https://github.com/apache/lucene\n",
+ "[Python client driver for CrateDB]: https://crate.io/docs/python/"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Getting Started"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": [],
+ "pycharm": {
+ "is_executing": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Install required packages: LangChain, OpenAI SDK, and the CrateDB Python driver.\n",
+ "!pip install 'langchain[cratedb,openai]'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You need to provide an OpenAI API key, optionally using the environment\n",
+ "variable `OPENAI_API_KEY`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-09-09T08:02:16.802456Z",
+ "start_time": "2023-09-09T08:02:07.065604Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import getpass\n",
+ "from dotenv import load_dotenv, find_dotenv\n",
+ "\n",
+ "# Run `export OPENAI_API_KEY=sk-YOUR_OPENAI_API_KEY`.\n",
+ "# Get OpenAI api key from `.env` file.\n",
+ "# Otherwise, prompt for it.\n",
+ "_ = load_dotenv(find_dotenv())\n",
+ "OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', getpass.getpass(\"OpenAI API key:\"))\n",
+ "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "You also need to provide a connection string to your CrateDB database cluster,\n",
+ "optionally using the environment variable `CRATEDB_CONNECTION_STRING`.\n",
+ "\n",
+ "This example uses a CrateDB instance on your workstation, which you can start by\n",
+ "running [CrateDB using Docker]. Alternatively, you can also connect to a cluster\n",
+ "running on [CrateDB Cloud].\n",
+ "\n",
+ "[CrateDB Cloud]: https://console.cratedb.cloud/\n",
+ "[CrateDB using Docker]: https://crate.io/docs/crate/tutorials/en/latest/basic/index.html#docker"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "CONNECTION_STRING = os.environ.get(\n",
+ " \"CRATEDB_CONNECTION_STRING\",\n",
+ " \"crate://crate@localhost:4200/?schema=langchain\",\n",
+ ")\n",
+ "\n",
+ "# For CrateDB Cloud, use:\n",
+ "# CONNECTION_STRING = os.environ.get(\n",
+ "# \"CRATEDB_CONNECTION_STRING\",\n",
+ "# \"crate://username:password@hostname:4200/?ssl=true&schema=langchain\",\n",
+ "# )"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-09-09T08:02:28.174088Z",
+ "start_time": "2023-09-09T08:02:28.162698Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "\"\"\"\n",
+ "# Alternatively, the connection string can be assembled from individual\n",
+ "# environment variables.\n",
+ "import os\n",
+ "\n",
+ "CONNECTION_STRING = CrateDBVectorSearch.connection_string_from_db_params(\n",
+ " driver=os.environ.get(\"CRATEDB_DRIVER\", \"crate\"),\n",
+ " host=os.environ.get(\"CRATEDB_HOST\", \"localhost\"),\n",
+ " port=int(os.environ.get(\"CRATEDB_PORT\", \"4200\")),\n",
+ " database=os.environ.get(\"CRATEDB_DATABASE\", \"langchain\"),\n",
+ " user=os.environ.get(\"CRATEDB_USER\", \"crate\"),\n",
+ " password=os.environ.get(\"CRATEDB_PASSWORD\", \"\"),\n",
+ ")\n",
+ "\"\"\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "You will start by importing all required modules."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "from langchain.embeddings.openai import OpenAIEmbeddings\n",
+ "from langchain.text_splitter import CharacterTextSplitter\n",
+ "from langchain.vectorstores import CrateDBVectorSearch\n",
+ "from langchain.document_loaders import UnstructuredURLLoader\n",
+ "from langchain.docstore.document import Document"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Load and Index Documents\n",
+ "\n",
+ "Next, you will read input data, and tokenize it. The module will create a table\n",
+ "with the name of the collection. Make sure the collection name is unique, and\n",
+ "that you have the permission to create a table."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "loader = UnstructuredURLLoader(\"https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt\")\n",
+ "documents = loader.load()\n",
+ "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
+ "docs = text_splitter.split_documents(documents)\n",
+ "\n",
+ "COLLECTION_NAME = \"state_of_the_union_test\"\n",
+ "\n",
+ "embeddings = OpenAIEmbeddings()\n",
+ "\n",
+ "db = CrateDBVectorSearch.from_documents(\n",
+ " embedding=embeddings,\n",
+ " documents=docs,\n",
+ " collection_name=COLLECTION_NAME,\n",
+ " connection_string=CONNECTION_STRING,\n",
+ ")"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "is_executing": true
+ }
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Search Documents\n",
+ "\n",
+ "### Similarity Search with Euclidean Distance\n",
+ "Searching by euclidean distance is the default."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-09-09T08:05:11.104135Z",
+ "start_time": "2023-09-09T08:05:10.548998Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "query = \"What did the president say about Ketanji Brown Jackson\"\n",
+ "docs_with_score = db.similarity_search_with_score(query)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-09-09T08:05:13.532334Z",
+ "start_time": "2023-09-09T08:05:13.523191Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "for doc, score in docs_with_score:\n",
+ " print(\"-\" * 80)\n",
+ " print(\"Score: \", score)\n",
+ " print(doc.page_content)\n",
+ " print(\"-\" * 80)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Maximal Marginal Relevance Search (MMR)\n",
+ "Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "docs_with_score = db.max_marginal_relevance_search_with_score(query)"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2023-09-09T08:05:23.276819Z",
+ "start_time": "2023-09-09T08:05:21.972256Z"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "for doc, score in docs_with_score:\n",
+ " print(\"-\" * 80)\n",
+ " print(\"Score: \", score)\n",
+ " print(doc.page_content)\n",
+ " print(\"-\" * 80)"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2023-09-09T08:05:27.478580Z",
+ "start_time": "2023-09-09T08:05:27.470138Z"
+ }
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Searching in Multiple Collections\n",
+ "`CrateDBVectorSearchMultiCollection` is a special adapter which provides similarity search across\n",
+ "multiple collections. It can not be used for indexing documents."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "from langchain.vectorstores.cratedb import CrateDBVectorSearchMultiCollection\n",
+ "\n",
+ "multisearch = CrateDBVectorSearchMultiCollection(\n",
+ " collection_names=[\"test_collection_1\", \"test_collection_2\"],\n",
+ " embedding_function=embeddings,\n",
+ " connection_string=CONNECTION_STRING,\n",
+ ")\n",
+ "docs_with_score = multisearch.similarity_search_with_score(query)"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Working with the Vector Store\n",
+ "\n",
+ "In the example above, you created a vector store from scratch. When\n",
+ "aiming to work with an existing vector store, you can initialize it directly."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "store = CrateDBVectorSearch(\n",
+ " collection_name=COLLECTION_NAME,\n",
+ " connection_string=CONNECTION_STRING,\n",
+ " embedding_function=embeddings,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Add Documents\n",
+ "\n",
+ "You can also add documents to an existing vector store."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "store.add_documents([Document(page_content=\"foo\")])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "docs_with_score = db.similarity_search_with_score(\"foo\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "docs_with_score[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "docs_with_score[1]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Overwriting a Vector Store\n",
+ "\n",
+ "If you have an existing collection, you can overwrite it by using `from_documents`,\n",
+ "aad setting `pre_delete_collection = True`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "db = CrateDBVectorSearch.from_documents(\n",
+ " documents=docs,\n",
+ " embedding=embeddings,\n",
+ " collection_name=COLLECTION_NAME,\n",
+ " connection_string=CONNECTION_STRING,\n",
+ " pre_delete_collection=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "docs_with_score = db.similarity_search_with_score(\"foo\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "docs_with_score[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Using a Vector Store as a Retriever"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "retriever = store.as_retriever()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(retriever)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.1"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs/docs/modules/data_connection/document_loaders/sqlalchemy.mdx b/docs/docs/modules/data_connection/document_loaders/sqlalchemy.mdx
new file mode 100644
index 0000000000000..9f7e663db075e
--- /dev/null
+++ b/docs/docs/modules/data_connection/document_loaders/sqlalchemy.mdx
@@ -0,0 +1,155 @@
+# SQLAlchemy
+
+
+## About
+
+The [SQLAlchemy] document loader loads records from any supported database,
+see [SQLAlchemy dialects] for all supported SQL databases and dialects.
+
+You can either use plain SQL for querying, or use an SQLAlchemy `Select`
+statement object, if you are using SQLAlchemy-Core or -ORM.
+
+You can select which columns to place into the document, which columns
+to place into its metadata, which columns to use as a `source` attribute
+in metadata, and whether to include the result row number and/or the SQL
+query expression into the metadata.
+
+
+## Example
+
+This example uses PostgreSQL, and the `psycopg2` driver.
+
+
+### Prerequisites
+
+```shell
+psql postgresql://postgres@localhost/ --command "CREATE DATABASE testdrive;"
+psql postgresql://postgres@localhost/testdrive < ./libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql
+```
+
+
+### Basic loading
+
+```python
+from langchain.document_loaders import SQLAlchemyLoader
+from pprint import pprint
+
+
+loader = SQLAlchemyLoader(
+ query="SELECT * FROM mlb_teams_2012 LIMIT 3;",
+ url="postgresql+psycopg2://postgres@localhost:5432/testdrive",
+)
+docs = loader.load()
+```
+
+```python
+pprint(docs)
+```
+
+
+
+```
+[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={}),
+ Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={}),
+ Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={})]
+```
+
+
+
+
+## Enriching metadata
+
+Use the `include_rownum_into_metadata` and `include_query_into_metadata` options to
+optionally populate the `metadata` dictionary with corresponding information.
+
+Having the `query` within metadata is useful when using documents loaded from
+database tables for chains that answer questions using their origin queries.
+
+```python
+loader = SQLAlchemyLoader(
+ query="SELECT * FROM mlb_teams_2012 LIMIT 3;",
+ url="postgresql+psycopg2://postgres@localhost:5432/testdrive",
+ include_rownum_into_metadata=True,
+ include_query_into_metadata=True,
+)
+docs = loader.load()
+```
+
+```python
+pprint(docs)
+```
+
+
+
+```
+[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={'row': 0, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'}),
+ Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={'row': 1, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'}),
+ Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={'row': 2, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'})]
+```
+
+
+
+
+## Customizing metadata
+
+Use the `page_content_columns`, and `metadata_columns` options to optionally populate
+the `metadata` dictionary with corresponding information. When `page_content_columns`
+is empty, all columns will be used.
+
+```python
+loader = SQLAlchemyLoader(
+ query="SELECT * FROM mlb_teams_2012 LIMIT 3;",
+ url="postgresql+psycopg2://postgres@localhost:5432/testdrive",
+ page_content_columns=["Payroll (millions)", "Wins"],
+ metadata_columns=["Team"],
+)
+docs = loader.load()
+```
+
+```python
+pprint(docs)
+```
+
+
+
+```
+[Document(page_content='Payroll (millions): 81.34\nWins: 98', metadata={'Team': 'Nationals'}),
+ Document(page_content='Payroll (millions): 82.2\nWins: 97', metadata={'Team': 'Reds'}),
+ Document(page_content='Payroll (millions): 197.96\nWins: 95', metadata={'Team': 'Yankees'})]
+```
+
+
+
+
+## Specify column(s) to identify the document source
+
+Use the `source_columns` option to specify the columns to use as a "source" for the
+document created from each row. This is useful for identifying documents through
+their metadata. Typically, you may use the primary key column(s) for that purpose.
+
+```python
+loader = SQLAlchemyLoader(
+ query="SELECT * FROM mlb_teams_2012 LIMIT 3;",
+ url="postgresql+psycopg2://postgres@localhost:5432/testdrive",
+ source_columns="Team",
+)
+docs = loader.load()
+```
+
+```python
+pprint(docs)
+```
+
+
+
+```
+[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={'source': 'Nationals'}),
+ Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={'source': 'Reds'}),
+ Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={'source': 'Yankees'})]
+```
+
+
+
+
+[SQLAlchemy]: https://www.sqlalchemy.org/
+[SQLAlchemy dialects]: https://docs.sqlalchemy.org/en/20/dialects/
diff --git a/docs/vercel.json b/docs/vercel.json
index 025c56cdb69af..20de4ea57968c 100644
--- a/docs/vercel.json
+++ b/docs/vercel.json
@@ -552,6 +552,10 @@
"source": "/docs/integrations/chaindesk",
"destination": "/docs/integrations/providers/chaindesk"
},
+ {
+ "source": "/docs/integrations/cratedb",
+ "destination": "/docs/integrations/providers/cratedb"
+ },
{
"source": "/docs/integrations/databricks",
"destination": "/docs/integrations/providers/databricks"
@@ -1732,6 +1736,10 @@
"source": "/docs/modules/data_connection/document_loaders/integrations/copypaste",
"destination": "/docs/integrations/document_loaders/copypaste"
},
+ {
+ "source": "/docs/modules/data_connection/document_loaders/integrations/cratedb",
+ "destination": "/docs/integrations/document_loaders/cratedb"
+ },
{
"source": "/en/latest/modules/indexes/document_loaders/examples/csv.html",
"destination": "/docs/integrations/document_loaders/csv"
@@ -2680,6 +2688,14 @@
"source": "/docs/modules/data_connection/vectorstores/integrations/chroma",
"destination": "/docs/integrations/vectorstores/chroma"
},
+ {
+ "source": "/en/latest/modules/indexes/vectorstores/examples/cratedb.html",
+ "destination": "/docs/integrations/vectorstores/cratedb"
+ },
+ {
+ "source": "/docs/modules/data_connection/vectorstores/integrations/cratedb",
+ "destination": "/docs/integrations/vectorstores/cratedb"
+ },
{
"source": "/en/latest/modules/indexes/vectorstores/examples/deeplake.html",
"destination": "/docs/integrations/vectorstores/activeloop_deeplake"
@@ -2944,6 +2960,14 @@
"source": "/docs/integrations/memory/entity_memory_with_sqlite",
"destination": "/docs/integrations/memory/sqlite"
},
+ {
+ "source": "/en/latest/modules/memory/examples/cratedb_chat_message_history.html",
+ "destination": "/docs/integrations/memory/cratedb_chat_message_history"
+ },
+ {
+ "source": "/docs/modules/memory/integrations/cratedb_chat_message_history",
+ "destination": "/docs/integrations/memory/cratedb_chat_message_history"
+ },
{
"source": "/en/latest/modules/memory/examples/dynamodb_chat_message_history.html",
"destination": "/docs/integrations/memory/dynamodb_chat_message_history"
diff --git a/libs/experimental/tests/unit_tests/conftest.py b/libs/experimental/tests/unit_tests/conftest.py
index da45a330f50af..afb609f3ce31f 100644
--- a/libs/experimental/tests/unit_tests/conftest.py
+++ b/libs/experimental/tests/unit_tests/conftest.py
@@ -40,8 +40,8 @@ def test_something():
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
- only_extended = config.getoption("--only-extended") or False
- only_core = config.getoption("--only-core") or False
+ only_extended = config.getoption("--only-extended", False)
+ only_core = config.getoption("--only-core", False)
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py
index 96cfd9e1b5486..a454a1baa9eb4 100644
--- a/libs/langchain/langchain/document_loaders/__init__.py
+++ b/libs/langchain/langchain/document_loaders/__init__.py
@@ -59,6 +59,7 @@
from langchain.document_loaders.concurrent import ConcurrentLoader
from langchain.document_loaders.confluence import ConfluenceLoader
from langchain.document_loaders.conllu import CoNLLULoader
+from langchain.document_loaders.cratedb import CrateDBLoader
from langchain.document_loaders.csv_loader import CSVLoader, UnstructuredCSVLoader
from langchain.document_loaders.cube_semantic import CubeSemanticLoader
from langchain.document_loaders.datadog_logs import DatadogLogsLoader
@@ -158,6 +159,7 @@
from langchain.document_loaders.slack_directory import SlackDirectoryLoader
from langchain.document_loaders.snowflake_loader import SnowflakeLoader
from langchain.document_loaders.spreedly import SpreedlyLoader
+from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader
from langchain.document_loaders.srt import SRTLoader
from langchain.document_loaders.stripe import StripeLoader
from langchain.document_loaders.telegram import (
@@ -244,6 +246,7 @@
"CollegeConfidentialLoader",
"ConcurrentLoader",
"ConfluenceLoader",
+ "CrateDBLoader",
"CubeSemanticLoader",
"DataFrameLoader",
"DatadogLogsLoader",
@@ -333,6 +336,7 @@
"SlackDirectoryLoader",
"SnowflakeLoader",
"SpreedlyLoader",
+ "SQLAlchemyLoader",
"StripeLoader",
"TelegramChatApiLoader",
"TelegramChatFileLoader",
diff --git a/libs/langchain/langchain/document_loaders/cratedb.py b/libs/langchain/langchain/document_loaders/cratedb.py
new file mode 100644
index 0000000000000..9e34b4d0cb9ec
--- /dev/null
+++ b/libs/langchain/langchain/document_loaders/cratedb.py
@@ -0,0 +1,5 @@
+from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader
+
+
+class CrateDBLoader(SQLAlchemyLoader):
+ pass
diff --git a/libs/langchain/langchain/document_loaders/sqlalchemy.py b/libs/langchain/langchain/document_loaders/sqlalchemy.py
new file mode 100644
index 0000000000000..787c9f339b686
--- /dev/null
+++ b/libs/langchain/langchain/document_loaders/sqlalchemy.py
@@ -0,0 +1,112 @@
+from typing import Dict, List, Optional, Union
+
+import sqlalchemy as sa
+
+from langchain.docstore.document import Document
+from langchain.document_loaders.base import BaseLoader
+
+
+class SQLAlchemyLoader(BaseLoader):
+ """
+ Load documents by querying database tables supported by SQLAlchemy.
+ Each document represents one row of the result.
+ """
+
+ def __init__(
+ self,
+ query: Union[str, sa.Select],
+ url: str,
+ page_content_columns: Optional[List[str]] = None,
+ metadata_columns: Optional[List[str]] = None,
+ source_columns: Optional[List[str]] = None,
+ include_rownum_into_metadata: bool = False,
+ include_query_into_metadata: bool = False,
+ sqlalchemy_kwargs: Optional[Dict] = None,
+ ):
+ """
+
+ Args:
+ query: The query to execute.
+ url: The SQLAlchemy connection string of the database to connect to.
+ page_content_columns: The columns to write into the `page_content`
+ of the document. Optional.
+ metadata_columns: The columns to write into the `metadata` of the document.
+ Optional.
+ source_columns: The names of the columns to use as the `source` within the
+ metadata dictionary. Optional.
+ include_rownum_into_metadata: Whether to include the row number into the
+ metadata dictionary. Optional. Default: False.
+ include_query_into_metadata: Whether to include the query expression into
+ the metadata dictionary. Optional. Default: False.
+ sqlalchemy_kwargs: More keyword arguments for SQLAlchemy's `create_engine`.
+ """
+ self.query = query
+ self.url = url
+ self.page_content_columns = page_content_columns
+ self.metadata_columns = metadata_columns
+ self.source_columns = source_columns
+ self.include_rownum_into_metadata = include_rownum_into_metadata
+ self.include_query_into_metadata = include_query_into_metadata
+ self.sqlalchemy_kwargs = sqlalchemy_kwargs or {}
+
+ def load(self) -> List[Document]:
+ try:
+ import sqlalchemy as sa
+ except ImportError:
+ raise ImportError(
+ "Could not import sqlalchemy python package. "
+ "Please install it with `pip install sqlalchemy`."
+ )
+
+ engine = sa.create_engine(self.url, **self.sqlalchemy_kwargs)
+
+ docs = []
+ with engine.connect() as conn:
+ if isinstance(self.query, sa.Select):
+ result = conn.execute(self.query)
+ query_sql = str(self.query.compile(bind=engine))
+ elif isinstance(self.query, str):
+ result = conn.execute(sa.text(self.query))
+ query_sql = self.query
+ else:
+ raise TypeError(
+ f"Unable to process query of unknown type: {self.query}"
+ )
+ field_names = list(result.mappings().keys())
+
+ if self.page_content_columns is None:
+ page_content_columns = field_names
+ else:
+ page_content_columns = self.page_content_columns
+
+ if self.metadata_columns is None:
+ metadata_columns = []
+ else:
+ metadata_columns = self.metadata_columns
+
+ for i, row in enumerate(result.mappings()):
+ page_content = "\n".join(
+ f"{column}: {value}"
+ for column, value in row.items()
+ if column in page_content_columns
+ )
+
+ metadata: Dict[str, Union[str, int]] = {}
+ if self.include_rownum_into_metadata:
+ metadata["row"] = i
+ if self.include_query_into_metadata:
+ metadata["query"] = query_sql
+
+ source_values = []
+ for column, value in row.items():
+ if column in metadata_columns:
+ metadata[column] = value
+ if self.source_columns and column in self.source_columns:
+ source_values.append(value)
+ if source_values:
+ metadata["source"] = ",".join(source_values)
+
+ doc = Document(page_content=page_content, metadata=metadata)
+ docs.append(doc)
+
+ return docs
diff --git a/libs/langchain/langchain/memory/chat_message_histories/__init__.py b/libs/langchain/langchain/memory/chat_message_histories/__init__.py
index 83fc7fa519ad0..7275d66d1108b 100644
--- a/libs/langchain/langchain/memory/chat_message_histories/__init__.py
+++ b/libs/langchain/langchain/memory/chat_message_histories/__init__.py
@@ -5,6 +5,7 @@
CassandraChatMessageHistory,
)
from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
+from langchain.memory.chat_message_histories.cratedb import CrateDBChatMessageHistory
from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
from langchain.memory.chat_message_histories.elasticsearch import (
ElasticsearchChatMessageHistory,
@@ -38,6 +39,7 @@
"ChatMessageHistory",
"CassandraChatMessageHistory",
"CosmosDBChatMessageHistory",
+ "CrateDBChatMessageHistory",
"DynamoDBChatMessageHistory",
"ElasticsearchChatMessageHistory",
"FileChatMessageHistory",
diff --git a/libs/langchain/langchain/memory/chat_message_histories/cratedb.py b/libs/langchain/langchain/memory/chat_message_histories/cratedb.py
new file mode 100644
index 0000000000000..19007176cb193
--- /dev/null
+++ b/libs/langchain/langchain/memory/chat_message_histories/cratedb.py
@@ -0,0 +1,113 @@
+import json
+import typing as t
+
+import sqlalchemy as sa
+from cratedb_toolkit.sqlalchemy import (
+ patch_inspector,
+ polyfill_refresh_after_dml,
+ refresh_table,
+)
+
+from langchain.memory.chat_message_histories.sql import (
+ BaseMessageConverter,
+ SQLChatMessageHistory,
+)
+from langchain.schema import BaseMessage, _message_to_dict, messages_from_dict
+
+
+def create_message_model(table_name, DynamicBase): # type: ignore
+ """
+ Create a message model for a given table name.
+
+ This is a specialized version for CrateDB for generating integer-based primary keys.
+ TODO: Find a way to converge CrateDB's generate_random_uuid() into a variant
+ returning its integer value.
+
+ Args:
+ table_name: The name of the table to use.
+ DynamicBase: The base class to use for the model.
+
+ Returns:
+ The model class.
+ """
+
+ # Model is declared inside a function to be able to use a dynamic table name.
+ class Message(DynamicBase):
+ __tablename__ = table_name
+ id = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now())
+ session_id = sa.Column(sa.Text)
+ message = sa.Column(sa.Text)
+
+ return Message
+
+
+class CrateDBMessageConverter(BaseMessageConverter):
+ """
+ The default message converter for CrateDBMessageConverter.
+
+ It is the same as the generic `SQLChatMessageHistory` converter,
+ but swaps in a different `create_message_model` function.
+ """
+
+ def __init__(self, table_name: str):
+ self.model_class = create_message_model(table_name, sa.orm.declarative_base())
+
+ def from_sql_model(self, sql_message: t.Any) -> BaseMessage:
+ return messages_from_dict([json.loads(sql_message.message)])[0]
+
+ def to_sql_model(self, message: BaseMessage, session_id: str) -> t.Any:
+ return self.model_class(
+ session_id=session_id, message=json.dumps(_message_to_dict(message))
+ )
+
+ def get_sql_model_class(self) -> t.Any:
+ return self.model_class
+
+
+class CrateDBChatMessageHistory(SQLChatMessageHistory):
+ """
+ It is the same as the generic `SQLChatMessageHistory` implementation,
+ but swaps in a different message converter by default.
+ """
+
+ DEFAULT_MESSAGE_CONVERTER: t.Type[BaseMessageConverter] = CrateDBMessageConverter
+
+ def __init__(
+ self,
+ session_id: str,
+ connection_string: str,
+ table_name: str = "message_store",
+ session_id_field_name: str = "session_id",
+ custom_message_converter: t.Optional[BaseMessageConverter] = None,
+ ):
+ # FIXME: Refactor elsewhere.
+ patch_inspector()
+
+ super().__init__(
+ session_id,
+ connection_string,
+ table_name=table_name,
+ session_id_field_name=session_id_field_name,
+ custom_message_converter=custom_message_converter,
+ )
+
+ # TODO: Check how this can be improved.
+ polyfill_refresh_after_dml(self.Session)
+
+ def _messages_query(self) -> sa.Select:
+ """
+ Construct an SQLAlchemy selectable to query for messages.
+ For CrateDB, add an `ORDER BY` clause on the primary key.
+ """
+ selectable = super()._messages_query()
+ selectable = selectable.order_by(self.sql_model_class.id)
+ return selectable
+
+ def clear(self) -> None:
+ """
+ Needed for CrateDB to synchronize data because `on_flush` does not catch it.
+ """
+ outcome = super().clear()
+ with self.Session() as session:
+ refresh_table(session, self.sql_model_class)
+ return outcome
diff --git a/libs/langchain/langchain/memory/chat_message_histories/sql.py b/libs/langchain/langchain/memory/chat_message_histories/sql.py
index fcc3ac71ab1a6..edf59685dc2fe 100644
--- a/libs/langchain/langchain/memory/chat_message_histories/sql.py
+++ b/libs/langchain/langchain/memory/chat_message_histories/sql.py
@@ -1,9 +1,9 @@
import json
import logging
from abc import ABC, abstractmethod
-from typing import Any, List, Optional
+from typing import Any, List, Optional, Type
-from sqlalchemy import Column, Integer, Text, create_engine
+from sqlalchemy import Column, Integer, Select, Text, create_engine, select
try:
from sqlalchemy.orm import declarative_base
@@ -23,6 +23,10 @@
class BaseMessageConverter(ABC):
"""The class responsible for converting BaseMessage to your SQLAlchemy model."""
+ @abstractmethod
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError
+
@abstractmethod
def from_sql_model(self, sql_message: Any) -> BaseMessage:
"""Convert a SQLAlchemy model to a BaseMessage instance."""
@@ -52,7 +56,7 @@ def create_message_model(table_name, DynamicBase): # type: ignore
"""
- # Model decleared inside a function to have a dynamic table name
+ # Model declared inside a function to have a dynamic table name
class Message(DynamicBase):
__tablename__ = table_name
id = Column(Integer, primary_key=True)
@@ -83,6 +87,8 @@ def get_sql_model_class(self) -> Any:
class SQLChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in an SQL database."""
+ DEFAULT_MESSAGE_CONVERTER: Type[BaseMessageConverter] = DefaultMessageConverter
+
def __init__(
self,
session_id: str,
@@ -94,7 +100,9 @@ def __init__(
self.connection_string = connection_string
self.engine = create_engine(connection_string, echo=False)
self.session_id_field_name = session_id_field_name
- self.converter = custom_message_converter or DefaultMessageConverter(table_name)
+ self.converter = custom_message_converter or self.DEFAULT_MESSAGE_CONVERTER(
+ table_name
+ )
self.sql_model_class = self.converter.get_sql_model_class()
if not hasattr(self.sql_model_class, session_id_field_name):
raise ValueError("SQL model class must have session_id column")
@@ -106,21 +114,25 @@ def __init__(
def _create_table_if_not_exists(self) -> None:
self.sql_model_class.metadata.create_all(self.engine)
+ def _messages_query(self) -> Select:
+ """Construct an SQLAlchemy selectable to query for messages"""
+ return (
+ select(self.sql_model_class)
+ .where(
+ getattr(self.sql_model_class, self.session_id_field_name)
+ == self.session_id
+ )
+ .order_by(self.sql_model_class.id.asc())
+ )
+
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all messages from db"""
with self.Session() as session:
- result = (
- session.query(self.sql_model_class)
- .where(
- getattr(self.sql_model_class, self.session_id_field_name)
- == self.session_id
- )
- .order_by(self.sql_model_class.id.asc())
- )
+ result = session.execute(self._messages_query())
messages = []
for record in result:
- messages.append(self.converter.from_sql_model(record))
+ messages.append(self.converter.from_sql_model(record[0]))
return messages
def add_message(self, message: BaseMessage) -> None:
diff --git a/libs/langchain/langchain/vectorstores/__init__.py b/libs/langchain/langchain/vectorstores/__init__.py
index 6c79ccfdf40e6..813535fc22fdb 100644
--- a/libs/langchain/langchain/vectorstores/__init__.py
+++ b/libs/langchain/langchain/vectorstores/__init__.py
@@ -134,6 +134,12 @@ def _import_clickhouse_settings() -> Any:
return ClickhouseSettings
+def _import_cratedb() -> Any:
+ from langchain.vectorstores.cratedb import CrateDBVectorSearch
+
+ return CrateDBVectorSearch
+
+
def _import_dashvector() -> Any:
from langchain.vectorstores.dashvector import DashVector
@@ -465,6 +471,8 @@ def __getattr__(name: str) -> Any:
return _import_clickhouse_settings()
elif name == "Clickhouse":
return _import_clickhouse()
+ elif name == "CrateDBVectorSearch":
+ return _import_cratedb()
elif name == "DashVector":
return _import_dashvector()
elif name == "DatabricksVectorSearch":
@@ -582,6 +590,7 @@ def __getattr__(name: str) -> Any:
"Clarifai",
"Clickhouse",
"ClickhouseSettings",
+ "CrateDBVectorSearch",
"DashVector",
"DatabricksVectorSearch",
"DeepLake",
diff --git a/libs/langchain/langchain/vectorstores/cratedb/__init__.py b/libs/langchain/langchain/vectorstores/cratedb/__init__.py
new file mode 100644
index 0000000000000..3b3e23270ec1b
--- /dev/null
+++ b/libs/langchain/langchain/vectorstores/cratedb/__init__.py
@@ -0,0 +1,8 @@
+from .base import CrateDBVectorSearch, StorageStrategy
+from .extended import CrateDBVectorSearchMultiCollection
+
+__all__ = [
+ "CrateDBVectorSearch",
+ "CrateDBVectorSearchMultiCollection",
+ "StorageStrategy",
+]
diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py
new file mode 100644
index 0000000000000..9257ce2772c99
--- /dev/null
+++ b/libs/langchain/langchain/vectorstores/cratedb/base.py
@@ -0,0 +1,495 @@
+from __future__ import annotations
+
+import enum
+import math
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+)
+
+import sqlalchemy
+from cratedb_toolkit.sqlalchemy.patch import patch_inspector
+from cratedb_toolkit.sqlalchemy.polyfill import (
+ polyfill_refresh_after_dml,
+ refresh_table,
+)
+from sqlalchemy.orm import sessionmaker
+
+from langchain.docstore.document import Document
+from langchain.schema.embeddings import Embeddings
+from langchain.utils import get_from_dict_or_env
+from langchain.vectorstores.cratedb.model import ModelFactory
+from langchain.vectorstores.pgvector import PGVector
+
+
+class DistanceStrategy(str, enum.Enum):
+ """Enumerator of similarity distance strategies."""
+
+ EUCLIDEAN = "euclidean"
+ COSINE = "cosine"
+ MAX_INNER_PRODUCT = "inner"
+
+
+class StorageStrategy(enum.Enum):
+ """Enumerator of storage strategies."""
+
+ # This storage strategy reflects the vanilla way the pgvector adapter manages
+ # the data model: There is a single `collection` table and a single
+ # `embedding` table.
+ LANGCHAIN_PGVECTOR = "langchain_pgvector"
+
+ # This storage strategy reflects a more advanced way to manage
+ # the data model: There is a single `collection` table, and multiple
+ # `embedding` tables, one per collection.
+ EMBEDDING_TABLE_PER_COLLECTION = "embedding_table_per_collection"
+
+
+DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN
+DEFAULT_STORAGE_STRATEGY = StorageStrategy.LANGCHAIN_PGVECTOR
+
+
+_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
+
+
+def _results_to_docs(docs_and_scores: Any) -> List[Document]:
+ """Return docs from docs and scores."""
+ return [doc for doc, _ in docs_and_scores]
+
+
+class CrateDBVectorSearch(PGVector):
+ """`CrateDB` vector store.
+
+ To use it, you should have the ``crate[sqlalchemy]`` python package installed.
+
+ Args:
+ connection_string: Database connection string.
+ embedding_function: Any embedding function implementing
+ `langchain.embeddings.base.Embeddings` interface.
+ collection_name: The name of the collection to use. (default: langchain)
+ NOTE: This is not the name of the table, but the name of the collection.
+ The tables will be created when initializing the store (if not exists)
+ So, make sure the user has the right permissions to create tables.
+ distance_strategy: The distance strategy to use. (default: EUCLIDEAN)
+ pre_delete_collection: If True, will delete the collection if it exists.
+ (default: False). Useful for testing.
+
+ Example:
+ .. code-block:: python
+
+ from langchain.vectorstores import CrateDBVectorSearch
+ from langchain.embeddings.openai import OpenAIEmbeddings
+
+ CONNECTION_STRING = "crate://crate@localhost:4200/test3"
+ COLLECTION_NAME = "state_of_the_union_test"
+ embeddings = OpenAIEmbeddings()
+ vectorestore = CrateDBVectorSearch.from_documents(
+ embedding=embeddings,
+ documents=docs,
+ collection_name=COLLECTION_NAME,
+ connection_string=CONNECTION_STRING,
+ )
+
+ """
+
+ # Select storage strategy: Either use two database tables (`collection`
+ # and `embedding`), or multiple `embedding` tables, one per collection.
+ STORAGE_STRATEGY: StorageStrategy = DEFAULT_STORAGE_STRATEGY
+
+ @classmethod
+ def configure(
+ cls, storage_strategy: StorageStrategy = DEFAULT_STORAGE_STRATEGY
+ ) -> None:
+ cls.STORAGE_STRATEGY = storage_strategy
+
+ def __post_init__(
+ self,
+ ) -> None:
+ """
+ Initialize the store.
+ """
+
+ # FIXME: Could be a bug in CrateDB SQLAlchemy dialect.
+ patch_inspector()
+
+ self._engine = self.create_engine()
+ self.Session = sessionmaker(self._engine)
+
+ # TODO: See what can be improved here.
+ polyfill_refresh_after_dml(self.Session)
+
+ # Need to defer initialization, because dimension size
+ # can only be figured out at runtime.
+ self.BaseModel = None
+ self.CollectionStore = None # type: ignore[assignment]
+ self.EmbeddingStore = None # type: ignore[assignment]
+
+ def __del__(self) -> None:
+ """
+ Work around premature session close.
+
+ sqlalchemy.orm.exc.DetachedInstanceError: Parent instance is not bound
+ to a Session; lazy load operation of attribute 'embeddings' cannot proceed.
+ -- https://docs.sqlalchemy.org/en/20/errors.html#error-bhk3
+
+ TODO: Review!
+ """ # noqa: E501
+ pass
+
+ def _init_models(self, embedding: List[float]) -> None:
+ """
+ Initialize SQLAlchemy model classes, using dimensionality from given vector.
+
+ It will only initialize the model classes once.
+ """
+
+ # TODO: Use a better way to run this only once.
+ if self.CollectionStore is not None and self.EmbeddingStore is not None:
+ return
+
+ size = len(embedding)
+ self._init_models_with_dimensionality(size=size)
+
+ def _init_models_with_dimensionality(self, size: int) -> None:
+ """
+ Initialize SQLAlchemy model classes, using given dimensionality value.
+ """
+ mf = self._get_model_factory(size=size)
+ self.BaseModel, self.CollectionStore, self.EmbeddingStore = (
+ mf.BaseModel, # type: ignore[assignment]
+ mf.CollectionStore,
+ mf.EmbeddingStore,
+ )
+
+ def _get_model_factory(self, size: Optional[int] = None) -> ModelFactory:
+ """
+ Initialize SQLAlchemy model classes, based on the selected storage strategy.
+ """
+ if self.STORAGE_STRATEGY is StorageStrategy.EMBEDDING_TABLE_PER_COLLECTION:
+ mf = ModelFactory(
+ dimensions=size, embedding_table=f"embedding_{self.collection_name}"
+ )
+ else:
+ mf = ModelFactory(dimensions=size)
+ return mf
+
+ def get_collection(self, session: sqlalchemy.orm.Session) -> Any:
+ if self.CollectionStore is None:
+ raise RuntimeError(
+ "Collection can't be accessed without specifying "
+ "dimension size of embedding vectors"
+ )
+ return self.CollectionStore.get_by_name(session, self.collection_name)
+
+ def add_embeddings(
+ self,
+ texts: Iterable[str],
+ embeddings: List[List[float]],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add embeddings to the vectorstore.
+
+ Args:
+ texts: Iterable of strings to add to the vectorstore.
+ embeddings: List of list of embedding vectors.
+ metadatas: List of metadatas associated with the texts.
+ kwargs: vectorstore specific parameters
+ """
+ if not embeddings:
+ return []
+ self._init_models(embeddings[0])
+
+ # When the user requested to delete the collection before running subsequent
+ # operations on it, run the deletion gracefully if the table does not exist
+ # yet.
+ if self.pre_delete_collection:
+ try:
+ self.delete_collection()
+ except sqlalchemy.exc.ProgrammingError as ex:
+ if "RelationUnknown" not in str(ex):
+ raise
+
+ # Tables need to be created at runtime, because the `EmbeddingStore.embedding`
+ # field, a `FloatVector`, needs to be initialized with a dimensionality
+ # parameter, which is only obtained at runtime.
+ self.create_tables_if_not_exists()
+ self.create_collection()
+
+ # After setting up the table/collection at runtime, add embeddings.
+ embedding_ids = super().add_embeddings(
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
+ )
+ refresh_table(self.Session(), self.EmbeddingStore)
+ return embedding_ids
+
+ def create_tables_if_not_exists(self) -> None:
+ """
+ Need to overwrite because this `Base` is different from parent's `Base`.
+ """
+ if self.BaseModel is None:
+ raise RuntimeError("Storage models not initialized")
+ self.BaseModel.metadata.create_all(self._engine)
+
+ def drop_tables(self) -> None:
+ """
+ Need to overwrite because this `Base` is different from parent's `Base`.
+ """
+ mf = self._get_model_factory()
+ mf.Base.metadata.drop_all(self._engine)
+
+ def delete(
+ self,
+ ids: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ """
+ Delete vectors by ids or uuids.
+
+ Remark: Specialized for CrateDB to synchronize data.
+
+ Args:
+ ids: List of ids to delete.
+
+ Remark: Patch for CrateDB needs to overwrite this, in order to
+ add a "REFRESH TABLE" statement afterwards. The other
+ patch, listening to `after_delete` events seems not be
+ able to catch it.
+ """
+ super().delete(ids=ids, **kwargs)
+
+ # CrateDB: Synchronize data because `on_flush` does not catch it.
+ with self.Session() as session:
+ refresh_table(session, self.EmbeddingStore)
+
+ @property
+ def distance_strategy(self) -> Any:
+ if self._distance_strategy == DistanceStrategy.EUCLIDEAN:
+ return self.EmbeddingStore.embedding.euclidean_distance
+ elif self._distance_strategy == DistanceStrategy.COSINE:
+ raise NotImplementedError("Cosine similarity not implemented yet")
+ elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
+ raise NotImplementedError("Dot-product similarity not implemented yet")
+ else:
+ raise ValueError(
+ f"Got unexpected value for distance: {self._distance_strategy}. "
+ f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}."
+ )
+
+ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]:
+ """Return docs and scores from results."""
+ docs = [
+ (
+ Document(
+ page_content=result.EmbeddingStore.document,
+ metadata=result.EmbeddingStore.cmetadata,
+ ),
+ result._score if self.embedding_function is not None else None,
+ )
+ for result in results
+ ]
+ return docs
+
+ def _query_collection(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ ) -> List[Any]:
+ """Query the collection."""
+ self._init_models(embedding)
+ with self.Session() as session:
+ collection = self.get_collection(session)
+ if not collection:
+ raise ValueError("Collection not found")
+ return self._query_collection_multi(
+ collections=[collection], embedding=embedding, k=k, filter=filter
+ )
+
+ def _query_collection_multi(
+ self,
+ collections: List[Any],
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ ) -> List[Any]:
+ """Query the collection."""
+ self._init_models(embedding)
+
+ collection_names = [coll.name for coll in collections]
+ collection_uuids = [coll.uuid for coll in collections]
+ self.logger.info(f"Querying collections: {collection_names}")
+
+ with self.Session() as session:
+ filter_by = self.EmbeddingStore.collection_id.in_(collection_uuids)
+
+ if filter is not None:
+ filter_clauses = []
+ for key, value in filter.items():
+ IN = "in"
+ if isinstance(value, dict) and IN in map(str.lower, value):
+ value_case_insensitive = {
+ k.lower(): v for k, v in value.items()
+ }
+ filter_by_metadata = self.EmbeddingStore.cmetadata[key].in_(
+ value_case_insensitive[IN]
+ )
+ filter_clauses.append(filter_by_metadata)
+ else:
+ filter_by_metadata = self.EmbeddingStore.cmetadata[key] == str(
+ value
+ ) # type: ignore[assignment]
+ filter_clauses.append(filter_by_metadata)
+
+ filter_by = sqlalchemy.and_(filter_by, *filter_clauses) # type: ignore[assignment]
+
+ _type = self.EmbeddingStore
+
+ results: List[Any] = (
+ session.query( # type: ignore[attr-defined]
+ self.EmbeddingStore,
+ # TODO: Original pgvector code uses `self.distance_strategy`.
+ # CrateDB currently only supports EUCLIDEAN.
+ # self.distance_strategy(embedding).label("distance")
+ sqlalchemy.literal_column(
+ f"{self.EmbeddingStore.__tablename__}._score"
+ ).label("_score"),
+ )
+ .filter(filter_by)
+ # CrateDB applies `KNN_MATCH` within the `WHERE` clause.
+ .filter(
+ sqlalchemy.func.knn_match(
+ self.EmbeddingStore.embedding, embedding, k
+ )
+ )
+ .order_by(sqlalchemy.desc("_score"))
+ .join(
+ self.CollectionStore,
+ self.EmbeddingStore.collection_id == self.CollectionStore.uuid,
+ )
+ .limit(k)
+ )
+ return results
+
+ @classmethod
+ def from_texts( # type: ignore[override]
+ cls: Type[CrateDBVectorSearch],
+ texts: List[str],
+ embedding: Embeddings,
+ metadatas: Optional[List[dict]] = None,
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
+ ids: Optional[List[str]] = None,
+ pre_delete_collection: bool = False,
+ **kwargs: Any,
+ ) -> CrateDBVectorSearch:
+ """
+ Return VectorStore initialized from texts and embeddings.
+ Database connection string is required.
+
+ Either pass it as a parameter, or set the CRATEDB_CONNECTION_STRING
+ environment variable.
+
+ Remark: Needs to be overwritten, because CrateDB uses a different
+ DEFAULT_DISTANCE_STRATEGY.
+ """
+ return super().from_texts( # type: ignore[return-value]
+ texts,
+ embedding,
+ metadatas=metadatas,
+ ids=ids,
+ collection_name=collection_name,
+ distance_strategy=distance_strategy, # type: ignore[arg-type]
+ pre_delete_collection=pre_delete_collection,
+ **kwargs,
+ )
+
+ @classmethod
+ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
+ connection_string: str = get_from_dict_or_env(
+ data=kwargs,
+ key="connection_string",
+ env_key="CRATEDB_CONNECTION_STRING",
+ )
+
+ if not connection_string:
+ raise ValueError(
+ "Database connection string is required."
+ "Either pass it as a parameter, or set the "
+ "CRATEDB_CONNECTION_STRING environment variable."
+ )
+
+ return connection_string
+
+ @classmethod
+ def connection_string_from_db_params(
+ cls,
+ driver: str,
+ host: str,
+ port: int,
+ database: str,
+ user: str,
+ password: str,
+ ) -> str:
+ """Return connection string from database parameters."""
+ return str(
+ sqlalchemy.URL.create(
+ drivername=driver,
+ host=host,
+ port=port,
+ username=user,
+ password=password,
+ query={"schema": database},
+ )
+ )
+
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
+ """
+ The 'correct' relevance function
+ may differ depending on a few things, including:
+ - the distance / similarity metric used by the VectorStore
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
+ - embedding dimensionality
+ - etc.
+ """
+ if self.override_relevance_score_fn is not None:
+ return self.override_relevance_score_fn
+
+ # Default strategy is to rely on distance strategy provided
+ # in vectorstore constructor
+ if self._distance_strategy == DistanceStrategy.COSINE:
+ return self._cosine_relevance_score_fn
+ elif self._distance_strategy == DistanceStrategy.EUCLIDEAN:
+ return self._euclidean_relevance_score_fn
+ elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
+ return self._max_inner_product_relevance_score_fn
+ else:
+ raise ValueError(
+ "No supported normalization function for distance_strategy of "
+ f"{self._distance_strategy}. Consider providing relevance_score_fn to "
+ "CrateDBVectorSearch constructor."
+ )
+
+ @staticmethod
+ def _euclidean_relevance_score_fn(score: float) -> float:
+ """Return a similarity score on a scale [0, 1]."""
+ # The 'correct' relevance function
+ # may differ depending on a few things, including:
+ # - the distance / similarity metric used by the VectorStore
+ # - the scale of your embeddings (OpenAI's are unit normed. Many
+ # others are not!)
+ # - embedding dimensionality
+ # - etc.
+ # This function converts the euclidean norm of normalized embeddings
+ # (0 is most similar, sqrt(2) most dissimilar)
+ # to a similarity function (0 to 1)
+
+ # Original:
+ # return 1.0 - distance / math.sqrt(2)
+ return score / math.sqrt(2)
diff --git a/libs/langchain/langchain/vectorstores/cratedb/extended.py b/libs/langchain/langchain/vectorstores/cratedb/extended.py
new file mode 100644
index 0000000000000..553ce281c5e24
--- /dev/null
+++ b/libs/langchain/langchain/vectorstores/cratedb/extended.py
@@ -0,0 +1,103 @@
+import logging
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+)
+
+import sqlalchemy
+from sqlalchemy.orm import sessionmaker
+
+from langchain.schema.embeddings import Embeddings
+from langchain.vectorstores.cratedb.base import (
+ DEFAULT_DISTANCE_STRATEGY,
+ CrateDBVectorSearch,
+ DistanceStrategy,
+ StorageStrategy,
+)
+from langchain.vectorstores.pgvector import _LANGCHAIN_DEFAULT_COLLECTION_NAME
+
+
+class CrateDBVectorSearchMultiCollection(CrateDBVectorSearch):
+ """
+ Provide functionality for searching multiple collections.
+ It can not be used for indexing documents.
+
+ To use it, you should have the ``crate[sqlalchemy]`` Python package installed.
+
+ Synopsis::
+
+ from langchain.vectorstores.cratedb import CrateDBVectorSearchMultiCollection
+
+ multisearch = CrateDBVectorSearchMultiCollection(
+ collection_names=["collection_foo", "collection_bar"],
+ embedding_function=embeddings,
+ connection_string=CONNECTION_STRING,
+ )
+ docs_with_score = multisearch.similarity_search_with_score(query)
+ """
+
+ def __init__(
+ self,
+ connection_string: str,
+ embedding_function: Embeddings,
+ collection_names: List[str] = [_LANGCHAIN_DEFAULT_COLLECTION_NAME],
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, # type: ignore[arg-type]
+ logger: Optional[logging.Logger] = None,
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
+ *,
+ connection: Optional[sqlalchemy.engine.Connection] = None,
+ engine_args: Optional[dict[str, Any]] = None,
+ ) -> None:
+ # Sanity checks.
+ # TODO: The CrateDBVectorSearchMultiCollection access variant needs further
+ # adjustments to support the EMBEDDING_TABLE_PER_COLLECTION storage
+ # strategy.
+ if self.STORAGE_STRATEGY is not StorageStrategy.LANGCHAIN_PGVECTOR:
+ raise NotImplementedError(
+ f"Multi-collection querying not supported "
+ f"by strategy: {self.STORAGE_STRATEGY}"
+ )
+
+ self.connection_string = connection_string
+ self.embedding_function = embedding_function
+ self.collection_names = collection_names
+ self._distance_strategy = distance_strategy # type: ignore[assignment]
+ self.logger = logger or logging.getLogger(__name__)
+ self.override_relevance_score_fn = relevance_score_fn
+ self.engine_args = engine_args or {}
+ # Create a connection if not provided, otherwise use the provided connection
+ self._engine = self.create_engine()
+ self.Session = sessionmaker(self._engine)
+ self._conn = connection if connection else self.connect()
+ self.__post_init__()
+
+ @classmethod
+ def _from(cls, *args: List, **kwargs: Dict): # type: ignore[no-untyped-def,override]
+ raise NotImplementedError("This adapter can not be used for indexing documents")
+
+ def get_collections(self, session: sqlalchemy.orm.Session) -> Any:
+ if self.CollectionStore is None:
+ raise RuntimeError(
+ "Collection can't be accessed without specifying "
+ "dimension size of embedding vectors"
+ )
+ return self.CollectionStore.get_by_names(session, self.collection_names)
+
+ def _query_collection(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[Dict[str, str]] = None,
+ ) -> List[Any]:
+ """Query multiple collections."""
+ self._init_models(embedding)
+ with self.Session() as session:
+ collections = self.get_collections(session)
+ if not collections:
+ raise ValueError("No collections found")
+ return self._query_collection_multi(
+ collections=collections, embedding=embedding, k=k, filter=filter
+ )
diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py
new file mode 100644
index 0000000000000..167b75ba23e13
--- /dev/null
+++ b/libs/langchain/langchain/vectorstores/cratedb/model.py
@@ -0,0 +1,117 @@
+import uuid
+from typing import Any, List, Optional, Tuple
+
+import sqlalchemy
+from crate.client.sqlalchemy.types import ObjectType
+from sqlalchemy.orm import Session, declarative_base, relationship
+
+from langchain.vectorstores.cratedb.sqlalchemy_type import FloatVector
+
+
+def generate_uuid() -> str:
+ return str(uuid.uuid4())
+
+
+class ModelFactory:
+ """Provide SQLAlchemy model objects at runtime."""
+
+ def __init__(
+ self,
+ dimensions: Optional[int] = None,
+ collection_table: Optional[str] = None,
+ embedding_table: Optional[str] = None,
+ ):
+ # While it does not have any function here, you will still need to supply a
+ # dummy dimension size value for operations like deleting records.
+ self.dimensions = dimensions or 1024
+
+ # Set default values for table names.
+ collection_table = collection_table or "collection"
+ embedding_table = embedding_table or "embedding"
+
+ Base: Any = declarative_base(class_registry=dict())
+
+ # Optional: Use a custom schema for the langchain tables.
+ # Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any
+
+ class BaseModel(Base):
+ """Base model for the SQL stores."""
+
+ __abstract__ = True
+ uuid = sqlalchemy.Column(
+ sqlalchemy.String, primary_key=True, default=generate_uuid
+ )
+
+ class CollectionStore(BaseModel):
+ """Collection store."""
+
+ __tablename__ = collection_table
+
+ name = sqlalchemy.Column(sqlalchemy.String)
+ cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType)
+
+ embeddings = relationship(
+ "EmbeddingStore",
+ back_populates="collection",
+ cascade="all, delete-orphan",
+ passive_deletes=False,
+ )
+
+ @classmethod
+ def get_by_name(cls, session: Session, name: str) -> "CollectionStore":
+ return session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined]
+
+ @classmethod
+ def get_by_names(
+ cls, session: Session, names: List[str]
+ ) -> List["CollectionStore"]:
+ return session.query(cls).filter(cls.name.in_(names)).all() # type: ignore[attr-defined]
+
+ @classmethod
+ def get_or_create(
+ cls,
+ session: Session,
+ name: str,
+ cmetadata: Optional[dict] = None,
+ ) -> Tuple["CollectionStore", bool]:
+ """
+ Get or create a collection.
+ Returns [Collection, bool] where the bool is True
+ if the collection was created.
+ """
+ created = False
+ collection = cls.get_by_name(session, name)
+ if collection:
+ return collection, created
+
+ collection = cls(name=name, cmetadata=cmetadata)
+ session.add(collection)
+ session.commit()
+ created = True
+ return collection, created
+
+ class EmbeddingStore(BaseModel):
+ """Embedding store."""
+
+ __tablename__ = embedding_table
+
+ collection_id = sqlalchemy.Column(
+ sqlalchemy.String,
+ sqlalchemy.ForeignKey(
+ f"{CollectionStore.__tablename__}.uuid",
+ ondelete="CASCADE",
+ ),
+ )
+ collection = relationship("CollectionStore", back_populates="embeddings")
+
+ embedding = sqlalchemy.Column(FloatVector(self.dimensions))
+ document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
+ cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True)
+
+ # custom_id : any user defined id
+ custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
+
+ self.Base = Base
+ self.BaseModel = BaseModel
+ self.CollectionStore = CollectionStore
+ self.EmbeddingStore = EmbeddingStore
diff --git a/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py b/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py
new file mode 100644
index 0000000000000..e784c3013a3d9
--- /dev/null
+++ b/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py
@@ -0,0 +1,84 @@
+# TODO: Refactor to CrateDB SQLAlchemy dialect.
+import typing as t
+
+import numpy as np
+import numpy.typing as npt
+import sqlalchemy as sa
+from sqlalchemy.types import UserDefinedType
+
+__all__ = ["FloatVector"]
+
+
+def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]:
+ # from `pgvector.utils`
+ # could be ndarray if already cast by lower-level driver
+ if value is None or isinstance(value, np.ndarray):
+ return value
+
+ return np.array(value, dtype=np.float32)
+
+
+def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]:
+ # from `pgvector.utils`
+ if value is None:
+ return value
+
+ if isinstance(value, np.ndarray):
+ if value.ndim != 1:
+ raise ValueError("expected ndim to be 1")
+
+ if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype(
+ value.dtype, np.floating
+ ):
+ raise ValueError("dtype must be numeric")
+
+ value = value.tolist()
+
+ if dim is not None and len(value) != dim:
+ raise ValueError("expected %d dimensions, not %d" % (dim, len(value)))
+
+ return value
+
+
+class FloatVector(UserDefinedType):
+ """
+ https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector
+ https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match
+ """
+
+ cache_ok = True
+
+ def __init__(self, dim: t.Optional[int] = None) -> None:
+ super(UserDefinedType, self).__init__()
+ self.dim = dim
+
+ def get_col_spec(self, **kw: t.Any) -> str:
+ if self.dim is None:
+ return "FLOAT_VECTOR"
+ return "FLOAT_VECTOR(%d)" % self.dim
+
+ def bind_processor(self, dialect: sa.Dialect) -> t.Callable:
+ def process(value: t.Iterable) -> t.Optional[t.List]:
+ return to_db(value, self.dim)
+
+ return process
+
+ def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable:
+ def process(value: t.Any) -> t.Optional[npt.ArrayLike]:
+ return from_db(value)
+
+ return process
+
+ """
+ CrateDB currently only supports similarity function `VectorSimilarityFunction.EUCLIDEAN`.
+ -- https://github.com/crate/crate/blob/1ca5c6dbb2/server/src/main/java/io/crate/types/FloatVectorType.java#L55
+
+ On the other hand, pgvector use a comparator to apply different similarity functions as operators,
+ see `pgvector.sqlalchemy.Vector.comparator_factory`.
+
+ <->: l2/euclidean_distance
+ <#>: max_inner_product
+ <=>: cosine_distance
+
+ TODO: Discuss.
+ """ # noqa: E501
diff --git a/libs/langchain/langchain/vectorstores/pgvector.py b/libs/langchain/langchain/vectorstores/pgvector.py
index e1a58e81afdde..030f1962e43ae 100644
--- a/libs/langchain/langchain/vectorstores/pgvector.py
+++ b/libs/langchain/langchain/vectorstores/pgvector.py
@@ -23,12 +23,12 @@
import sqlalchemy
from sqlalchemy import delete
from sqlalchemy.dialects.postgresql import UUID
-from sqlalchemy.orm import Session
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import Session, sessionmaker
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
@@ -130,6 +130,8 @@ def __init__(
self.override_relevance_score_fn = relevance_score_fn
self.engine_args = engine_args or {}
# Create a connection if not provided, otherwise use the provided connection
+ self._engine = self.create_engine()
+ self.Session = sessionmaker(self._engine)
self._conn = connection if connection else self.connect()
self.__post_init__()
@@ -156,14 +158,15 @@ def __del__(self) -> None:
def embeddings(self) -> Embeddings:
return self.embedding_function
+ def create_engine(self) -> sqlalchemy.Engine:
+ return sqlalchemy.create_engine(self.connection_string, echo=False)
+
def connect(self) -> sqlalchemy.engine.Connection:
- engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
- conn = engine.connect()
- return conn
+ return self._engine.connect()
def create_vector_extension(self) -> None:
try:
- with Session(self._conn) as session:
+ with self.Session() as session:
# The advisor lock fixes issue arising from concurrent
# creation of the vector extension.
# https://github.com/langchain-ai/langchain/issues/12933
@@ -181,24 +184,22 @@ def create_vector_extension(self) -> None:
raise Exception(f"Failed to create vector extension: {e}") from e
def create_tables_if_not_exists(self) -> None:
- with self._conn.begin():
- Base.metadata.create_all(self._conn)
+ Base.metadata.create_all(self._engine)
def drop_tables(self) -> None:
- with self._conn.begin():
- Base.metadata.drop_all(self._conn)
+ Base.metadata.drop_all(self._engine)
def create_collection(self) -> None:
if self.pre_delete_collection:
self.delete_collection()
- with Session(self._conn) as session:
+ with self.Session() as session:
self.CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata
)
def delete_collection(self) -> None:
self.logger.debug("Trying to delete collection")
- with Session(self._conn) as session:
+ with self.Session() as session:
collection = self.get_collection(session)
if not collection:
self.logger.warning("Collection not found")
@@ -209,7 +210,7 @@ def delete_collection(self) -> None:
@contextlib.contextmanager
def _make_session(self) -> Generator[Session, None, None]:
"""Create a context manager for the session, bind to _conn string."""
- yield Session(self._conn)
+ yield self.Session()
def delete(
self,
@@ -221,7 +222,7 @@ def delete(
Args:
ids: List of ids to delete.
"""
- with Session(self._conn) as session:
+ with self.Session() as session:
if ids is not None:
self.logger.debug(
"Trying to delete vectors by ids (represented by the model "
@@ -237,7 +238,7 @@ def get_collection(self, session: Session) -> Optional["CollectionStore"]:
return self.CollectionStore.get_by_name(session, self.collection_name)
@classmethod
- def __from(
+ def _from(
cls,
texts: List[str],
embeddings: List[List[float]],
@@ -295,10 +296,11 @@ def add_embeddings(
if not metadatas:
metadatas = [{} for _ in texts]
- with Session(self._conn) as session:
+ with self.Session() as session:
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
+ documents = []
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
embedding_store = self.EmbeddingStore(
embedding=embedding,
@@ -307,7 +309,8 @@ def add_embeddings(
custom_id=id,
collection_id=collection.uuid,
)
- session.add(embedding_store)
+ documents.append(embedding_store)
+ session.bulk_save_objects(documents)
session.commit()
return ids
@@ -400,7 +403,7 @@ def similarity_search_with_score_by_vector(
k: int = 4,
filter: Optional[dict] = None,
) -> List[Tuple[Document, float]]:
- results = self.__query_collection(embedding=embedding, k=k, filter=filter)
+ results = self._query_collection(embedding=embedding, k=k, filter=filter)
return self._results_to_docs_and_scores(results)
@@ -418,14 +421,14 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa
]
return docs
- def __query_collection(
+ def _query_collection(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
) -> List[Any]:
"""Query the collection."""
- with Session(self._conn) as session:
+ with self.Session() as session:
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
@@ -512,7 +515,7 @@ def from_texts(
"""
embeddings = embedding.embed_documents(list(texts))
- return cls.__from(
+ return cls._from(
texts,
embeddings,
embedding,
@@ -557,7 +560,7 @@ def from_embeddings(
texts = [t[0] for t in text_embeddings]
embeddings = [t[1] for t in text_embeddings]
- return cls.__from(
+ return cls._from(
texts,
embeddings,
embedding,
@@ -718,7 +721,7 @@ def max_marginal_relevance_search_with_score_by_vector(
List[Tuple[Document, float]]: List of Documents selected by maximal marginal
relevance to the query and score for each.
"""
- results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)
+ results = self._query_collection(embedding=embedding, k=fetch_k, filter=filter)
embedding_list = [result.EmbeddingStore.embedding for result in results]
diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml
index b9bf2b0546d1c..30b515c099718 100644
--- a/libs/langchain/pyproject.toml
+++ b/libs/langchain/pyproject.toml
@@ -147,6 +147,8 @@ praw = {version = "^7.7.1", optional = true}
msal = {version = "^1.25.0", optional = true}
databricks-vectorsearch = {version = "^0.21", optional = true}
dgml-utils = {version = "^0.3.0", optional = true}
+crate = {version = "^0.34.0", extras=["sqlalchemy"], optional = true}
+cratedb-toolkit = {version = ">=0.0.1", optional = true}
[tool.poetry.group.test.dependencies]
# The only dependencies that should be added are
@@ -167,6 +169,7 @@ pytest-socket = "^0.6.0"
syrupy = "^4.0.2"
requests-mock = "^1.11.0"
langchain-core = {path = "../core", develop = true}
+sqlparse = "^0.4.4"
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
@@ -229,6 +232,7 @@ cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
javascript = ["esprima"]
+cratedb = ["crate", "cratedb-toolkit"]
azure = [
"azure-identity",
"azure-cosmos",
@@ -315,6 +319,8 @@ all = [
"librosa",
"python-arango",
"dgml-utils",
+ "crate",
+ "cratedb-toolkit",
]
cli = [
@@ -388,6 +394,8 @@ extended_testing = [
"databricks-vectorsearch",
"dgml-utils",
"cohere",
+ "crate",
+ "cratedb-toolkit",
]
[tool.ruff]
diff --git a/libs/langchain/tests/data.py b/libs/langchain/tests/data.py
index c3b240bbc57cd..be48867430641 100644
--- a/libs/langchain/tests/data.py
+++ b/libs/langchain/tests/data.py
@@ -9,3 +9,7 @@
HELLO_PDF = _EXAMPLES_DIR / "hello.pdf"
LAYOUT_PARSER_PAPER_PDF = _EXAMPLES_DIR / "layout-parser-paper.pdf"
DUPLICATE_CHARS = _EXAMPLES_DIR / "duplicate-chars.pdf"
+
+# Paths to data files
+MLB_TEAMS_2012_CSV = _EXAMPLES_DIR / "mlb_teams_2012.csv"
+MLB_TEAMS_2012_SQL = _EXAMPLES_DIR / "mlb_teams_2012.sql"
diff --git a/libs/langchain/tests/integration_tests/document_loaders/docker-compose/cratedb.yml b/libs/langchain/tests/integration_tests/document_loaders/docker-compose/cratedb.yml
new file mode 100644
index 0000000000000..b547b2c766f20
--- /dev/null
+++ b/libs/langchain/tests/integration_tests/document_loaders/docker-compose/cratedb.yml
@@ -0,0 +1,20 @@
+version: "3"
+
+services:
+ postgresql:
+ image: crate/crate:nightly
+ environment:
+ - CRATE_HEAP_SIZE=4g
+ ports:
+ - "4200:4200"
+ - "5432:5432"
+ command: |
+ crate -Cdiscovery.type=single-node
+ healthcheck:
+ test:
+ [
+ "CMD-SHELL",
+ "curl --silent --fail http://localhost:4200/ || exit 1",
+ ]
+ interval: 5s
+ retries: 60
diff --git a/libs/langchain/tests/integration_tests/document_loaders/docker-compose/postgresql.yml b/libs/langchain/tests/integration_tests/document_loaders/docker-compose/postgresql.yml
new file mode 100644
index 0000000000000..f8ab2cfdeb418
--- /dev/null
+++ b/libs/langchain/tests/integration_tests/document_loaders/docker-compose/postgresql.yml
@@ -0,0 +1,19 @@
+version: "3"
+
+services:
+ postgresql:
+ image: postgres:16
+ environment:
+ - POSTGRES_HOST_AUTH_METHOD=trust
+ ports:
+ - "5432:5432"
+ command: |
+ postgres -c log_statement=all
+ healthcheck:
+ test:
+ [
+ "CMD-SHELL",
+ "psql postgresql://postgres@localhost --command 'SELECT 1;' || exit 1",
+ ]
+ interval: 5s
+ retries: 60
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py
new file mode 100644
index 0000000000000..eec3a428a74e8
--- /dev/null
+++ b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py
@@ -0,0 +1,146 @@
+"""
+Test SQLAlchemy/CrateDB document loader functionality.
+
+cd tests/integration_tests/document_loaders/docker-compose
+docker-compose -f cratedb.yml up
+"""
+import logging
+import os
+import unittest
+
+import pytest
+import sqlalchemy as sa
+import sqlparse
+
+from langchain.document_loaders import CrateDBLoader
+from tests.data import MLB_TEAMS_2012_SQL
+
+logging.basicConfig(level=logging.DEBUG)
+
+try:
+ import crate.client.sqlalchemy # noqa: F401
+
+ crate_client_installed = True
+except ImportError:
+ crate_client_installed = False
+
+
+CONNECTION_STRING = os.environ.get(
+ "TEST_CRATEDB_CONNECTION_STRING", "crate://crate@localhost/?schema=testdrive"
+)
+
+
+@pytest.fixture
+def engine() -> sa.Engine:
+ """
+ Return an SQLAlchemy engine object.
+ """
+ return sa.create_engine(CONNECTION_STRING, echo=False)
+
+
+@pytest.fixture()
+def provision_database(engine: sa.Engine) -> None:
+ """
+ Provision database with table schema and data.
+ """
+ sql_statements = MLB_TEAMS_2012_SQL.read_text()
+ with engine.connect() as connection:
+ connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;"))
+ for statement in sqlparse.split(sql_statements):
+ connection.execute(sa.text(statement))
+ connection.execute(sa.text("REFRESH TABLE mlb_teams_2012;"))
+ connection.commit()
+
+
+@unittest.skipIf(not crate_client_installed, "CrateDB client not installed")
+def test_cratedb_loader_no_options() -> None:
+ """Test SQLAlchemy loader with CrateDB."""
+
+ loader = CrateDBLoader("SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING)
+ docs = loader.load()
+
+ assert len(docs) == 1
+ assert docs[0].page_content == "a: 1\nb: 2"
+ assert docs[0].metadata == {}
+
+
+@unittest.skipIf(not crate_client_installed, "CrateDB client not installed")
+def test_cratedb_loader_page_content_columns() -> None:
+ """Test SQLAlchemy loader with CrateDB."""
+
+ loader = CrateDBLoader(
+ "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b",
+ url=CONNECTION_STRING,
+ page_content_columns=["a"],
+ )
+ docs = loader.load()
+
+ assert len(docs) == 2
+ assert docs[0].page_content == "a: 1"
+ assert docs[0].metadata == {}
+
+ assert docs[1].page_content == "a: 3"
+ assert docs[1].metadata == {}
+
+
+@unittest.skipIf(not crate_client_installed, "CrateDB client not installed")
+def test_cratedb_loader_metadata_columns() -> None:
+ """Test SQLAlchemy loader with CrateDB."""
+
+ loader = CrateDBLoader(
+ "SELECT 1 AS a, 2 AS b",
+ url=CONNECTION_STRING,
+ page_content_columns=["a"],
+ metadata_columns=["b"],
+ )
+ docs = loader.load()
+
+ assert len(docs) == 1
+ assert docs[0].page_content == "a: 1"
+ assert docs[0].metadata == {"b": 2}
+
+
+@unittest.skipIf(not crate_client_installed, "CrateDB client not installed")
+def test_cratedb_loader_real_data_with_sql(provision_database: None) -> None:
+ """Test SQLAlchemy loader with CrateDB."""
+
+ loader = CrateDBLoader(
+ query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";',
+ url=CONNECTION_STRING,
+ )
+ docs = loader.load()
+
+ assert len(docs) == 30
+ assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89"
+ assert docs[0].metadata == {}
+
+
+@unittest.skipIf(not crate_client_installed, "CrateDB client not installed")
+def test_cratedb_loader_real_data_with_selectable(provision_database: None) -> None:
+ """Test SQLAlchemy loader with CrateDB."""
+
+ # Define an SQLAlchemy table.
+ mlb_teams_2012 = sa.Table(
+ "mlb_teams_2012",
+ sa.MetaData(),
+ sa.Column("Team", sa.VARCHAR),
+ sa.Column("Payroll (millions)", sa.FLOAT),
+ sa.Column("Wins", sa.BIGINT),
+ )
+
+ # Query the database table using an SQLAlchemy selectable.
+ select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team)
+ loader = CrateDBLoader(
+ query=select,
+ url=CONNECTION_STRING,
+ include_query_into_metadata=True,
+ )
+ docs = loader.load()
+
+ assert len(docs) == 30
+ assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89"
+ assert docs[0].metadata == {
+ "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", '
+ 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 '
+ 'ORDER BY mlb_teams_2012."Team"'
+ }
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py
new file mode 100644
index 0000000000000..29f52cb9f7a33
--- /dev/null
+++ b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py
@@ -0,0 +1,177 @@
+"""
+Test SQLAlchemy/PostgreSQL document loader functionality.
+
+cd tests/integration_tests/document_loaders/docker-compose
+docker-compose -f postgresql.yml up
+"""
+import logging
+import os
+import unittest
+
+import pytest
+import sqlalchemy as sa
+import sqlparse
+
+from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader
+from tests.data import MLB_TEAMS_2012_SQL
+
+logging.basicConfig(level=logging.DEBUG)
+
+
+try:
+ import psycopg2 # noqa: F401
+
+ psycopg2_installed = True
+except ImportError:
+ psycopg2_installed = False
+
+
+CONNECTION_STRING = os.environ.get(
+ "TEST_POSTGRESQL_CONNECTION_STRING",
+ "postgresql+psycopg2://postgres@localhost:5432/",
+)
+
+
+@pytest.fixture
+def engine() -> sa.Engine:
+ """
+ Return an SQLAlchemy engine object.
+ """
+ return sa.create_engine(CONNECTION_STRING, echo=False)
+
+
+@pytest.fixture()
+def provision_database(engine: sa.Engine) -> None:
+ """
+ Provision database with table schema and data.
+ """
+ sql_statements = MLB_TEAMS_2012_SQL.read_text()
+ with engine.connect() as connection:
+ connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;"))
+ for statement in sqlparse.split(sql_statements):
+ connection.execute(sa.text(statement))
+ connection.commit()
+
+
+@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed")
+def test_postgresql_loader_no_options() -> None:
+ """Test SQLAlchemy loader with psycopg2."""
+
+ loader = SQLAlchemyLoader("SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING)
+ docs = loader.load()
+
+ assert len(docs) == 1
+ assert docs[0].page_content == "a: 1\nb: 2"
+ assert docs[0].metadata == {}
+
+
+@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed")
+def test_postgresql_loader_include_rownum_into_metadata() -> None:
+ """Test SQLAlchemy loader with psycopg2."""
+
+ loader = SQLAlchemyLoader(
+ "SELECT 1 AS a, 2 AS b",
+ url=CONNECTION_STRING,
+ include_rownum_into_metadata=True,
+ )
+ docs = loader.load()
+
+ assert len(docs) == 1
+ assert docs[0].page_content == "a: 1\nb: 2"
+ assert docs[0].metadata == {"row": 0}
+
+
+@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed")
+def test_postgresql_loader_include_query_into_metadata() -> None:
+ """Test SQLAlchemy loader with psycopg2."""
+
+ loader = SQLAlchemyLoader(
+ "SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING, include_query_into_metadata=True
+ )
+ docs = loader.load()
+
+ assert len(docs) == 1
+ assert docs[0].page_content == "a: 1\nb: 2"
+ assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"}
+
+
+@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed")
+def test_postgresql_loader_page_content_columns() -> None:
+ """Test SQLAlchemy loader with psycopg2."""
+
+ loader = SQLAlchemyLoader(
+ "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b",
+ url=CONNECTION_STRING,
+ page_content_columns=["a"],
+ )
+ docs = loader.load()
+
+ assert len(docs) == 2
+ assert docs[0].page_content == "a: 1"
+ assert docs[0].metadata == {}
+
+ assert docs[1].page_content == "a: 3"
+ assert docs[1].metadata == {}
+
+
+@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed")
+def test_postgresql_loader_metadata_columns() -> None:
+ """Test SQLAlchemy loader with psycopg2."""
+
+ loader = SQLAlchemyLoader(
+ "SELECT 1 AS a, 2 AS b",
+ url=CONNECTION_STRING,
+ page_content_columns=["a"],
+ metadata_columns=["b"],
+ )
+ docs = loader.load()
+
+ assert len(docs) == 1
+ assert docs[0].page_content == "a: 1"
+ assert docs[0].metadata == {"b": 2}
+
+
+@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed")
+def test_postgresql_loader_real_data_with_sql(provision_database: None) -> None:
+ """Test SQLAlchemy loader with psycopg2."""
+
+ loader = SQLAlchemyLoader(
+ query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";',
+ url=CONNECTION_STRING,
+ )
+ docs = loader.load()
+
+ assert len(docs) == 30
+ assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89"
+ assert docs[0].metadata == {}
+
+
+@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed")
+def test_postgresql_loader_real_data_with_selectable(provision_database: None) -> None:
+ """Test SQLAlchemy loader with psycopg2."""
+
+ # Define an SQLAlchemy table.
+ mlb_teams_2012 = sa.Table(
+ "mlb_teams_2012",
+ sa.MetaData(),
+ sa.Column("Team", sa.VARCHAR),
+ sa.Column("Payroll (millions)", sa.FLOAT),
+ sa.Column("Wins", sa.BIGINT),
+ )
+
+ # Query the database table using an SQLAlchemy selectable.
+ select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team)
+ loader = SQLAlchemyLoader(
+ query=select,
+ url=CONNECTION_STRING,
+ include_query_into_metadata=True,
+ )
+ docs = loader.load()
+
+ assert len(docs) == 30
+ assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89"
+ assert docs[0].metadata == {
+ "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", '
+ 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 '
+ 'ORDER BY mlb_teams_2012."Team"'
+ }
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py
new file mode 100644
index 0000000000000..f1fac2cecbc00
--- /dev/null
+++ b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py
@@ -0,0 +1,181 @@
+"""
+Test SQLAlchemy/SQLite document loader functionality.
+"""
+import logging
+import unittest
+
+import pytest
+import sqlalchemy as sa
+import sqlparse
+from _pytest.tmpdir import TempPathFactory
+
+from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader
+from tests.data import MLB_TEAMS_2012_SQL
+
+logging.basicConfig(level=logging.DEBUG)
+
+
+try:
+ import sqlite3 # noqa: F401
+
+ sqlite3_installed = True
+except ImportError:
+ sqlite3_installed = False
+
+
+@pytest.fixture(scope="module")
+def db_uri(tmp_path_factory: TempPathFactory) -> str:
+ """
+ Return an SQLAlchemy URI for a temporary SQLite database.
+ """
+ db_path = tmp_path_factory.getbasetemp().joinpath("testdrive.sqlite")
+ return f"sqlite:///{db_path}"
+
+
+@pytest.fixture(scope="module")
+def engine(db_uri: str) -> sa.Engine:
+ """
+ Return an SQLAlchemy engine object.
+ """
+ return sa.create_engine(db_uri, echo=False)
+
+
+@pytest.fixture()
+def provision_database(engine: sa.Engine) -> None:
+ """
+ Provision database with table schema and data.
+ """
+ sql_statements = MLB_TEAMS_2012_SQL.read_text()
+ with engine.connect() as connection:
+ connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;"))
+ for statement in sqlparse.split(sql_statements):
+ connection.execute(sa.text(statement))
+ connection.commit()
+
+
+@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed")
+def test_sqlite_loader_no_options(db_uri: str) -> None:
+ """Test SQLAlchemy loader with sqlite3."""
+
+ loader = SQLAlchemyLoader("SELECT 1 AS a, 2 AS b", url=db_uri)
+ docs = loader.load()
+
+ assert len(docs) == 1
+ assert docs[0].page_content == "a: 1\nb: 2"
+ assert docs[0].metadata == {}
+
+
+@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed")
+def test_sqlite_loader_include_rownum_into_metadata(db_uri: str) -> None:
+ """Test SQLAlchemy loader with sqlite3."""
+
+ loader = SQLAlchemyLoader(
+ "SELECT 1 AS a, 2 AS b",
+ url=db_uri,
+ include_rownum_into_metadata=True,
+ )
+ docs = loader.load()
+
+ assert len(docs) == 1
+ assert docs[0].page_content == "a: 1\nb: 2"
+ assert docs[0].metadata == {"row": 0}
+
+
+@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed")
+def test_sqlite_loader_include_query_into_metadata(db_uri: str) -> None:
+ """Test SQLAlchemy loader with sqlite3."""
+
+ loader = SQLAlchemyLoader(
+ "SELECT 1 AS a, 2 AS b", url=db_uri, include_query_into_metadata=True
+ )
+ docs = loader.load()
+
+ assert len(docs) == 1
+ assert docs[0].page_content == "a: 1\nb: 2"
+ assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"}
+
+
+@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed")
+def test_sqlite_loader_page_content_columns(db_uri: str) -> None:
+ """Test SQLAlchemy loader with sqlite3."""
+
+ loader = SQLAlchemyLoader(
+ "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b",
+ url=db_uri,
+ page_content_columns=["a"],
+ )
+ docs = loader.load()
+
+ assert len(docs) == 2
+ assert docs[0].page_content == "a: 1"
+ assert docs[0].metadata == {}
+
+ assert docs[1].page_content == "a: 3"
+ assert docs[1].metadata == {}
+
+
+@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed")
+def test_sqlite_loader_metadata_columns(db_uri: str) -> None:
+ """Test SQLAlchemy loader with sqlite3."""
+
+ loader = SQLAlchemyLoader(
+ "SELECT 1 AS a, 2 AS b",
+ url=db_uri,
+ page_content_columns=["a"],
+ metadata_columns=["b"],
+ )
+ docs = loader.load()
+
+ assert len(docs) == 1
+ assert docs[0].page_content == "a: 1"
+ assert docs[0].metadata == {"b": 2}
+
+
+@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed")
+def test_sqlite_loader_real_data_with_sql(
+ db_uri: str, provision_database: None
+) -> None:
+ """Test SQLAlchemy loader with sqlite3."""
+
+ loader = SQLAlchemyLoader(
+ query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";',
+ url=db_uri,
+ )
+ docs = loader.load()
+
+ assert len(docs) == 30
+ assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89"
+ assert docs[0].metadata == {}
+
+
+@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed")
+def test_sqlite_loader_real_data_with_selectable(
+ db_uri: str, provision_database: None
+) -> None:
+ """Test SQLAlchemy loader with sqlite3."""
+
+ # Define an SQLAlchemy table.
+ mlb_teams_2012 = sa.Table(
+ "mlb_teams_2012",
+ sa.MetaData(),
+ sa.Column("Team", sa.VARCHAR),
+ sa.Column("Payroll (millions)", sa.FLOAT),
+ sa.Column("Wins", sa.BIGINT),
+ )
+
+ # Query the database table using an SQLAlchemy selectable.
+ select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team)
+ loader = SQLAlchemyLoader(
+ query=select,
+ url=db_uri,
+ include_query_into_metadata=True,
+ )
+ docs = loader.load()
+
+ assert len(docs) == 30
+ assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89"
+ assert docs[0].metadata == {
+ "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", '
+ 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 '
+ 'ORDER BY mlb_teams_2012."Team"'
+ }
diff --git a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv
new file mode 100644
index 0000000000000..b22ae961a1331
--- /dev/null
+++ b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv
@@ -0,0 +1,32 @@
+"Team", "Payroll (millions)", "Wins"
+"Nationals", 81.34, 98
+"Reds", 82.20, 97
+"Yankees", 197.96, 95
+"Giants", 117.62, 94
+"Braves", 83.31, 94
+"Athletics", 55.37, 94
+"Rangers", 120.51, 93
+"Orioles", 81.43, 93
+"Rays", 64.17, 90
+"Angels", 154.49, 89
+"Tigers", 132.30, 88
+"Cardinals", 110.30, 88
+"Dodgers", 95.14, 86
+"White Sox", 96.92, 85
+"Brewers", 97.65, 83
+"Phillies", 174.54, 81
+"Diamondbacks", 74.28, 81
+"Pirates", 63.43, 79
+"Padres", 55.24, 76
+"Mariners", 81.97, 75
+"Mets", 93.35, 74
+"Blue Jays", 75.48, 73
+"Royals", 60.91, 72
+"Marlins", 118.07, 69
+"Red Sox", 173.18, 69
+"Indians", 78.43, 68
+"Twins", 94.08, 66
+"Rockies", 78.06, 64
+"Cubs", 88.19, 61
+"Astros", 60.65, 55
+
diff --git a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql
new file mode 100644
index 0000000000000..9df72ef19954a
--- /dev/null
+++ b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql
@@ -0,0 +1,41 @@
+-- Provisioning table "mlb_teams_2012".
+--
+-- psql postgresql://postgres@localhost < mlb_teams_2012.sql
+-- crash < mlb_teams_2012.sql
+
+DROP TABLE IF EXISTS mlb_teams_2012;
+CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT);
+INSERT INTO mlb_teams_2012
+ ("Team", "Payroll (millions)", "Wins")
+VALUES
+ ('Nationals', 81.34, 98),
+ ('Reds', 82.20, 97),
+ ('Yankees', 197.96, 95),
+ ('Giants', 117.62, 94),
+ ('Braves', 83.31, 94),
+ ('Athletics', 55.37, 94),
+ ('Rangers', 120.51, 93),
+ ('Orioles', 81.43, 93),
+ ('Rays', 64.17, 90),
+ ('Angels', 154.49, 89),
+ ('Tigers', 132.30, 88),
+ ('Cardinals', 110.30, 88),
+ ('Dodgers', 95.14, 86),
+ ('White Sox', 96.92, 85),
+ ('Brewers', 97.65, 83),
+ ('Phillies', 174.54, 81),
+ ('Diamondbacks', 74.28, 81),
+ ('Pirates', 63.43, 79),
+ ('Padres', 55.24, 76),
+ ('Mariners', 81.97, 75),
+ ('Mets', 93.35, 74),
+ ('Blue Jays', 75.48, 73),
+ ('Royals', 60.91, 72),
+ ('Marlins', 118.07, 69),
+ ('Red Sox', 173.18, 69),
+ ('Indians', 78.43, 68),
+ ('Twins', 94.08, 66),
+ ('Rockies', 78.06, 64),
+ ('Cubs', 88.19, 61),
+ ('Astros', 60.65, 55)
+;
diff --git a/libs/langchain/tests/integration_tests/memory/test_cratedb.py b/libs/langchain/tests/integration_tests/memory/test_cratedb.py
new file mode 100644
index 0000000000000..2c00b5d2b200b
--- /dev/null
+++ b/libs/langchain/tests/integration_tests/memory/test_cratedb.py
@@ -0,0 +1,170 @@
+import json
+import os
+from typing import Any, Generator, Tuple
+
+import pytest
+import sqlalchemy as sa
+from sqlalchemy import Column, Integer, Text
+from sqlalchemy.orm import DeclarativeBase
+
+from langchain.memory import ConversationBufferMemory
+from langchain.memory.chat_message_histories import CrateDBChatMessageHistory
+from langchain.memory.chat_message_histories.sql import DefaultMessageConverter
+from langchain.schema.messages import AIMessage, HumanMessage, _message_to_dict
+
+
+@pytest.fixture()
+def connection_string() -> str:
+ return os.environ.get(
+ "TEST_CRATEDB_CONNECTION_STRING", "crate://crate@localhost/?schema=testdrive"
+ )
+
+
+@pytest.fixture()
+def engine(connection_string: str) -> sa.Engine:
+ """
+ Return an SQLAlchemy engine object.
+ """
+ return sa.create_engine(connection_string, echo=True)
+
+
+@pytest.fixture(autouse=True)
+def reset_database(engine: sa.Engine) -> None:
+ """
+ Provision database with table schema and data.
+ """
+ with engine.connect() as connection:
+ connection.execute(sa.text("DROP TABLE IF EXISTS test_table;"))
+ connection.commit()
+
+
+@pytest.fixture()
+def sql_histories(
+ connection_string: str,
+) -> Generator[Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory], None, None]:
+ """
+ Provide the test cases with data fixtures.
+ """
+ message_history = CrateDBChatMessageHistory(
+ session_id="123", connection_string=connection_string, table_name="test_table"
+ )
+ # Create history for other session
+ other_history = CrateDBChatMessageHistory(
+ session_id="456", connection_string=connection_string, table_name="test_table"
+ )
+
+ yield message_history, other_history
+ message_history.clear()
+ other_history.clear()
+
+
+def test_add_messages(
+ sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory]
+) -> None:
+ history1, _ = sql_histories
+ history1.add_user_message("Hello!")
+ history1.add_ai_message("Hi there!")
+
+ messages = history1.messages
+ assert len(messages) == 2
+ assert isinstance(messages[0], HumanMessage)
+ assert isinstance(messages[1], AIMessage)
+ assert messages[0].content == "Hello!"
+ assert messages[1].content == "Hi there!"
+
+
+def test_multiple_sessions(
+ sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory]
+) -> None:
+ history1, history2 = sql_histories
+
+ # first session
+ history1.add_user_message("Hello!")
+ history1.add_ai_message("Hi there!")
+ history1.add_user_message("Whats cracking?")
+
+ # second session
+ history2.add_user_message("Hellox")
+
+ messages1 = history1.messages
+ messages2 = history2.messages
+
+ # Ensure the messages are added correctly in the first session
+ assert len(messages1) == 3, "waat"
+ assert messages1[0].content == "Hello!"
+ assert messages1[1].content == "Hi there!"
+ assert messages1[2].content == "Whats cracking?"
+
+ assert len(messages2) == 1
+ assert len(messages1) == 3
+ assert messages2[0].content == "Hellox"
+ assert messages1[0].content == "Hello!"
+ assert messages1[1].content == "Hi there!"
+ assert messages1[2].content == "Whats cracking?"
+
+
+def test_clear_messages(
+ sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory]
+) -> None:
+ sql_history, other_history = sql_histories
+ sql_history.add_user_message("Hello!")
+ sql_history.add_ai_message("Hi there!")
+ assert len(sql_history.messages) == 2
+ # Now create another history with different session id
+ other_history.add_user_message("Hellox")
+ assert len(other_history.messages) == 1
+ assert len(sql_history.messages) == 2
+ # Now clear the first history
+ sql_history.clear()
+ assert len(sql_history.messages) == 0
+ assert len(other_history.messages) == 1
+
+
+def test_model_no_session_id_field_error(connection_string: str) -> None:
+ class Base(DeclarativeBase):
+ pass
+
+ class Model(Base):
+ __tablename__ = "test_table"
+ id = Column(Integer, primary_key=True)
+ test_field = Column(Text)
+
+ class CustomMessageConverter(DefaultMessageConverter):
+ def get_sql_model_class(self) -> Any:
+ return Model
+
+ with pytest.raises(ValueError):
+ CrateDBChatMessageHistory(
+ "test",
+ connection_string,
+ custom_message_converter=CustomMessageConverter("test_table"),
+ )
+
+
+def test_memory_with_message_store(connection_string: str) -> None:
+ """
+ Test ConversationBufferMemory with a message store.
+ """
+ # Setup CrateDB as a message store.
+ message_history = CrateDBChatMessageHistory(
+ connection_string=connection_string, session_id="test-session"
+ )
+ memory = ConversationBufferMemory(
+ memory_key="baz", chat_memory=message_history, return_messages=True
+ )
+
+ # Add a few messages.
+ memory.chat_memory.add_ai_message("This is me, the AI")
+ memory.chat_memory.add_user_message("This is me, the human")
+
+ # Get the message history from the memory store and turn it into JSON.
+ messages = memory.chat_memory.messages
+ messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
+
+ # Verify the outcome.
+ assert "This is me, the AI" in messages_json
+ assert "This is me, the human" in messages_json
+
+ # Clear the conversation history, and verify that.
+ memory.chat_memory.clear()
+ assert memory.chat_memory.messages == []
diff --git a/libs/langchain/tests/integration_tests/vectorstores/docker-compose/cratedb.yml b/libs/langchain/tests/integration_tests/vectorstores/docker-compose/cratedb.yml
new file mode 100644
index 0000000000000..b547b2c766f20
--- /dev/null
+++ b/libs/langchain/tests/integration_tests/vectorstores/docker-compose/cratedb.yml
@@ -0,0 +1,20 @@
+version: "3"
+
+services:
+ postgresql:
+ image: crate/crate:nightly
+ environment:
+ - CRATE_HEAP_SIZE=4g
+ ports:
+ - "4200:4200"
+ - "5432:5432"
+ command: |
+ crate -Cdiscovery.type=single-node
+ healthcheck:
+ test:
+ [
+ "CMD-SHELL",
+ "curl --silent --fail http://localhost:4200/ || exit 1",
+ ]
+ interval: 5s
+ retries: 60
diff --git a/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py b/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py
index 7b99c696444af..016ebe85fed18 100644
--- a/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py
+++ b/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py
@@ -52,7 +52,6 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
- return self.embed_documents([text])[0]
if text not in self.known_texts:
return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)]
return [float(1.0)] * (self.dimensionality - 1) + [
diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py
new file mode 100644
index 0000000000000..1687dd10a30d9
--- /dev/null
+++ b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py
@@ -0,0 +1,803 @@
+"""
+Test CrateDB `FLOAT_VECTOR` / `KNN_MATCH` functionality.
+
+cd tests/integration_tests/vectorstores/docker-compose
+docker-compose -f cratedb.yml up
+"""
+import os
+import re
+from typing import Dict, Generator, List, Tuple
+
+import pytest
+import sqlalchemy as sa
+import sqlalchemy.orm
+from pytest_mock import MockerFixture
+from sqlalchemy.exc import ProgrammingError
+from sqlalchemy.orm import Session
+
+from langchain.docstore.document import Document
+from langchain.vectorstores.cratedb import CrateDBVectorSearch
+from langchain.vectorstores.cratedb.base import StorageStrategy
+from langchain.vectorstores.cratedb.extended import CrateDBVectorSearchMultiCollection
+from langchain.vectorstores.cratedb.model import ModelFactory
+from tests.integration_tests.vectorstores.fake_embeddings import (
+ ConsistentFakeEmbeddings,
+ FakeEmbeddings,
+)
+
+SCHEMA_NAME = os.environ.get("TEST_CRATEDB_DATABASE", "testdrive")
+
+CONNECTION_STRING = CrateDBVectorSearch.connection_string_from_db_params(
+ driver=os.environ.get("TEST_CRATEDB_DRIVER", "crate"),
+ host=os.environ.get("TEST_CRATEDB_HOST", "localhost"),
+ port=int(os.environ.get("TEST_CRATEDB_PORT", "4200")),
+ database=SCHEMA_NAME,
+ user=os.environ.get("TEST_CRATEDB_USER", "crate"),
+ password=os.environ.get("TEST_CRATEDB_PASSWORD", ""),
+)
+
+ADA_TOKEN_COUNT = 1536
+
+
+@pytest.fixture
+def engine() -> sa.Engine:
+ """
+ Return an SQLAlchemy engine object.
+ """
+ return sa.create_engine(CONNECTION_STRING, echo=False)
+
+
+@pytest.fixture
+def session(engine: sa.Engine) -> Generator[sa.orm.Session, None, None]:
+ with engine.connect() as conn:
+ with Session(conn) as session:
+ yield session
+
+
+@pytest.fixture(autouse=True)
+def drop_tables(engine: sa.Engine) -> None:
+ """
+ Drop relevant database tables before invoking each test case.
+
+ TODO: Check how only those database tables can be dropped, which
+ have actually been used, in order to increase performance.
+ Alternatively, reduce the number of different table names
+ used for testing.
+ """
+
+ def sqlalchemy_drop_model(mf: ModelFactory) -> None:
+ try:
+ mf.BaseModel.metadata.drop_all(engine, checkfirst=False)
+ except Exception as ex:
+ if "RelationUnknown" not in str(ex):
+ raise
+
+ collection_name_candidates = [
+ "test_collection",
+ "test_collection_filter",
+ "test_collection_foo",
+ "test_collection_bar",
+ "test_collection_1",
+ "test_collection_2",
+ ]
+
+ # Recycling for vanilla storage strategy.
+ mf = ModelFactory(embedding_table="embedding")
+ sqlalchemy_drop_model(mf)
+
+ # Recycling for advanced "embedding-table-per-collection" storage strategy.
+ for collection_name in collection_name_candidates:
+ mf = ModelFactory(embedding_table=f"embedding_{collection_name}")
+ sqlalchemy_drop_model(mf)
+
+
+@pytest.fixture
+def prune_tables(engine: sa.Engine) -> None:
+ """
+ Delete data from database tables.
+
+ Note: This fixture is currently not used.
+ """
+ with engine.connect() as conn:
+ with Session(conn) as session:
+ mf = ModelFactory()
+ try:
+ session.query(mf.CollectionStore).delete()
+ except ProgrammingError:
+ pass
+ try:
+ session.query(mf.EmbeddingStore).delete()
+ except ProgrammingError:
+ pass
+
+
+def ensure_collection(session: sa.orm.Session, name: str) -> None:
+ """
+ Create a (fake) collection item.
+ """
+ embedding_table_name = "embedding"
+ if (
+ CrateDBVectorSearch.STORAGE_STRATEGY
+ is StorageStrategy.EMBEDDING_TABLE_PER_COLLECTION
+ ):
+ embedding_table_name = f"embedding_{name}"
+ session.execute(
+ sa.text(
+ """
+ CREATE TABLE IF NOT EXISTS collection (
+ uuid TEXT,
+ name TEXT,
+ cmetadata OBJECT
+ );
+ """
+ )
+ )
+ session.execute(
+ sa.text(
+ f"""
+ CREATE TABLE IF NOT EXISTS {embedding_table_name} (
+ uuid TEXT,
+ collection_id TEXT,
+ embedding FLOAT_VECTOR(123),
+ document TEXT,
+ cmetadata OBJECT,
+ custom_id TEXT
+ );
+ """
+ )
+ )
+ try:
+ session.execute(
+ sa.text(
+ f"INSERT INTO collection (uuid, name, cmetadata) "
+ f"VALUES ('uuid-{name}', '{name}', {{}});"
+ )
+ )
+ session.execute(sa.text("REFRESH TABLE collection"))
+ except sa.exc.IntegrityError:
+ pass
+
+
+class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
+ """Fake embeddings functionality for testing."""
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Return simple embeddings."""
+ return [
+ [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
+ ]
+
+ def embed_query(self, text: str) -> List[float]:
+ """Return simple embeddings."""
+ return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]
+
+
+class ConsistentFakeEmbeddingsWithAdaDimension(ConsistentFakeEmbeddings):
+ """
+ Fake embeddings which remember all the texts seen so far to return
+ consistent vectors for the same texts.
+
+ Other than this, they also have a fixed dimensionality, which is
+ important in this case.
+ """
+
+ def __init__(self, *args: List, **kwargs: Dict) -> None:
+ super().__init__(dimensionality=ADA_TOKEN_COUNT)
+
+
+@pytest.fixture
+def two_stores() -> Tuple[CrateDBVectorSearch, CrateDBVectorSearch]:
+ """
+ Provide two different vector search handles to test case functions,
+ associated with two different collections, correspondingly.
+ """
+ store_foo = CrateDBVectorSearch.from_texts(
+ texts=["foo"],
+ collection_name="test_collection_foo",
+ collection_metadata={"category": "foo"},
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=[{"document": "foo"}],
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ store_bar = CrateDBVectorSearch.from_texts(
+ texts=["bar"],
+ collection_name="test_collection_bar",
+ collection_metadata={"category": "bar"},
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=[{"document": "bar"}],
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ return store_foo, store_bar
+
+
+def test_cratedb_texts() -> None:
+ """Test end to end construction and search."""
+ texts = ["foo", "bar", "baz"]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ output = docsearch.similarity_search("foo", k=1)
+ assert output == [Document(page_content="foo")]
+
+
+def test_cratedb_embedding_dimension() -> None:
+ """Verify the `embedding` column uses the correct vector dimensionality."""
+ texts = ["foo", "bar", "baz"]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection",
+ embedding=ConsistentFakeEmbeddingsWithAdaDimension(),
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ with docsearch.Session() as session:
+ result = session.execute(sa.text(f"SHOW CREATE TABLE {SCHEMA_NAME}.embedding"))
+ record = result.first()
+ if not record:
+ raise ValueError("No data found")
+ ddl = record[0]
+ assert f'"embedding" FLOAT_VECTOR({ADA_TOKEN_COUNT})' in ddl
+
+
+def test_cratedb_embeddings() -> None:
+ """Test end to end construction with embeddings and search."""
+ texts = ["foo", "bar", "baz"]
+ text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts)
+ text_embedding_pairs = list(zip(texts, text_embeddings))
+ docsearch = CrateDBVectorSearch.from_embeddings(
+ text_embeddings=text_embedding_pairs,
+ collection_name="test_collection",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ output = docsearch.similarity_search("foo", k=1)
+ assert output == [Document(page_content="foo")]
+
+
+def test_cratedb_with_metadatas() -> None:
+ """Test end to end construction and search."""
+ texts = ["foo", "bar", "baz"]
+ metadatas = [{"page": str(i)} for i in range(len(texts))]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=metadatas,
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ output = docsearch.similarity_search("foo", k=1)
+ assert output == [Document(page_content="foo", metadata={"page": "0"})]
+
+
+def test_cratedb_with_metadatas_with_scores() -> None:
+ """Test end to end construction and search."""
+ texts = ["foo", "bar", "baz"]
+ metadatas = [{"page": str(i)} for i in range(len(texts))]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=metadatas,
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ output = docsearch.similarity_search_with_score("foo", k=1)
+ assert output == [(Document(page_content="foo", metadata={"page": "0"}), 2.0)]
+
+
+def test_cratedb_with_filter_match() -> None:
+ """Test end to end construction and search."""
+ texts = ["foo", "bar", "baz"]
+ metadatas = [{"page": str(i)} for i in range(len(texts))]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection_filter",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=metadatas,
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ # TODO: Original:
+ # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501
+ output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
+ assert output == [
+ (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(2.2, 0.1))
+ ]
+
+
+def test_cratedb_with_filter_distant_match() -> None:
+ """Test end to end construction and search."""
+ texts = ["foo", "bar", "baz"]
+ metadatas = [{"page": str(i)} for i in range(len(texts))]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection_filter",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=metadatas,
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ output = docsearch.similarity_search_with_score("foo", k=2, filter={"page": "2"})
+ # Original score value: 0.0013003906671379406
+ assert output == [
+ (Document(page_content="baz", metadata={"page": "2"}), pytest.approx(1.5, 0.2))
+ ]
+
+
+def test_cratedb_with_filter_no_match() -> None:
+ """Test end to end construction and search."""
+ texts = ["foo", "bar", "baz"]
+ metadatas = [{"page": str(i)} for i in range(len(texts))]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection_filter",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=metadatas,
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"})
+ assert output == []
+
+
+@pytest.fixture
+def storage_strategy_langchain_pgvector(mocker: MockerFixture) -> None:
+ mocker.patch(
+ "langchain.vectorstores.cratedb.base.CrateDBVectorSearch.STORAGE_STRATEGY",
+ StorageStrategy.LANGCHAIN_PGVECTOR,
+ )
+
+
+@pytest.fixture
+def storage_strategy_embedding_table_per_collection(mocker: MockerFixture) -> None:
+ mocker.patch(
+ "langchain.vectorstores.cratedb.base.CrateDBVectorSearch.STORAGE_STRATEGY",
+ StorageStrategy.EMBEDDING_TABLE_PER_COLLECTION,
+ )
+
+
+def test_cratedb_storage_strategy_langchain_pgvector(
+ storage_strategy_langchain_pgvector: None,
+ two_stores: Tuple[CrateDBVectorSearch, CrateDBVectorSearch],
+) -> None:
+ """
+ Verify collection construction and deletion using the vanilla storage strategy.
+
+ It uses two different collections of embeddings. By such, it proves that
+ the embeddings are managed according to the storage strategy.
+
+ In this case, embeddings for multiple collections are managed on behalf of
+ a single database table, called `embedding`.
+ """
+ store_foo, store_bar = two_stores
+
+ session = store_foo.Session()
+
+ # Verify data in database.
+ collection_foo = store_foo.get_collection(session)
+ collection_bar = store_bar.get_collection(session)
+ if collection_foo is None or collection_bar is None:
+ assert False, "Expected CollectionStore objects but received None"
+ assert collection_foo.embeddings[0].cmetadata == {"document": "foo"}
+ assert collection_bar.embeddings[0].cmetadata == {"document": "bar"}
+
+ # Verify number of records before deletion.
+ assert session.query(store_foo.EmbeddingStore).count() == 2
+ assert session.query(store_bar.EmbeddingStore).count() == 2
+
+ # Delete first collection.
+ store_foo.delete_collection()
+
+ # Verify that the "foo" collection has been deleted.
+ collection_foo = store_foo.get_collection(session)
+ collection_bar = store_bar.get_collection(session)
+ if collection_bar is None:
+ assert False, "Expected CollectionStore object but received None"
+ assert collection_foo is None
+ assert collection_bar.embeddings[0].cmetadata == {"document": "bar"}
+
+ # Verify number of records after deletion, to proof that associated
+ # embeddings also have been deleted.
+ assert session.query(store_foo.EmbeddingStore).count() == 1
+ assert session.query(store_bar.EmbeddingStore).count() == 1
+
+
+def test_cratedb_storage_strategy_embedding_table_per_collection(
+ storage_strategy_embedding_table_per_collection: None,
+ two_stores: Tuple[CrateDBVectorSearch, CrateDBVectorSearch],
+) -> None:
+ """
+ Verify collection construction and deletion using a more advanced storage strategy.
+
+ It uses two different collections of embeddings. By such, it proves that
+ the embeddings are managed according to the storage strategy.
+
+ In this case, embeddings for multiple collections are managed on behalf of
+ separate database tables, called `embedding_{collection_name}`.
+ """
+ store_foo, store_bar = two_stores
+
+ session = store_foo.Session()
+
+ # Verify data in database.
+ collection_foo = store_foo.get_collection(session)
+ collection_bar = store_bar.get_collection(session)
+ assert collection_foo.embeddings[0].cmetadata == {"document": "foo"}
+ assert collection_bar.embeddings[0].cmetadata == {"document": "bar"}
+
+ # Verify number of records before deletion.
+ assert session.query(store_foo.EmbeddingStore).count() == 1
+ assert session.query(store_bar.EmbeddingStore).count() == 1
+
+ # Delete first collection.
+ store_foo.delete_collection()
+
+ # Verify that the "foo" collection has been deleted.
+ collection_foo = store_foo.get_collection(session)
+ collection_bar = store_bar.get_collection(session)
+ assert collection_foo is None
+ assert collection_bar.embeddings[0].cmetadata == {"document": "bar"}
+
+ # Verify number of records after deletion, to proof that associated
+ # embeddings also have been deleted.
+ assert session.query(store_foo.EmbeddingStore).count() == 0
+ assert session.query(store_bar.EmbeddingStore).count() == 1
+
+
+def test_cratedb_collection_with_metadata() -> None:
+ """Test end to end collection construction"""
+ texts = ["foo", "bar", "baz"]
+ metadatas = [{"page": str(i)} for i in range(len(texts))]
+ cratedb_vector = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection",
+ collection_metadata={"foo": "bar"},
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=metadatas,
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ collection = cratedb_vector.get_collection(cratedb_vector.Session())
+ if collection is None:
+ assert False, "Expected a CollectionStore object but received None"
+ else:
+ assert collection.name == "test_collection"
+ assert collection.cmetadata == {"foo": "bar"}
+
+
+def test_cratedb_collection_no_embedding_dimension() -> None:
+ """
+ Verify that addressing collections fails when not specifying dimensions.
+ """
+ cratedb_vector = CrateDBVectorSearch(
+ embedding_function=None, # type: ignore[arg-type]
+ connection_string=CONNECTION_STRING,
+ )
+ session = Session(cratedb_vector.connect())
+ with pytest.raises(RuntimeError) as ex:
+ cratedb_vector.get_collection(session)
+ assert ex.match(
+ "Collection can't be accessed without specifying "
+ "dimension size of embedding vectors"
+ )
+
+
+def test_cratedb_collection_read_only(session: Session) -> None:
+ """
+ Test using a collection, without adding any embeddings upfront.
+
+ This happens when just invoking the "retrieval" case.
+
+ In this scenario, embedding dimensionality needs to be figured out
+ from the supplied `embedding_function`.
+ """
+
+ # Create a fake collection item.
+ ensure_collection(session, "baz2")
+
+ # This test case needs an embedding _with_ dimensionality.
+ # Otherwise, the data access layer is unable to figure it
+ # out at runtime.
+ embedding = ConsistentFakeEmbeddingsWithAdaDimension()
+
+ vectorstore = CrateDBVectorSearch(
+ collection_name="baz2",
+ connection_string=CONNECTION_STRING,
+ embedding_function=embedding,
+ )
+ output = vectorstore.similarity_search("foo", k=1)
+
+ # No documents/embeddings have been loaded, the collection is empty.
+ # This is why there are also no results.
+ assert output == []
+
+
+def test_cratedb_with_filter_in_set() -> None:
+ """Test end to end construction and search."""
+ texts = ["foo", "bar", "baz"]
+ metadatas = [{"page": str(i)} for i in range(len(texts))]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection_filter",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=metadatas,
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ output = docsearch.similarity_search_with_score(
+ "foo", k=2, filter={"page": {"IN": ["0", "2"]}}
+ )
+ # Original score values: 0.0, 0.0013003906671379406
+ assert output == [
+ (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(3.0, 0.1)),
+ (Document(page_content="baz", metadata={"page": "2"}), pytest.approx(2.2, 0.1)),
+ ]
+
+
+def test_cratedb_delete_docs() -> None:
+ """Add and delete documents."""
+ texts = ["foo", "bar", "baz"]
+ metadatas = [{"page": str(i)} for i in range(len(texts))]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection_filter",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=metadatas,
+ ids=["1", "2", "3"],
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ docsearch.delete(["1", "2"])
+ with docsearch._make_session() as session:
+ records = list(session.query(docsearch.EmbeddingStore).all())
+ # ignoring type error since mypy cannot determine whether
+ # the list is sortable
+ assert sorted(record.custom_id for record in records) == ["3"] # type: ignore
+
+ docsearch.delete(["2", "3"]) # Should not raise on missing ids
+ with docsearch._make_session() as session:
+ records = list(session.query(docsearch.EmbeddingStore).all())
+ # ignoring type error since mypy cannot determine whether
+ # the list is sortable
+ assert sorted(record.custom_id for record in records) == [] # type: ignore
+
+
+def test_cratedb_relevance_score() -> None:
+ """Test to make sure the relevance score is scaled to 0-1."""
+ texts = ["foo", "bar", "baz"]
+ metadatas = [{"page": str(i)} for i in range(len(texts))]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=metadatas,
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+
+ output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
+ # Original score values: 1.0, 0.9996744261675065, 0.9986996093328621
+ assert output == [
+ (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(1.4, 0.1)),
+ (Document(page_content="bar", metadata={"page": "1"}), pytest.approx(1.1, 0.1)),
+ (Document(page_content="baz", metadata={"page": "2"}), pytest.approx(0.8, 0.1)),
+ ]
+
+
+def test_cratedb_retriever_search_threshold() -> None:
+ """Test using retriever for searching with threshold."""
+ texts = ["foo", "bar", "baz"]
+ metadatas = [{"page": str(i)} for i in range(len(texts))]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=metadatas,
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+
+ retriever = docsearch.as_retriever(
+ search_type="similarity_score_threshold",
+ search_kwargs={"k": 3, "score_threshold": 0.999},
+ )
+ output = retriever.get_relevant_documents("summer")
+ assert output == [
+ Document(page_content="foo", metadata={"page": "0"}),
+ Document(page_content="bar", metadata={"page": "1"}),
+ ]
+
+
+def test_cratedb_retriever_search_threshold_custom_normalization_fn() -> None:
+ """Test searching with threshold and custom normalization function"""
+ texts = ["foo", "bar", "baz"]
+ metadatas = [{"page": str(i)} for i in range(len(texts))]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ metadatas=metadatas,
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ relevance_score_fn=lambda d: d * 0,
+ )
+
+ retriever = docsearch.as_retriever(
+ search_type="similarity_score_threshold",
+ search_kwargs={"k": 3, "score_threshold": 0.5},
+ )
+ output = retriever.get_relevant_documents("foo")
+ assert output == []
+
+
+def test_cratedb_max_marginal_relevance_search() -> None:
+ """Test max marginal relevance search."""
+ texts = ["foo", "bar", "baz"]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3)
+ assert output == [Document(page_content="foo")]
+
+
+def test_cratedb_max_marginal_relevance_search_with_score() -> None:
+ """Test max marginal relevance search with relevance scores."""
+ texts = ["foo", "bar", "baz"]
+ docsearch = CrateDBVectorSearch.from_texts(
+ texts=texts,
+ collection_name="test_collection",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3)
+ assert output == [(Document(page_content="foo"), 2.0)]
+
+
+def test_cratedb_multicollection_search_success() -> None:
+ """
+ `CrateDBVectorSearchMultiCollection` provides functionality for
+ searching multiple collections.
+ """
+
+ store_1 = CrateDBVectorSearch.from_texts(
+ texts=["Räuber", "Hotzenplotz"],
+ collection_name="test_collection_1",
+ embedding=ConsistentFakeEmbeddingsWithAdaDimension(),
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+ _ = CrateDBVectorSearch.from_texts(
+ texts=["John", "Doe"],
+ collection_name="test_collection_2",
+ embedding=ConsistentFakeEmbeddingsWithAdaDimension(),
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+
+ # Probe the first store.
+ output = store_1.similarity_search("Räuber", k=1)
+ assert Document(page_content="Räuber") in output[:2]
+ output = store_1.similarity_search("Hotzenplotz", k=1)
+ assert Document(page_content="Hotzenplotz") in output[:2]
+ output = store_1.similarity_search("John Doe", k=1)
+ assert Document(page_content="Räuber") in output[:2]
+
+ # Probe the multi-store.
+ multisearch = CrateDBVectorSearchMultiCollection(
+ collection_names=["test_collection_1", "test_collection_2"],
+ embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(),
+ connection_string=CONNECTION_STRING,
+ )
+ output = multisearch.similarity_search("Räuber Hotzenplotz", k=2)
+ assert Document(page_content="Räuber") in output[:2]
+ output = multisearch.similarity_search("John Doe", k=2)
+ assert Document(page_content="John") in output[:2]
+
+
+def test_cratedb_multicollection_fail_indexing_not_permitted() -> None:
+ """
+ `CrateDBVectorSearchMultiCollection` does not provide functionality for
+ indexing documents.
+ """
+
+ with pytest.raises(NotImplementedError) as ex:
+ CrateDBVectorSearchMultiCollection.from_texts(
+ texts=["foo"],
+ collection_name="test_collection",
+ embedding=FakeEmbeddingsWithAdaDimension(),
+ connection_string=CONNECTION_STRING,
+ )
+ assert ex.match("This adapter can not be used for indexing documents")
+
+
+def test_cratedb_multicollection_search_table_does_not_exist() -> None:
+ """
+ `CrateDBVectorSearchMultiCollection` will fail when the `collection`
+ table does not exist.
+ """
+
+ store = CrateDBVectorSearchMultiCollection(
+ collection_names=["unknown"],
+ embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(),
+ connection_string=CONNECTION_STRING,
+ )
+ with pytest.raises(ProgrammingError) as ex:
+ store.similarity_search("foo")
+ assert ex.match(re.escape("RelationUnknown[Relation 'collection' unknown]"))
+
+
+def test_cratedb_multicollection_search_unknown_collection() -> None:
+ """
+ `CrateDBVectorSearchMultiCollection` will fail when not able to identify
+ collections to search in.
+ """
+
+ CrateDBVectorSearch.from_texts(
+ texts=["Räuber", "Hotzenplotz"],
+ collection_name="test_collection",
+ embedding=ConsistentFakeEmbeddingsWithAdaDimension(),
+ connection_string=CONNECTION_STRING,
+ pre_delete_collection=True,
+ )
+
+ store = CrateDBVectorSearchMultiCollection(
+ collection_names=["unknown"],
+ embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(),
+ connection_string=CONNECTION_STRING,
+ )
+ with pytest.raises(ValueError) as ex:
+ store.similarity_search("foo")
+ assert ex.match("No collections found")
+
+
+def test_cratedb_multicollection_no_embedding_dimension() -> None:
+ """
+ Verify that addressing collections fails when not specifying dimensions.
+ """
+ store = CrateDBVectorSearchMultiCollection(
+ embedding_function=None, # type: ignore[arg-type]
+ connection_string=CONNECTION_STRING,
+ )
+ session = Session(store.connect())
+ with pytest.raises(RuntimeError) as ex:
+ store.get_collection(session)
+ assert ex.match(
+ "Collection can't be accessed without specifying "
+ "dimension size of embedding vectors"
+ )
+
+
+def test_cratedb_multicollection_storage_strategy_embedding_table_per_collection() -> (
+ None
+):
+ """
+ Verify that using the multi-collection querying trips corresponding safety checks,
+ when configured to use the `EMBEDDING_TABLE_PER_COLLECTION` storage strategy.
+
+ They are not supported together, yet.
+ """
+ CrateDBVectorSearchMultiCollection.configure(
+ storage_strategy=StorageStrategy.EMBEDDING_TABLE_PER_COLLECTION
+ )
+ with pytest.raises(NotImplementedError) as ex:
+ CrateDBVectorSearchMultiCollection(
+ embedding_function=None, # type: ignore[arg-type]
+ connection_string=CONNECTION_STRING,
+ )
+ assert ex.match(
+ "Multi-collection querying not supported by strategy: "
+ "StorageStrategy.EMBEDDING_TABLE_PER_COLLECTION"
+ )
diff --git a/libs/langchain/tests/unit_tests/conftest.py b/libs/langchain/tests/unit_tests/conftest.py
index da45a330f50af..afb609f3ce31f 100644
--- a/libs/langchain/tests/unit_tests/conftest.py
+++ b/libs/langchain/tests/unit_tests/conftest.py
@@ -40,8 +40,8 @@ def test_something():
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
- only_extended = config.getoption("--only-extended") or False
- only_core = config.getoption("--only-core") or False
+ only_extended = config.getoption("--only-extended", False)
+ only_core = config.getoption("--only-core", False)
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")