diff --git a/.github/actions/run-notebook/action.yml b/.github/actions/run-notebook/action.yml new file mode 100644 index 00000000..80533f2c --- /dev/null +++ b/.github/actions/run-notebook/action.yml @@ -0,0 +1,51 @@ +name: "Run Notebook" +description: "Run a notebook" + +inputs: + notebook: + description: "The notebook to run" + required: true + PINECONE_API_KEY: + description: "The Pinecone API key" + required: true + OPENAI_API_KEY: + description: "The OpenAI API key" + required: true + +runs: + using: 'composite' + steps: + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + shell: bash + run: | + pip install --upgrade pip + pip install nbformat + + - id: convert + shell: bash + name: Convert notebook into tmpdir script + run: | + python .github/actions/run-notebook/convert-notebook.py ${{ inputs.notebook }} + + - name: View the run script + shell: bash + run: | + cat ${{ steps.convert.outputs.script_path }} + + - name: View converted notebook content + shell: bash + run: | + cat ${{ steps.convert.outputs.notebook_path }} + + - name: Run the converted notebook + shell: bash + run: | + bash ${{ steps.convert.outputs.script_path }} + env: + PINECONE_API_KEY: ${{ inputs.PINECONE_API_KEY }} + OPENAI_API_KEY: ${{ inputs.OPENAI_API_KEY }} \ No newline at end of file diff --git a/.github/actions/run-notebook/convert-notebook.py b/.github/actions/run-notebook/convert-notebook.py new file mode 100755 index 00000000..2dfe2c57 --- /dev/null +++ b/.github/actions/run-notebook/convert-notebook.py @@ -0,0 +1,89 @@ +#! /usr/bin/env python + +# Convert a notebook to a Python script + +import os +import sys +import nbformat +import shutil +from tempfile import mkdtemp +from tempfile import TemporaryDirectory + +# Get the notebook filename from the command line +filename = "../../../" + sys.argv[1] +print(f"Processing notebook: {filename}") +nb_source_path = os.path.join(os.path.dirname(__file__), filename) + +temp_dir = mkdtemp() +venv_path = os.path.join(temp_dir, 'venv') +os.makedirs(venv_path, exist_ok=True) + +# Copy file into temp directory +temp_nb_path = os.path.join(temp_dir, 'notebook.ipynb') +print(f"Copying notebook to {temp_nb_path}") +shutil.copy(nb_source_path, temp_nb_path) + +with open(temp_nb_path, "r", encoding="utf-8") as f: + nb = nbformat.read(f, as_version=4) + +# Extract pip install commands (assumes they are written as "!pip install ..." or "%pip install ...") +# This grabs any line containing "pip install" in the script. +activate_venv = """ +#!/bin/bash + +set -ex + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Create new virtual environment +python -m venv "${SCRIPT_DIR}/venv" + +# Activate the virtual environment +source "${SCRIPT_DIR}/venv/bin/activate" +pip install --upgrade pip +""" +run_commands = [activate_venv] +for cell in nb.cells: + if cell.cell_type == "code": + if "!pip" in cell.source or "%pip" in cell.source: + # Replace all instances of "!pip" and "%pip" with "pip" + command = cell.source.replace("!pip", "pip").replace("%pip", "pip") + run_commands.append(command) + +run_commands.append(""" +# Run the notebook executable code +python "${SCRIPT_DIR}/notebook.py" +""") + +run_commands.append(""" +# Deactivate the virtual environment +deactivate +""") + +# Save pip install commands to a run.sh script +run_script_path = os.path.join(temp_dir, 'run.sh') +with open(run_script_path, 'w', encoding="utf-8") as f: + f.write("\n".join(run_commands)) + +print(f"Setup script saved to {run_script_path}") + +# Collect cells that are not pip install commands +executable_cells = [] +for cell in nb.cells: + if cell.cell_type == "code": + if "pip" not in cell.source: + executable_cells.append(cell) + +# Save executable cells to a notebook.py file +script_path = os.path.join(temp_dir, 'notebook.py') +with open(script_path, 'w', encoding="utf-8") as f: + for cell in executable_cells: + f.write(cell.source + '\n') + +print(f"Script saved to {script_path}") + +# Output script and notebook path to github actions output +with open(os.environ['GITHUB_OUTPUT'], 'a') as f: + f.write(f"script_path={run_script_path}\n") + f.write(f"notebook_path={script_path}\n") + \ No newline at end of file diff --git a/.github/actions/validate-json/action.yml b/.github/actions/validate-json/action.yml new file mode 100644 index 00000000..b000e2d2 --- /dev/null +++ b/.github/actions/validate-json/action.yml @@ -0,0 +1,21 @@ +name: "Validate JSON" +description: "Validate JSON" + +runs: + using: 'composite' + steps: + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + shell: bash + run: | + pip install --upgrade pip + pip install nbformat + + - name: Validate all notebooks + shell: bash + run: | + python .github/actions/validate-json/validate-notebook-formats.py \ No newline at end of file diff --git a/.github/scripts/validate-notebook-formats.py b/.github/actions/validate-json/validate-notebook-formats.py similarity index 100% rename from .github/scripts/validate-notebook-formats.py rename to .github/actions/validate-json/validate-notebook-formats.py diff --git a/.github/scripts/version-census.py b/.github/scripts/version-census.py index 029ae619..9a4ee636 100644 --- a/.github/scripts/version-census.py +++ b/.github/scripts/version-census.py @@ -98,7 +98,8 @@ def main(): print() print(f"Notebooks using {client_type}:") for version, notebooks in sorted(pinecone_versions.items()): - if client_type in version: + client = version.split("==")[0] + if client_type == client: print(f" {version}: {len(notebooks)} notebooks") for notebook in notebooks: print(" - ", notebook) diff --git a/.github/workflows/client-versions.yaml b/.github/workflows/client-versions.yaml index 60744195..cc6e1eaf 100644 --- a/.github/workflows/client-versions.yaml +++ b/.github/workflows/client-versions.yaml @@ -1,10 +1,8 @@ name: "Report: Client Version Usage" on: - push: - branches: - - main - pull_request: + workflow_dispatch: + workflow_call: jobs: analyze-client-versions: diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml new file mode 100644 index 00000000..c5ce87de --- /dev/null +++ b/.github/workflows/pr.yaml @@ -0,0 +1,16 @@ +name: "Pull Request" + +on: + pull_request: + push: + branches: + - main + +jobs: + report-client-versions: + uses: './.github/workflows/client-versions.yaml' + secrets: inherit + + test-notebooks: + uses: './.github/workflows/test-notebooks-changed.yaml' + secrets: inherit diff --git a/.github/workflows/test-notebooks-all.yaml b/.github/workflows/test-notebooks-all.yaml new file mode 100644 index 00000000..a22bb66b --- /dev/null +++ b/.github/workflows/test-notebooks-all.yaml @@ -0,0 +1,46 @@ +name: "Test: All Notebooks" + +on: + workflow_dispatch: + inputs: + directory: + description: 'Directory to search for notebooks' + required: true + default: 'docs' + type: string + +jobs: + validate-notebooks: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/validate-json + + list-notebooks: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + - name: Find all *.ipynb files + id: set-matrix + run: | + # Get list of all .ipynb files in target directory + NOTEBOOKS=$(find ${{ inputs.directory }} -name "*.ipynb" | jq -R -s -c 'split("\n")[:-1]') + echo "matrix={\"notebook\":$NOTEBOOKS}" >> $GITHUB_OUTPUT + + test-notebooks: + needs: list-notebooks + runs-on: ubuntu-latest + strategy: + fail-fast: false + max-parallel: 10 + matrix: ${{ fromJSON(needs.list-notebooks.outputs.matrix) }} + steps: + - uses: actions/checkout@v4 + + - uses: ./.github/actions/run-notebook + with: + notebook: ${{ matrix.notebook }} + PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} diff --git a/.github/workflows/test-notebooks-changed.yaml b/.github/workflows/test-notebooks-changed.yaml new file mode 100644 index 00000000..2d329fcd --- /dev/null +++ b/.github/workflows/test-notebooks-changed.yaml @@ -0,0 +1,62 @@ +name: "Test: Notebook Execution" + +on: + workflow_call: + inputs: + base_ref: + required: false + type: string + default: 'master' + +jobs: + validate-notebooks: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/validate-json + + detect-changes: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + has_changes: ${{ steps.set-matrix.outputs.has_changes }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Required for git diff + + - name: Fetch base branch + run: git fetch origin ${{ inputs.base_ref }} + + - name: Detect changed notebooks + id: set-matrix + run: | + # Get list of changed .ipynb files + CHANGED_NOTEBOOKS=$(git diff --name-only origin/${{ inputs.base_ref }}...HEAD | grep '\.ipynb$' || true) + if [ -z "$CHANGED_NOTEBOOKS" ]; then + echo "No notebook changes detected" + echo "has_changes=false" >> $GITHUB_OUTPUT + echo "matrix={\"notebook\":[]}" >> $GITHUB_OUTPUT + else + # Convert newlines to JSON array format + NOTEBOOK_LIST=$(echo "$CHANGED_NOTEBOOKS" | jq -R -s -c 'split("\n")[:-1]') + echo "has_changes=true" >> $GITHUB_OUTPUT + echo "matrix={\"notebook\":$NOTEBOOK_LIST}" >> $GITHUB_OUTPUT + fi + + test-notebooks: + needs: + - detect-changes + - validate-notebooks + if: needs.detect-changes.outputs.has_changes == 'true' + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: ${{ fromJSON(needs.detect-changes.outputs.matrix) }} + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/run-notebook + with: + notebook: ${{ matrix.notebook }} + PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} diff --git a/.github/workflows/validate.yaml b/.github/workflows/validate.yaml deleted file mode 100644 index f8bca719..00000000 --- a/.github/workflows/validate.yaml +++ /dev/null @@ -1,25 +0,0 @@ -name: Validate Notebook JSON - -on: - push: - branches: - - main - pull_request: - -jobs: - validate-notebooks: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.11' - - - name: Install dependencies - run: pip install nbformat - - - name: Validate all notebooks - run: | - python .github/scripts/validate-notebook-formats.py \ No newline at end of file diff --git a/docs/pinecone-reranker.ipynb b/docs/pinecone-reranker.ipynb index 833ba53f..f6d44ffb 100644 --- a/docs/pinecone-reranker.ipynb +++ b/docs/pinecone-reranker.ipynb @@ -1,550 +1,910 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "EE8es7PuO2RM", - "metadata": { - "id": "EE8es7PuO2RM" - }, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/docs/pinecone-reranker.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/docs/pinecone-reranker.ipynb)\n", - "\n", - "# Pinecone Serverless Reranking in Action\n", - "\n", - "### Overview\n", - "\n", - "\n", - "Reranking models are designed to enhance search relevance. They work by assessing the similarity between a query and a document, producing a numerical score that reflects how well the document matches the query. This score is then used to reorder documents, prioritizing those most relevant to the user's search.\n", - "\n", - "The process of reranking is crucial in improving the quality of information presented to users or supplied as context to Large Language Models (LLMs) by helping to filter out less relevant results and bringing the most pertinent information to the forefront.\n", - "\n", - "We now offer reranking support within the Pinecone Inference API. This feature eliminates the need for users to manage and deploy these models themselves. You can find a more through overview of our reranking [here](https://www.pinecone.io/learn/refine-with-rerank/).\n", - "\n", - "Below is the flow of a sample application utilizing a reranker:" - ] - }, - { - "cell_type": "markdown", - "id": "mZVVVzs2dQI0", - "metadata": { - "id": "mZVVVzs2dQI0" - }, - "source": [ - "![reranker.png]()" - ] - }, - { - "cell_type": "markdown", - "id": "OzKJdRqsfHIC", - "metadata": { - "id": "OzKJdRqsfHIC" - }, - "source": [ - "### Steps in This Notebook:\n", - "\n", - "1. **Load Libraries**\n", - "2. **Load Small Documents Object**\n", - "3. **Execute Reranking Model**\n", - "4. **Show Results**\n", - "5. **Create Index**\n", - "6. **Upsert Sample Data**\n", - "7. **Embed Query**\n", - "8. **Execute Search**\n", - "9. **View Results**\n", - "10. **Rerank Results**\n", - "\n", - "\n", - "The main dataset we will be using consists of randomly generated doctor’s notes sample data. The original JSON data has been embedded into vectors, which we will load into Pinecone.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "Ns7xj3uxO2RO", - "metadata": { - "id": "Ns7xj3uxO2RO" - }, - "outputs": [], - "source": [ - "# Installation\n", - "!pip install -U pinecone-client\n", - "!pip install -U --pre pinecone-plugin-inference\n", - "!pip install -U pinecone-notebooks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "_NsyrR-1Z02X", - "metadata": { - "id": "_NsyrR-1Z02X" - }, - "outputs": [], - "source": [ - "import os\n", - "\n", - "if not os.environ.get(\"PINECONE_API_KEY\"):\n", - " from pinecone_notebooks.colab import Authenticate\n", - " Authenticate()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "tGxwpB7OZjFn", - "metadata": { - "id": "tGxwpB7OZjFn" - }, - "outputs": [], - "source": [ - "from pinecone import Pinecone\n", - "\n", - "api_key = os.environ.get(\"PINECONE_API_KEY\")\n", - "\n", - "# configure client\n", - "pc = Pinecone(api_key=api_key)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "uj9tOi7WO2RP", - "metadata": { - "id": "uj9tOi7WO2RP" - }, - "outputs": [], - "source": [ - "# Create query and documents\n", - "query = \"Tell me about Apple's products\"\n", - "documents = [\n", - " \"Apple is a popular fruit known for its sweetness and crisp texture.\",\n", - " \"Apple is known for its innovative products like the iPhone.\",\n", - " \"Many people enjoy eating apples as a healthy snack.\",\n", - " \"Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces.\",\n", - " \"An apple a day keeps the doctor away, as the saying goes.\"\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4xGTyQv3g7iR", - "metadata": { - "id": "4xGTyQv3g7iR" - }, - "outputs": [], - "source": [ - "documents" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "Jx5F7QYPO2RP", - "metadata": { - "id": "Jx5F7QYPO2RP" - }, - "outputs": [], - "source": [ - "# Perform reranking to get top_n results based on the query\n", - "reranked_results = pc.inference.rerank(\n", - " model=\"bge-reranker-v2-m3\",\n", - " query=query,\n", - " documents=documents,\n", - " top_n=3,\n", - " return_documents=True\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "jY9NtvKMO2RP", - "metadata": { - "id": "jY9NtvKMO2RP" - }, - "outputs": [], - "source": [ - "# Display the reranked results\n", - "print(query)\n", - "\n", - "print(\"Top 3 Reranked Documents:\")\n", - "for i, entry in enumerate(reranked_results.data): # Access the 'data' attribute\n", - " document_text = entry['document']['text'] # Extract the text of the document\n", - " score = entry['score'] # Extract the score\n", - " print(f\"{i+1}: Score: {score}, Document: {document_text}\")\n", - "\n", - "#Note the reranker ranks Apple the company over apple the fruit based on the context of the query" - ] - }, - { - "cell_type": "markdown", - "id": "dC73hnorO2RQ", - "metadata": { - "id": "dC73hnorO2RQ" - }, - "source": [ - "### Enhanced Medical Note Retrieval for Improved Clinical Decision-Making\n", - "**Scenario**: A healthcare system allows doctors to search through a large dataset of medical notes to find relevant patient information.\n", - "\n", - "**Application**: After an initial list of relevant notes is generated from a search query, a reranker can fine-tune the order by considering factors such as the specificity of the medical conditions mentioned, and the relevance to the patient's current symptoms or treatment plan. This ensures that the most critical and contextually relevant notes are presented first, aiding in quicker and more accurate clinical decision-making." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "vqdr1UlDO2RP", - "metadata": { - "id": "vqdr1UlDO2RP" - }, - "outputs": [], - "source": [ - "import os\n", - "import time\n", - "import pandas as pd\n", - "from google.colab import files\n", - "from pinecone import Pinecone, ServerlessSpec\n", - "from transformers import AutoTokenizer, AutoModel\n", - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "UvZ5s6SsZPG_", - "metadata": { - "id": "UvZ5s6SsZPG_" - }, - "outputs": [], - "source": [ - "# Get cloud and region settings\n", - "cloud = os.getenv('PINECONE_CLOUD', 'aws')\n", - "region = os.getenv('PINECONE_REGION', 'us-east-1')\n", - "\n", - "# Define serverless specifications\n", - "spec = ServerlessSpec(cloud=cloud, region=region)\n", - "\n", - "# Define index name\n", - "index_name = 'pinecone-reranker'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ySCKGQ8XDx43", - "metadata": { - "id": "ySCKGQ8XDx43" - }, - "outputs": [], - "source": [ - "import time\n", - "\n", - "if index_name in pc.list_indexes().names():\n", - " pc.delete_index(index_name)\n", - "\n", - "# Create a new index\n", - "pc.create_index(index_name, dimension=384, metric='cosine', spec=spec)\n", - "\n", - "# wait for index to be initialized\n", - "while not pc.describe_index(index_name).status['ready']:\n", - " time.sleep(1)\n", - "print(f\"Index {index_name} has been successfully created.\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "sSHvtj9kEJS4", - "metadata": { - "id": "sSHvtj9kEJS4" - }, - "outputs": [], - "source": [ - "index = pc.Index(index_name)\n", - "# wait a moment for connection\n", - "time.sleep(1)\n", - "\n", - "index.describe_index_stats()" - ] - }, - { - "cell_type": "markdown", - "id": "xOXoiGFeVaZ_", - "metadata": { - "id": "xOXoiGFeVaZ_" - }, - "source": [ - "### Load Data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "TP27ES75VOGv", - "metadata": { - "id": "TP27ES75VOGv" - }, - "outputs": [], - "source": [ - "# Step 1: Upload the file from from the repository: docs/data/sample_notes_data.jsonl to Google Colab\n", - "uploaded = files.upload()\n", - "\n", - "# Step 2: Assuming the file is uploaded, read it into a DataFrame\n", - "file_name = next(iter(uploaded.keys()))\n", - "df = pd.read_json(file_name, orient='records', lines=True)\n", - "\n", - "# Show head of the DataFrame\n", - "df.head()" - ] - }, - { - "cell_type": "markdown", - "id": "MKvl9Pr9c6Vs", - "metadata": { - "id": "MKvl9Pr9c6Vs" - }, - "source": [ - "### Connect to Index and UPSERT Data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "-TDP4VMQVu4C", - "metadata": { - "id": "-TDP4VMQVu4C" - }, - "outputs": [], - "source": [ - "# Connect to index\n", - "index = pc.Index(index_name)\n", - "time.sleep(1)\n", - "\n", - "# Upsert data into index from DataFrame\n", - "index.upsert_from_dataframe(df)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eu43tFQwg3YE", - "metadata": { - "id": "eu43tFQwg3YE" - }, - "outputs": [], - "source": [ - "# View index stats\n", - "index.describe_index_stats()" - ] - }, - { - "cell_type": "markdown", - "id": "POAoISsAeZAt", - "metadata": { - "id": "POAoISsAeZAt" - }, - "source": [ - "## Embedding Function\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "h4lkfpnPeXmx", - "metadata": { - "id": "h4lkfpnPeXmx" - }, - "outputs": [], - "source": [ - "def get_embedding(input_question):\n", - " model_name = 'sentence-transformers/all-MiniLM-L6-v2' #Hugging Face Model\n", - " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", - " model = AutoModel.from_pretrained(model_name)\n", - "\n", - " encoded_input = tokenizer(input_question, padding=True, truncation=True, return_tensors='pt')\n", - "\n", - " with torch.no_grad():\n", - " model_output = model(**encoded_input)\n", - "\n", - " embedding = model_output.last_hidden_state[0].mean(dim=0)\n", - " return embedding" - ] - }, - { - "cell_type": "markdown", - "id": "RL9odEJ9dDSG", - "metadata": { - "id": "RL9odEJ9dDSG" - }, - "source": [ - "## Execute Query" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "xskqy0AbV14d", - "metadata": { - "id": "xskqy0AbV14d" - }, - "outputs": [], - "source": [ - "# Build a query to search\n", - "question = \"what if my patient has leg pain\"\n", - "query = get_embedding(question).tolist()\n", - "\n", - "# Get results\n", - "results = index.query(vector=[query], top_k=10, include_metadata=True)\n", - "\n", - "# Sort results by score in descending order\n", - "sorted_matches = sorted(results['matches'], key=lambda x: x['score'], reverse=True)" - ] - }, - { - "cell_type": "markdown", - "id": "JkVM2XQpdPvv", - "metadata": { - "id": "JkVM2XQpdPvv" - }, - "source": [ - "## Show Results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eLVSmaHxV8XP", - "metadata": { - "id": "eLVSmaHxV8XP" - }, - "outputs": [], - "source": [ - "# Print results\n", - "print(f'Original question: {question}')\n", - "print('---\\nResults:\\n--------------')\n", - "for match in sorted_matches:\n", - " print(f'ID: {match[\"id\"]}')\n", - " print(f'Score: {match[\"score\"]}')\n", - " print(f'Metadata: {match[\"metadata\"]}')\n", - " print('---')" - ] - }, - { - "cell_type": "markdown", - "id": "-L62zHVmdcQP", - "metadata": { - "id": "-L62zHVmdcQP" - }, - "source": [ - "### Perform Rerank" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32WVD7lPo7Tb", - "metadata": { - "id": "32WVD7lPo7Tb" - }, - "outputs": [], - "source": [ - "# Create documents with concatenated metadata field as \"reranking_field\" field\n", - "documents = [\n", - " {\n", - " 'id': match['id'],\n", - " 'reranking_field': ' '.join([f\"{key}: {value}\" for key, value in match['metadata'].items()])\n", - " }\n", - " for match in results['matches']\n", - "]\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "WLWt80fwq18w", - "metadata": { - "id": "WLWt80fwq18w" - }, - "outputs": [], - "source": [ - "# Define a more specific query for reranking\n", - "query = \"what if my patient had knee surgery\"\n", - "\n", - "# Perform reranking based on the query and specified field\n", - "reranked_results_field = pc.inference.rerank(\n", - " model=\"bge-reranker-v2-m3\",\n", - " query=query,\n", - " documents=documents, # Use the transformed documents\n", - " rank_fields=[\"reranking_field\"],\n", - " top_n=2,\n", - " return_documents=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0iDXuSUYsTm4", - "metadata": { - "id": "0iDXuSUYsTm4" - }, - "outputs": [], - "source": [ - "# Print results\n", - "print(f'Original question: {query}')\n", - "print('---\\nResults:\\n--------------')\n", - "for match in reranked_results_field.data:\n", - " print(f'ID: {match[\"document\"][\"id\"]}')\n", - " print(f'Score: {match[\"score\"]:.6f}')\n", - " print(f'Text: {match[\"document\"][\"reranking_field\"]}')\n", - " print('---')\n" - ] - }, - { - "cell_type": "markdown", - "id": "8PQTLfT-Frv8", - "metadata": { - "id": "8PQTLfT-Frv8" - }, - "source": [ - "Now let's delete the index to save resources" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "lEooi2--F0yR", - "metadata": { - "id": "lEooi2--F0yR" - }, - "outputs": [], - "source": [ - "pc.delete_index(index_name)" - ] + "cells": [ + { + "cell_type": "markdown", + "id": "EE8es7PuO2RM", + "metadata": { + "id": "EE8es7PuO2RM" + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/docs/pinecone-reranker.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/docs/pinecone-reranker.ipynb)\n", + "\n", + "# Pinecone Serverless Reranking in Action\n", + "\n", + "### Overview\n", + "\n", + "\n", + "Reranking models are designed to enhance search relevance. They work by assessing the similarity between a query and a document, producing a numerical score that reflects how well the document matches the query. This score is then used to reorder documents, prioritizing those most relevant to the user's search.\n", + "\n", + "The process of reranking is crucial in improving the quality of information presented to users or supplied as context to Large Language Models (LLMs) by helping to filter out less relevant results and bringing the most pertinent information to the forefront.\n", + "\n", + "We now offer reranking support within the Pinecone Inference API. This feature eliminates the need for users to manage and deploy these models themselves. You can find a more through overview of our reranking [here](https://www.pinecone.io/learn/refine-with-rerank/).\n", + "\n", + "Below is the flow of a sample application utilizing a reranker:" + ] + }, + { + "cell_type": "markdown", + "id": "mZVVVzs2dQI0", + "metadata": { + "id": "mZVVVzs2dQI0" + }, + "source": [ + "![reranker.png]()" + ] + }, + { + "cell_type": "markdown", + "id": "OzKJdRqsfHIC", + "metadata": { + "id": "OzKJdRqsfHIC" + }, + "source": [ + "### Steps in This Notebook:\n", + "\n", + "1. **Load Libraries**\n", + "2. **Load Small Documents Object**\n", + "3. **Execute Reranking Model**\n", + "4. **Show Results**\n", + "5. **Create Index**\n", + "6. **Upsert Sample Data**\n", + "7. **Embed Query**\n", + "8. **Execute Search**\n", + "9. **View Results**\n", + "10. **Rerank Results**\n", + "\n", + "\n", + "The main dataset we will be using consists of randomly generated doctor’s notes sample data. The original JSON data has been embedded into vectors, which we will load into Pinecone.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "Ns7xj3uxO2RO", + "metadata": { + "id": "Ns7xj3uxO2RO" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: pinecone==6.0.1 in /opt/conda/lib/python3.12/site-packages (6.0.1)\n", + "Requirement already satisfied: certifi>=2019.11.17 in /opt/conda/lib/python3.12/site-packages (from pinecone==6.0.1) (2025.1.31)\n", + "Requirement already satisfied: pinecone-plugin-interface<0.0.8,>=0.0.7 in /opt/conda/lib/python3.12/site-packages (from pinecone==6.0.1) (0.0.7)\n", + "Requirement already satisfied: python-dateutil>=2.5.3 in /opt/conda/lib/python3.12/site-packages (from pinecone==6.0.1) (2.9.0.post0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4 in /opt/conda/lib/python3.12/site-packages (from pinecone==6.0.1) (4.12.2)\n", + "Requirement already satisfied: urllib3>=1.26.5 in /opt/conda/lib/python3.12/site-packages (from pinecone==6.0.1) (2.3.0)\n", + "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.12/site-packages (from python-dateutil>=2.5.3->pinecone==6.0.1) (1.17.0)\n", + "Requirement already satisfied: pinecone-notebooks in /opt/conda/lib/python3.12/site-packages (0.1.1)\n" + ] } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.8" + ], + "source": [ + "# Installation\n", + "!pip install -U pinecone==6.0.1\n", + "!pip install -U pinecone-notebooks" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "_NsyrR-1Z02X", + "metadata": { + "id": "_NsyrR-1Z02X" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "if not os.environ.get(\"PINECONE_API_KEY\"):\n", + " from pinecone_notebooks.colab import Authenticate\n", + " Authenticate()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "tGxwpB7OZjFn", + "metadata": { + "id": "tGxwpB7OZjFn" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from pinecone import Pinecone\n", + "\n", + "api_key = os.environ.get(\"PINECONE_API_KEY\")\n", + "\n", + "# Instantiate the Pinecone client\n", + "pc = Pinecone(api_key=api_key)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "uj9tOi7WO2RP", + "metadata": { + "id": "uj9tOi7WO2RP" + }, + "outputs": [], + "source": [ + "# Create query and documents\n", + "query = \"Tell me about Apple's products\"\n", + "documents = [\n", + " \"Apple is a popular fruit known for its sweetness and crisp texture.\",\n", + " \"Apple is known for its innovative products like the iPhone.\",\n", + " \"Many people enjoy eating apples as a healthy snack.\",\n", + " \"Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces.\",\n", + " \"An apple a day keeps the doctor away, as the saying goes.\"\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4xGTyQv3g7iR", + "metadata": { + "id": "4xGTyQv3g7iR" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['Apple is a popular fruit known for its sweetness and crisp texture.',\n", + " 'Apple is known for its innovative products like the iPhone.',\n", + " 'Many people enjoy eating apples as a healthy snack.',\n", + " 'Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces.',\n", + " 'An apple a day keeps the doctor away, as the saying goes.']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } + ], + "source": [ + "documents" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "Jx5F7QYPO2RP", + "metadata": { + "id": "Jx5F7QYPO2RP" + }, + "outputs": [], + "source": [ + "from pinecone import RerankModel\n", + "\n", + "# Perform reranking to get top_n results based on the query\n", + "reranked_results = pc.inference.rerank(\n", + " model=RerankModel.Bge_Reranker_V2_M3,\n", + " query=query,\n", + " documents=documents,\n", + " top_n=3,\n", + " return_documents=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "jY9NtvKMO2RP", + "metadata": { + "id": "jY9NtvKMO2RP" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Query: Tell me about Apple's products\n", + "Reranked Results:\n", + " 1. Score: 0.83907574\n", + " Document: Apple is known for its innovative products like the iPhone.\n", + "\n", + " 2. Score: 0.23196201\n", + " Document: Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces.\n", + "\n", + " 3. Score: 0.1742697\n", + " Document: Apple is a popular fruit known for its sweetness and crisp texture.\n", + "\n" + ] + } + ], + "source": [ + "def show_reranked_results(query, matches):\n", + " \"\"\"A utility function to print our reranked results\"\"\"\n", + " print(f'Query: {query}')\n", + " print('Reranked Results:')\n", + " for i, match in enumerate(matches):\n", + " print(f'{str(i+1).rjust(4)}. Score: {match.score}')\n", + " print(f' Document: {match.document.text}')\n", + " print('')\n", + "\n", + "# Note the reranker ranks Apple the company over apple the fruit based on the context of the query\n", + "show_reranked_results(query, reranked_results.data)" + ] + }, + { + "cell_type": "markdown", + "id": "dC73hnorO2RQ", + "metadata": { + "id": "dC73hnorO2RQ" + }, + "source": [ + "### Enhanced Medical Note Retrieval for Improved Clinical Decision-Making\n", + "**Scenario**: A healthcare system allows doctors to search through a large dataset of medical notes to find relevant patient information.\n", + "\n", + "**Application**: After an initial list of relevant notes is generated from a search query, a reranker can fine-tune the order by considering factors such as the specificity of the medical conditions mentioned, and the relevance to the patient's current symptoms or treatment plan. This ensures that the most critical and contextually relevant notes are presented first, aiding in quicker and more accurate clinical decision-making." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f73e8ba3-9ea9-45b5-9c1e-ddbffd31dc4d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: pandas in /opt/conda/lib/python3.12/site-packages (2.2.3)\n", + "Requirement already satisfied: torch in /opt/conda/lib/python3.12/site-packages (2.6.0)\n", + "Requirement already satisfied: transformers in /opt/conda/lib/python3.12/site-packages (4.49.0)\n", + "Requirement already satisfied: numpy>=1.26.0 in /opt/conda/lib/python3.12/site-packages (from pandas) (2.2.3)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.12/site-packages (from pandas) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.12/site-packages (from pandas) (2025.1)\n", + "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.12/site-packages (from pandas) (2025.1)\n", + "Requirement already satisfied: filelock in /opt/conda/lib/python3.12/site-packages (from torch) (3.17.0)\n", + "Requirement already satisfied: typing-extensions>=4.10.0 in /opt/conda/lib/python3.12/site-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: setuptools in /opt/conda/lib/python3.12/site-packages (from torch) (75.8.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /opt/conda/lib/python3.12/site-packages (from torch) (1.13.1)\n", + "Requirement already satisfied: networkx in /opt/conda/lib/python3.12/site-packages (from torch) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.12/site-packages (from torch) (3.1.5)\n", + "Requirement already satisfied: fsspec in /opt/conda/lib/python3.12/site-packages (from torch) (2025.2.0)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.12/site-packages (from sympy==1.13.1->torch) (1.3.0)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.26.0 in /opt/conda/lib/python3.12/site-packages (from transformers) (0.29.1)\n", + "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.12/site-packages (from transformers) (24.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.12/site-packages (from transformers) (6.0.2)\n", + "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.12/site-packages (from transformers) (2024.11.6)\n", + "Requirement already satisfied: requests in /opt/conda/lib/python3.12/site-packages (from transformers) (2.32.3)\n", + "Requirement already satisfied: tokenizers<0.22,>=0.21 in /opt/conda/lib/python3.12/site-packages (from transformers) (0.21.0)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /opt/conda/lib/python3.12/site-packages (from transformers) (0.5.2)\n", + "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.12/site-packages (from transformers) (4.67.1)\n", + "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas) (1.17.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.12/site-packages (from jinja2->torch) (3.0.2)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /opt/conda/lib/python3.12/site-packages (from requests->transformers) (3.4.1)\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.12/site-packages (from requests->transformers) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.12/site-packages (from requests->transformers) (2.3.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.12/site-packages (from requests->transformers) (2025.1.31)\n" + ] + } + ], + "source": [ + "!pip install pandas torch transformers" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "vqdr1UlDO2RP", + "metadata": { + "id": "vqdr1UlDO2RP" + }, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "import pandas as pd\n", + "from pinecone import Pinecone, ServerlessSpec\n", + "from transformers import AutoTokenizer, AutoModel\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "UvZ5s6SsZPG_", + "metadata": { + "id": "UvZ5s6SsZPG_" + }, + "outputs": [], + "source": [ + "# Get cloud and region settings\n", + "cloud = os.getenv('PINECONE_CLOUD', 'aws')\n", + "region = os.getenv('PINECONE_REGION', 'us-east-1')\n", + "\n", + "# Define serverless specifications\n", + "spec = ServerlessSpec(cloud=cloud, region=region)\n", + "\n", + "# Define index name\n", + "index_name = 'pinecone-reranker'" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ySCKGQ8XDx43", + "metadata": { + "id": "ySCKGQ8XDx43" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{\n", + " \"name\": \"pinecone-reranker\",\n", + " \"metric\": \"cosine\",\n", + " \"host\": \"pinecone-reranker-dojoi3u.svc.aped-4627-b74a.pinecone.io\",\n", + " \"spec\": {\n", + " \"serverless\": {\n", + " \"cloud\": \"aws\",\n", + " \"region\": \"us-east-1\"\n", + " }\n", + " },\n", + " \"status\": {\n", + " \"ready\": true,\n", + " \"state\": \"Ready\"\n", + " },\n", + " \"vector_type\": \"dense\",\n", + " \"dimension\": 384,\n", + " \"deletion_protection\": \"disabled\",\n", + " \"tags\": null\n", + "}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "if pc.has_index(name=index_name):\n", + " pc.delete_index(name=index_name)\n", + "\n", + "# Create a new index\n", + "pc.create_index(\n", + " name=index_name, \n", + " dimension=384, \n", + " metric='cosine', \n", + " spec=spec\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "xOXoiGFeVaZ_", + "metadata": { + "id": "xOXoiGFeVaZ_" + }, + "source": [ + "### Load Data" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "TP27ES75VOGv", + "metadata": { + "id": "TP27ES75VOGv" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idvaluesmetadata
0P011[-0.2027486265, 0.2769146562, -0.1509393603, 0...{'advice': 'rest, hydrate', 'symptoms': 'heada...
1P001[0.1842793673, 0.4459365904, -0.0770567134, 0....{'tests': 'EKG, stress test', 'symptoms': 'che...
2P002[-0.2040648609, -0.1739618927, -0.2897160649, ...{'HbA1c': '7.2', 'condition': 'diabetes', 'med...
3P003[0.1889383644, 0.2924542725, -0.2335938066, -0...{'symptoms': 'cough, wheezing', 'diagnosis': '...
4P004[-0.12171068040000001, 0.1674752235, -0.231888...{'referral': 'dermatology', 'condition': 'susp...
\n", + "
" + ], + "text/plain": [ + " id values \\\n", + "0 P011 [-0.2027486265, 0.2769146562, -0.1509393603, 0... \n", + "1 P001 [0.1842793673, 0.4459365904, -0.0770567134, 0.... \n", + "2 P002 [-0.2040648609, -0.1739618927, -0.2897160649, ... \n", + "3 P003 [0.1889383644, 0.2924542725, -0.2335938066, -0... \n", + "4 P004 [-0.12171068040000001, 0.1674752235, -0.231888... \n", + "\n", + " metadata \n", + "0 {'advice': 'rest, hydrate', 'symptoms': 'heada... \n", + "1 {'tests': 'EKG, stress test', 'symptoms': 'che... \n", + "2 {'HbA1c': '7.2', 'condition': 'diabetes', 'med... \n", + "3 {'symptoms': 'cough, wheezing', 'diagnosis': '... \n", + "4 {'referral': 'dermatology', 'condition': 'susp... " + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import requests\n", + "import tempfile\n", + "\n", + "with tempfile.TemporaryDirectory() as tmpdirname:\n", + " # Construct the full path for the file within the temporary directory.\n", + " file_path = os.path.join(tmpdirname, \"sample_notes_data.jsonl\")\n", + " \n", + " # Download the file from github\n", + " url = \"https://raw.githubusercontent.com/pinecone-io/examples/refs/heads/master/docs/data/sample_notes_data.jsonl\"\n", + " response = requests.get(url)\n", + " response.raise_for_status() # Raise an exception for any HTTP errors.\n", + " \n", + " # Write the file content to the temporary directory.\n", + " with open(file_path, \"wb\") as f:\n", + " f.write(response.content)\n", + "\n", + " df = pd.read_json(file_path, orient='records', lines=True)\n", + "\n", + "# Show head of the DataFrame\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "MKvl9Pr9c6Vs", + "metadata": { + "id": "MKvl9Pr9c6Vs" + }, + "source": [ + "### Upsert data to the index" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "-TDP4VMQVu4C", + "metadata": { + "id": "-TDP4VMQVu4C" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "sending upsert requests: 100%|██████████| 100/100 [00:00<00:00, 200.29it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "{'upserted_count': 100}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Instantiate an index client\n", + "index = pc.Index(name=index_name)\n", + "\n", + "# Upsert data into index from DataFrame\n", + "index.upsert_from_dataframe(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "144d6557-da46-46e5-901d-3cb5204a8d54", + "metadata": { + "id": "eu43tFQwg3YE" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Vector count: 0\n", + "Vector count: 0\n", + "Vector count: 0\n", + "Vector count: 0\n", + "Vector count: 0\n", + "Vector count: 0\n", + "Vector count: 0\n", + "Vector count: 0\n", + "Vector count: 0\n", + "Vector count: 100\n" + ] + }, + { + "data": { + "text/plain": [ + "{'dimension': 384,\n", + " 'index_fullness': 0.0,\n", + " 'metric': 'cosine',\n", + " 'namespaces': {'': {'vector_count': 100}},\n", + " 'total_vector_count': 100,\n", + " 'vector_type': 'dense'}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import time\n", + "\n", + "def is_fresh(index):\n", + " stats = index.describe_index_stats()\n", + " vector_count = stats.total_vector_count\n", + " print(f\"Vector count: \", vector_count)\n", + " return vector_count > 0\n", + "\n", + "while not is_fresh(index):\n", + " # It takes a few moments for vectors we just upserted\n", + " # to become available for querying\n", + " time.sleep(5)\n", + "\n", + "# View index stats\n", + "index.describe_index_stats()" + ] + }, + { + "cell_type": "markdown", + "id": "POAoISsAeZAt", + "metadata": { + "id": "POAoISsAeZAt" + }, + "source": [ + "## Embedding Function\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "h4lkfpnPeXmx", + "metadata": { + "id": "h4lkfpnPeXmx" + }, + "outputs": [], + "source": [ + "def get_embedding(input_question):\n", + " model_name = 'sentence-transformers/all-MiniLM-L6-v2' # HuggingFace Model\n", + " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + " model = AutoModel.from_pretrained(model_name)\n", + "\n", + " encoded_input = tokenizer(input_question, padding=True, truncation=True, return_tensors='pt')\n", + "\n", + " with torch.no_grad():\n", + " model_output = model(**encoded_input)\n", + "\n", + " embedding = model_output.last_hidden_state[0].mean(dim=0)\n", + " return embedding" + ] + }, + { + "cell_type": "markdown", + "id": "RL9odEJ9dDSG", + "metadata": { + "id": "RL9odEJ9dDSG" + }, + "source": [ + "## Execute Query" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "xskqy0AbV14d", + "metadata": { + "id": "xskqy0AbV14d" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/pytorch/third_party/ideep/mkl-dnn/src/cpu/aarch64/xbyak_aarch64/src/util_impl_linux.h, 451: Can't read MIDR_EL1 sysfs entry\n" + ] + } + ], + "source": [ + "# Build a query to search\n", + "question = \"what if my patient has leg pain\"\n", + "query = get_embedding(question).tolist()\n", + "\n", + "# Get results\n", + "results = index.query(vector=[query], top_k=10, include_metadata=True)\n", + "\n", + "# Sort results by score in descending order\n", + "sorted_matches = sorted(results['matches'], key=lambda x: x['score'], reverse=True)" + ] + }, + { + "cell_type": "markdown", + "id": "JkVM2XQpdPvv", + "metadata": { + "id": "JkVM2XQpdPvv" + }, + "source": [ + "## Show Results" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "eLVSmaHxV8XP", + "metadata": { + "id": "eLVSmaHxV8XP" + }, + "outputs": [], + "source": [ + "def show_results(question, matches):\n", + " \"\"\"A utility function to print our results\"\"\"\n", + " print(f'Question: \\'{question}\\'')\n", + " print('\\nResults:')\n", + " for i, match in enumerate(matches):\n", + " print(f'{str(i+1).rjust(4)}. ID: {match[\"id\"]}')\n", + " print(f' Score: {match[\"score\"]}')\n", + " print(f' Metadata: {match[\"metadata\"]}')\n", + " print('')" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "5b279e16-dc9b-4f71-a607-a9a969bea5a4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question: 'what if my patient has leg pain'\n", + "\n", + "Results:\n", + " 1. ID: P0100\n", + " Score: 0.517953098\n", + " Metadata: {'advice': 'over-the-counter pain relief, stretching', 'symptoms': 'muscle pain'}\n", + "\n", + " 2. ID: P095\n", + " Score: 0.500854671\n", + " Metadata: {'symptoms': 'back pain', 'treatment': 'physical therapy'}\n", + "\n", + " 3. ID: P047\n", + " Score: 0.500854671\n", + " Metadata: {'symptoms': 'back pain', 'treatment': 'physical therapy'}\n", + "\n", + " 4. ID: P007\n", + " Score: 0.459922969\n", + " Metadata: {'surgery': 'knee arthroscopy', 'symptoms': 'pain, swelling', 'treatment': 'physical therapy'}\n", + "\n", + " 5. ID: P028\n", + " Score: 0.446633637\n", + " Metadata: {'condition': 'knee pain', 'referral': 'orthopedics'}\n", + "\n", + " 6. ID: P059\n", + " Score: 0.429972351\n", + " Metadata: {'symptoms': 'joint pain', 'treatment': 'NSAIDs, rest'}\n", + "\n", + " 7. ID: P020\n", + " Score: 0.424824864\n", + " Metadata: {'condition': 'sprained ankle', 'tests': 'X-ray'}\n", + "\n", + " 8. ID: P068\n", + " Score: 0.414039701\n", + " Metadata: {'condition': 'broken arm', 'treatment': 'cast'}\n", + "\n", + " 9. ID: P092\n", + " Score: 0.408774346\n", + " Metadata: {'condition': 'dehydration', 'treatment': 'IV fluids'}\n", + "\n", + " 10. ID: P044\n", + " Score: 0.408774346\n", + " Metadata: {'condition': 'dehydration', 'treatment': 'IV fluids'}\n", + "\n" + ] + } + ], + "source": [ + "show_results(question, sorted_matches)" + ] + }, + { + "cell_type": "markdown", + "id": "-L62zHVmdcQP", + "metadata": { + "id": "-L62zHVmdcQP" + }, + "source": [ + "### Perform Rerank" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "32WVD7lPo7Tb", + "metadata": { + "id": "32WVD7lPo7Tb" + }, + "outputs": [], + "source": [ + "# Create documents with concatenated metadata field as \"reranking_field\" field\n", + "transformed_documents = [\n", + " {\n", + " 'id': match['id'],\n", + " 'reranking_field': '; '.join([f\"{key}: {value}\" for key, value in match['metadata'].items()])\n", + " }\n", + " for match in results['matches']\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "WLWt80fwq18w", + "metadata": { + "id": "WLWt80fwq18w" + }, + "outputs": [], + "source": [ + "# Define a more specific query for reranking\n", + "query = \"what if my patient had knee surgery\"\n", + "\n", + "# Perform reranking based on the query and specified field\n", + "reranked_results_field = pc.inference.rerank(\n", + " model=\"bge-reranker-v2-m3\",\n", + " query=query,\n", + " documents=transformed_documents,\n", + " rank_fields=[\"reranking_field\"],\n", + " top_n=2,\n", + " return_documents=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "0iDXuSUYsTm4", + "metadata": { + "id": "0iDXuSUYsTm4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question: 'what if my patient had knee surgery'\n", + "\n", + "Reranked Results:\n", + " 1. ID: P007\n", + " Score: 0.18184364\n", + " Reranking Field: surgery: knee arthroscopy; symptoms: pain, swelling; treatment: physical therapy\n", + "\n", + " 2. ID: P028\n", + " Score: 0.0054905633\n", + " Reranking Field: condition: knee pain; referral: orthopedics\n", + "\n" + ] + } + ], + "source": [ + "def show_reranked_results(question, matches):\n", + " \"\"\"A utility function to print our reranked results\"\"\"\n", + " print(f'Question: \\'{question}\\'')\n", + " print('\\nReranked Results:')\n", + " for i, match in enumerate(matches):\n", + " print(f'{str(i+1).rjust(4)}. ID: {match.document.id}')\n", + " print(f' Score: {match.score}')\n", + " print(f' Reranking Field: {match.document.reranking_field}')\n", + " print('')\n", + "\n", + "show_reranked_results(query, reranked_results_field.data)" + ] + }, + { + "cell_type": "markdown", + "id": "8PQTLfT-Frv8", + "metadata": { + "id": "8PQTLfT-Frv8" + }, + "source": [ + "Now let's delete the index to save resources" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "lEooi2--F0yR", + "metadata": { + "id": "lEooi2--F0yR" + }, + "outputs": [], + "source": [ + "pc.delete_index(name=index_name)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 5 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 }