diff --git a/.github/workflows/test-langchain.yml b/.github/workflows/test-langchain.yml new file mode 100644 index 00000000..b8f0237c --- /dev/null +++ b/.github/workflows/test-langchain.yml @@ -0,0 +1,62 @@ +name: LangChain +on: + + pull_request: + branches: ~ + paths: + - '.github/workflows/test-langchain.yml' + - 'framework/langchain/**' + - 'testing/ngr.py' + push: + branches: [ main ] + paths: + - '.github/workflows/test-langchain.yml' + - 'framework/langchain/**' + - 'testing/ngr.py' + + # Allow job to be triggered manually. + workflow_dispatch: + +# Cancel in-progress jobs when pushing to the same branch. +concurrency: + cancel-in-progress: true + group: ${{ github.workflow }}-${{ github.ref }} + +jobs: + test: + name: "CrateDB: ${{ matrix.cratedb-version }} + Python: ${{ matrix.python-version }} + on ${{ matrix.os }}" + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ "ubuntu-latest" ] + cratedb-version: [ "nightly" ] + python-version: [ "3.11" ] + + services: + cratedb: + image: crate/crate:nightly + ports: + - 4200:4200 + - 5432:5432 + + steps: + + - name: Acquire sources + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + cache: 'pip' + cache-dependency-path: | + framework/langchain/requirements.txt + framework/langchain/requirements-dev.txt + + - name: Validate framework/langchain + run: | + python testing/ngr.py --accept-no-venv framework/langchain diff --git a/.gitignore b/.gitignore index 77abd6ea..9c88f8af 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ .idea .venv* +__pycache__ +.coverage +coverage.xml diff --git a/README.rst b/README.rst index 640d586b..ea677da8 100644 --- a/README.rst +++ b/README.rst @@ -50,6 +50,7 @@ Examples:: More examples:: + python testing/ngr.py framework/langchain python testing/ngr.py testing/testcontainers/java It is recommended to invoke ``ngr`` from within a Python virtualenv. diff --git a/framework/langchain/.gitignore b/framework/langchain/.gitignore new file mode 100644 index 00000000..d1b811b7 --- /dev/null +++ b/framework/langchain/.gitignore @@ -0,0 +1 @@ +*.sql diff --git a/framework/langchain/conftest.py b/framework/langchain/conftest.py new file mode 100644 index 00000000..98f88569 --- /dev/null +++ b/framework/langchain/conftest.py @@ -0,0 +1,52 @@ +import typing as t + +import pytest + + +def monkeypatch_pytest_notebook_treat_cell_exit_as_notebook_skip(): + """ + Patch `pytest-notebook`, in fact `nbclient.client.NotebookClient`, + to propagate cell-level `pytest.exit()` invocations as signals + to mark the whole notebook as skipped. + + In order not to be too intrusive, the feature only skips notebooks + when being explicitly instructed, by adding `[skip-notebook]` at the + end of the `reason` string. Example: + + import pytest + if "ACME_API_KEY" not in os.environ: + pytest.exit("ACME_API_KEY not given [skip-notebook]") + + https://github.com/chrisjsewell/pytest-notebook/issues/43 + """ + from nbclient.client import NotebookClient + from nbclient.exceptions import CellExecutionError + from nbformat import NotebookNode + + async_execute_cell_dist = NotebookClient.async_execute_cell + + async def async_execute_cell( + self, + cell: NotebookNode, + cell_index: int, + execution_count: t.Optional[int] = None, + store_history: bool = True, + ) -> NotebookNode: + try: + return await async_execute_cell_dist( + self, + cell, + cell_index, + execution_count=execution_count, + store_history=store_history, + ) + except CellExecutionError as ex: + if ex.ename == "Exit" and ex.evalue.endswith("[skip-notebook]"): + raise pytest.skip(ex.evalue) + else: + raise + + NotebookClient.async_execute_cell = async_execute_cell + + +monkeypatch_pytest_notebook_treat_cell_exit_as_notebook_skip() diff --git a/framework/langchain/conversational_memory.ipynb b/framework/langchain/conversational_memory.ipynb index da0b64bc..f3a0d1cb 100644 --- a/framework/langchain/conversational_memory.ipynb +++ b/framework/langchain/conversational_memory.ipynb @@ -58,12 +58,16 @@ "source": [ "from langchain.memory.chat_message_histories import CrateDBChatMessageHistory\n", "\n", + "# Connect to a self-managed CrateDB instance.\n", "CONNECTION_STRING = \"crate://crate@localhost/?schema=notebook\"\n", "\n", "chat_message_history = CrateDBChatMessageHistory(\n", "\tsession_id=\"test_session\",\n", "\tconnection_string=CONNECTION_STRING\n", - ")" + ")\n", + "\n", + "# Make sure to start with a blank canvas.\n", + "chat_message_history.clear()" ], "metadata": { "collapsed": false @@ -101,9 +105,11 @@ "outputs": [ { "data": { - "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + "text/plain": [ + "[HumanMessage(content='Hello'), AIMessage(content='Hi')]" + ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -147,7 +153,6 @@ "from datetime import datetime\n", "from typing import Any\n", "\n", - "from langchain.memory.chat_message_histories.cratedb import generate_autoincrement_identifier\n", "from langchain.memory.chat_message_histories.sql import BaseMessageConverter\n", "from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage\n", "\n", @@ -161,7 +166,7 @@ "class CustomMessage(Base):\n", "\t__tablename__ = \"custom_message_store\"\n", "\n", - "\tid = sa.Column(sa.BigInteger, primary_key=True, default=generate_autoincrement_identifier)\n", + "\tid = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now())\n", "\tsession_id = sa.Column(sa.Text)\n", "\ttype = sa.Column(sa.Text)\n", "\tcontent = sa.Column(sa.Text)\n", @@ -215,6 +220,9 @@ "\t\t)\n", "\t)\n", "\n", + "\t# Make sure to start with a blank canvas.\n", + "\tchat_message_history.clear()\n", + "\n", "\tchat_message_history.add_user_message(\"Hello\")\n", "\tchat_message_history.add_ai_message(\"Hi\")" ], @@ -233,9 +241,11 @@ "outputs": [ { "data": { - "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + "text/plain": [ + "[HumanMessage(content='Hello'), AIMessage(content='Hi')]" + ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -268,13 +278,13 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "outputs": [], "source": [ "import json\n", "import typing as t\n", "\n", - "from langchain.memory.chat_message_histories.cratedb import generate_autoincrement_identifier, CrateDBMessageConverter\n", + "from langchain.memory.chat_message_histories.cratedb import CrateDBMessageConverter\n", "from langchain.schema import _message_to_dict\n", "\n", "\n", @@ -282,7 +292,7 @@ "\n", "class MessageWithDifferentSessionIdColumn(Base):\n", "\t__tablename__ = \"message_store_different_session_id\"\n", - "\tid = sa.Column(sa.BigInteger, primary_key=True, default=generate_autoincrement_identifier)\n", + "\tid = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now())\n", "\tcustom_session_id = sa.Column(sa.Text)\n", "\tmessage = sa.Column(sa.Text)\n", "\n", @@ -307,6 +317,9 @@ "\t\tsession_id_field_name=\"custom_session_id\",\n", "\t)\n", "\n", + "\t# Make sure to start with a blank canvas.\n", + "\tchat_message_history.clear()\n", + "\n", "\tchat_message_history.add_user_message(\"Hello\")\n", "\tchat_message_history.add_ai_message(\"Hi\")" ], @@ -316,11 +329,13 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "outputs": [ { "data": { - "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + "text/plain": [ + "[HumanMessage(content='Hello'), AIMessage(content='Hi')]" + ] }, "execution_count": 9, "metadata": {}, diff --git a/framework/langchain/conversational_memory.py b/framework/langchain/conversational_memory.py index 74934e40..4eef99f7 100644 --- a/framework/langchain/conversational_memory.py +++ b/framework/langchain/conversational_memory.py @@ -19,11 +19,17 @@ from langchain.memory.chat_message_histories import CrateDBChatMessageHistory +CONNECTION_STRING = os.environ.get( + "CRATEDB_CONNECTION_STRING", + "crate://crate@localhost/?schema=doc" +) + + def main(): chat_message_history = CrateDBChatMessageHistory( session_id="test_session", - connection_string=os.environ.get("CRATEDB_CONNECTION_STRING") + connection_string=CONNECTION_STRING, ) chat_message_history.add_user_message("Hello") chat_message_history.add_ai_message("Hi") diff --git a/framework/langchain/document_loader.ipynb b/framework/langchain/document_loader.ipynb index b8756382..fea0437c 100644 --- a/framework/langchain/document_loader.ipynb +++ b/framework/langchain/document_loader.ipynb @@ -58,8 +58,8 @@ "output_type": "stream", "text": [ "\u001B[32mCONNECT OK\r\n", - "\u001B[0m\u001B[32mPSQL OK, 1 row affected (0.001 sec)\r\n", - "\u001B[0m\u001B[32mDELETE OK, 30 rows affected (0.010 sec)\r\n", + "\u001B[0m\u001B[32mPROVISIONING OK, 0 rows affected (0.001 sec)\r\n", + "\u001B[0m\u001B[32mCREATE OK, 1 row affected (0.010 sec)\r\n", "\u001B[0m\u001B[32mINSERT OK, 30 rows affected (0.011 sec)\r\n", "\u001B[0m\u001B[0m\u001B[32mCONNECT OK\r\n", "\u001B[0m\u001B[32mREFRESH OK, 1 row affected (0.026 sec)\r\n", @@ -95,6 +95,7 @@ "from langchain.document_loaders import CrateDBLoader\n", "from pprint import pprint\n", "\n", + "# Connect to a self-managed CrateDB instance.\n", "CONNECTION_STRING = \"crate://crate@localhost/?schema=notebook\"\n", "\n", "loader = CrateDBLoader(\n", @@ -115,11 +116,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Document(page_content='Team: Angels\\nPayroll (millions): 154.49\\nWins: 89', metadata={}),\n", - " Document(page_content='Team: Astros\\nPayroll (millions): 60.65\\nWins: 55', metadata={}),\n", - " Document(page_content='Team: Athletics\\nPayroll (millions): 55.37\\nWins: 94', metadata={}),\n", - " Document(page_content='Team: Blue Jays\\nPayroll (millions): 75.48\\nWins: 73', metadata={}),\n", - " Document(page_content='Team: Braves\\nPayroll (millions): 83.31\\nWins: 94', metadata={})]\n" + "[Document(page_content='Team: Angels\\nPayroll (millions): 154.49\\nWins: 89'),\n", + " Document(page_content='Team: Astros\\nPayroll (millions): 60.65\\nWins: 55'),\n", + " Document(page_content='Team: Athletics\\nPayroll (millions): 55.37\\nWins: 94'),\n", + " Document(page_content='Team: Blue Jays\\nPayroll (millions): 75.48\\nWins: 73'),\n", + " Document(page_content='Team: Braves\\nPayroll (millions): 83.31\\nWins: 94')]\n" ] } ], diff --git a/framework/langchain/document_loader.py b/framework/langchain/document_loader.py index e6710af2..48e2d9fa 100644 --- a/framework/langchain/document_loader.py +++ b/framework/langchain/document_loader.py @@ -28,10 +28,16 @@ from pprint import pprint +CONNECTION_STRING = os.environ.get( + "CRATEDB_CONNECTION_STRING", + "crate://crate@localhost/?schema=doc" +) + + def main(): loader = CrateDBLoader( query="SELECT * FROM mlb_teams_2012 LIMIT 3;", - url=os.environ.get("CRATEDB_CONNECTION_STRING"), + url=CONNECTION_STRING, include_rownum_into_metadata=True, ) docs = loader.load() diff --git a/framework/langchain/pyproject.toml b/framework/langchain/pyproject.toml new file mode 100644 index 00000000..b9e54d6b --- /dev/null +++ b/framework/langchain/pyproject.toml @@ -0,0 +1,45 @@ +[tool.pytest.ini_options] +minversion = "2.0" +addopts = """ + -rfEX -p pytester --strict-markers --verbosity=3 --capture=no + """ +# --cov=. --cov-report=term-missing --cov-report=xml +env = [ + "CRATEDB_CONNECTION_STRING=crate://crate@localhost/?schema=testdrive", + "PYDEVD_DISABLE_FILE_VALIDATION=1", +] + +#log_level = "DEBUG" +#log_cli_level = "DEBUG" + +testpaths = [ + "*.py", +] +xfail_strict = true +markers = [ +] + +# pytest-notebook settings +nb_test_files = true +nb_coverage = true +nb_diff_replace = [ + # Compensate output of `crash`. + '"/cells/*/outputs/*/text" "\(\d.\d+ sec\)" "(0.000 sec)"', +] +# `vector_search.py` does not include any output(s). +nb_diff_ignore = [ + "/metadata/language_info", + "/cells/*/execution_count", + "/cells/*/outputs/*/execution_count", +] + +[tool.coverage.run] +branch = false + +[tool.coverage.report] +fail_under = 0 +show_missing = true +omit = [ + "conftest.py", + "test*.py", +] diff --git a/framework/langchain/readme.md b/framework/langchain/readme.md index fb6d52d4..f5ff852d 100644 --- a/framework/langchain/readme.md +++ b/framework/langchain/readme.md @@ -100,6 +100,25 @@ a cloud-based development environment is up and running. As soon as your project easily move to a different cluster tier or scale horizontally. +## Testing + +Run all tests. +```shell +pytest +``` + +Run tests selectively. +```shell +pytest -k document_loader +pytest -k "notebook and loader" +``` + +In order to force a regeneration of the Jupyter Notebook, use the +`--nb-force-regen` option. +```shell +pytest -k document_loader --nb-force-regen +``` + [Agents]: https://python.langchain.com/docs/modules/agents/ [Callbacks]: https://python.langchain.com/docs/modules/callbacks/ diff --git a/framework/langchain/requirements-dev.txt b/framework/langchain/requirements-dev.txt new file mode 100644 index 00000000..e0769c87 --- /dev/null +++ b/framework/langchain/requirements-dev.txt @@ -0,0 +1,6 @@ +coverage~=5.0 +ipykernel +pytest<8 +pytest-cov<5 +pytest-env<2 +pytest-notebook<0.9 diff --git a/framework/langchain/test.py b/framework/langchain/test.py new file mode 100644 index 00000000..b6f537e4 --- /dev/null +++ b/framework/langchain/test.py @@ -0,0 +1,122 @@ +import importlib +import io +import os +import sys +from pathlib import Path +from unittest import mock + +import pytest +from _pytest.python import Function + +HERE = Path(__file__).parent + + +def list_files(path: Path, pattern: str): + """ + Enumerate all files in given directory. + """ + files = path.glob(pattern) + files = [item.relative_to(path) for item in files] + return files + + +def list_notebooks(path: Path): + """ + Enumerate all Jupyter Notebook files found in given directory. + """ + return list_files(path, "**/*.ipynb") + + +def list_pyfiles(path: Path): + """ + Enumerate all regular Python files found in given directory. + """ + pyfiles = [] + for item in list_files(path, "**/*.py"): + if item.suffix != ".py" or item.name in ["conftest.py"] or item.name.startswith("test"): + continue + pyfiles.append(item) + return pyfiles + + +def str_list(things): + """ + Converge list to list of strings. + """ + return map(str, things) + + +@pytest.fixture(scope="function", autouse=True) +def db_init(): + """ + Initialize database. + """ + run_sql(statement="DROP TABLE IF EXISTS mlb_teams_2012;") + + +def db_provision_mlb_teams_2012(): + """ + Provision database. + """ + run_sql(file="mlb_teams_2012.sql") + run_sql(statement="REFRESH TABLE mlb_teams_2012;") + + +def run_sql(statement: str = None, file: str = None): + """ + Run SQL from string or file. + """ + import crate.crash.command + sys.argv = ["foo", "--schema=testdrive"] + if statement: + sys.argv += ["--command", statement] + if file: + sys.stdin = io.StringIO(Path(file).read_text()) + with \ + mock.patch("crate.crash.repl.SQLCompleter._populate_keywords"), \ + mock.patch("crate.crash.command.CrateShell.close"): + try: + crate.crash.command.main() + except SystemExit as ex: + if ex.code != 0: + raise + + +@pytest.mark.parametrize("notebook", str_list(list_notebooks(HERE))) +def test_notebook(request, notebook: str): + """ + From individual Jupyter Notebook file, collect cells as pytest + test cases, and run them. + + Not using `NBRegressionFixture`, because it would manually need to be configured. + """ + from _pytest._py.path import LocalPath + from pytest_notebook.plugin import pytest_collect_file + tests = pytest_collect_file(LocalPath(notebook), request.node) + for test in tests.collect(): + test.runtest() + + +@pytest.mark.parametrize("pyfile", str_list(list_pyfiles(HERE))) +def test_file(request, pyfile: Path): + """ + From individual Python file, collect and wrap the `main` function into a test case. + """ + + # TODO: Make configurable. + entrypoint_symbol = "main" + + # Skip `vector_search.py` example, when no `OPENAI_API_KEY` is supplied. + if str(pyfile).endswith("vector_search.py"): + if "OPENAI_API_KEY" not in os.environ: + raise pytest.skip("OPENAI_API_KEY not given") + + # `document_loader.py` needs provisioning. + if str(pyfile).endswith("document_loader.py"): + db_provision_mlb_teams_2012() + + path = Path(pyfile) + mod = importlib.import_module(path.stem) + fun = getattr(mod, entrypoint_symbol) + f = Function.from_parent(request.node, name="main", callobj=fun) + f.runtest() diff --git a/framework/langchain/vector_search.ipynb b/framework/langchain/vector_search.ipynb index 2b2353f6..00921d3f 100644 --- a/framework/langchain/vector_search.ipynb +++ b/framework/langchain/vector_search.ipynb @@ -61,8 +61,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You need to provide an OpenAI API key, optionally using the environment\n", - "variable `OPENAI_API_KEY`." + "You need to provide an OpenAI API key, using the environment variable\n", + "`OPENAI_API_KEY`, or by defining it within an `.env` file.\n", + "\n", + "```shell\n", + "export OPENAI_API_KEY=sk-YOUR_OPENAI_API_KEY\n", + "```" ] }, { @@ -80,23 +84,29 @@ "import getpass\n", "from dotenv import load_dotenv, find_dotenv\n", "\n", - "# Run `export OPENAI_API_KEY=sk-YOUR_OPENAI_API_KEY`.\n", - "# Get OpenAI api key from `.env` file.\n", - "# Otherwise, prompt for it.\n", - "_ = load_dotenv(find_dotenv())\n", - "OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', getpass.getpass(\"OpenAI API key:\"))\n", - "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY" + "# Load OpenAI API key from `OPENAI_API_KEY` environment variable or `.env` file.\n", + "# If it is not defined, prompt interactively.\n", + "load_dotenv(find_dotenv())\n", + "if \"OPENAI_API_KEY\" not in os.environ:\n", + " if \"PYTEST_CURRENT_TEST\" in os.environ:\n", + " import pytest\n", + " pytest.exit(\"OPENAI_API_KEY not given [skip-notebook]\")\n", + " else:\n", + " os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API key:\")\n", + "\n", + "# FIXME: Needs a patch to work, see https://github.com/chrisjsewell/pytest-notebook/issues/43.\n", + "# TODO: Refactor to `pueblo.getenvpass(\"OPENAI_API_KEY\", prompt=\"OpenAI API key:\")`." ] }, { "cell_type": "markdown", "source": [ - "You also need to provide a connection string to your CrateDB database cluster,\n", - "optionally using the environment variable `CRATEDB_CONNECTION_STRING`.\n", + "You can also provide a connection string to your CrateDB database cluster,\n", + "using the environment variable `CRATEDB_CONNECTION_STRING`.\n", "\n", - "This example uses a CrateDB instance on your workstation, which you can start by\n", - "running [CrateDB using Docker]. Alternatively, you can also connect to a cluster\n", - "running on [CrateDB Cloud].\n", + "By default, the notebook will connect to a CrateDB server instance running on `localhost`.\n", + "You can start a sandbox instance on your workstation by running [CrateDB using Docker].\n", + "Alternatively, you can also connect to a cluster running on [CrateDB Cloud].\n", "\n", "[CrateDB Cloud]: https://console.cratedb.cloud/\n", "[CrateDB using Docker]: https://crate.io/docs/crate/tutorials/en/latest/basic/index.html#docker" @@ -112,12 +122,13 @@ "source": [ "import os\n", "\n", + "# Connect to a self-managed CrateDB instance.\n", "CONNECTION_STRING = os.environ.get(\n", " \"CRATEDB_CONNECTION_STRING\",\n", " \"crate://crate@localhost/?schema=notebook\",\n", ")\n", "\n", - "# For CrateDB Cloud, use:\n", + "# Connect to CrateDB Cloud.\n", "# CONNECTION_STRING = os.environ.get(\n", "# \"CRATEDB_CONNECTION_STRING\",\n", "# \"crate://username:password@hostname/?ssl=true&schema=notebook\",\n", @@ -134,11 +145,12 @@ "ExecuteTime": { "end_time": "2023-09-09T08:02:28.174088Z", "start_time": "2023-09-09T08:02:28.162698Z" - } + }, + "collapsed": true }, "outputs": [], "source": [ - "\"\"\"\n", + "_ = \"\"\"\n", "# Alternatively, the connection string can be assembled from individual\n", "# environment variables.\n", "import os\n", @@ -399,7 +411,7 @@ "### Overwriting a vector store\n", "\n", "If you have an existing collection, you can overwrite it by using `from_documents`,\n", - "aad setting `pre_delete_collection = True`." + "and setting `pre_delete_collection = True`." ] }, { diff --git a/framework/langchain/vector_search.py b/framework/langchain/vector_search.py index 8b71d11d..28e0d770 100644 --- a/framework/langchain/vector_search.py +++ b/framework/langchain/vector_search.py @@ -23,7 +23,7 @@ # Run program. python vector_search.py """ # noqa: E501 -from langchain.document_loaders import TextLoader +from langchain.document_loaders import UnstructuredURLLoader from langchain.embeddings.openai import OpenAIEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import CrateDBVectorSearch @@ -33,7 +33,8 @@ def main(): # Load the document, split it into chunks, embed each chunk, # and load it into the vector store. - raw_documents = TextLoader("state_of_the_union.txt").load() + state_of_the_union_url = "https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt" + raw_documents = UnstructuredURLLoader(urls=[state_of_the_union_url]).load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) documents = text_splitter.split_documents(raw_documents) db = CrateDBVectorSearch.from_documents(documents, OpenAIEmbeddings())