Skip to content

Add "copy" method to AstraDBVectorStore #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,18 +283,20 @@ def __init__(
# - a single ("langchain", <version of langchain_core>)
# - if such is provided, a (component_name, <version of langchain_astradb>)
# (note: if component_name is None, astrapy strips it out automatically)
self.ext_callers = ext_callers
self.component_name = component_name
norm_ext_callers = [
cpair
for cpair in (
_raw_caller if isinstance(_raw_caller, tuple) else (_raw_caller, None)
for _raw_caller in (ext_callers or [])
for _raw_caller in (self.ext_callers or [])
)
if cpair[0] is not None or cpair[1] is not None
]
full_callers = [
*norm_ext_callers,
LC_CORE_CALLER,
(component_name, LC_ASTRADB_VERSION),
(self.component_name, LC_ASTRADB_VERSION),
]

# create the callers
Expand Down Expand Up @@ -343,9 +345,10 @@ def __init__(
async_astra_db_client=async_astra_db_client,
)
self.collection_name = collection_name
self.collection_embedding_api_key = collection_embedding_api_key
self.collection = self.database.get_collection(
name=self.collection_name,
embedding_api_key=collection_embedding_api_key,
embedding_api_key=self.collection_embedding_api_key,
)
self.async_collection = self.collection.to_async()

Expand Down Expand Up @@ -395,6 +398,57 @@ def __init__(
except ValueError as validation_error:
raise validation_error from data_api_exception

def copy(
self,
*,
token: str | TokenProvider | None = None,
ext_callers: list[tuple[str | None, str | None] | str | None] | None = None,
component_name: str | None = None,
collection_embedding_api_key: str | EmbeddingHeadersProvider | None = None,
) -> _AstraDBCollectionEnvironment:
"""Create a copy, possibly with changed attributes.

This method creates a shallow copy of this environment. If a parameter
is passed and differs from None, it will replace the corresponding value
in the copy.

The method allows changing only the parameters that ensure the copy is
functional and does not trigger side-effects:
for example, one cannot create a copy acting on a new collection.
In those cases, one should create a new instance
of ``_AstraDBCollectionEnvironment`` from scratch.

Attributes:
token: API token for Astra DB usage, either in the form of a string
or a subclass of ``astrapy.authentication.TokenProvider``.
In order to suppress token usage in the copy, explicitly pass
``astrapy.authentication.StaticTokenProvider(None)``.
ext_callers: additional custom (caller_name, caller_version) pairs
to attach to the User-Agent header when issuing Data API requests.
component_name: a value for the LangChain component name to use when
identifying the originator of the Data API requests.
collection_embedding_api_key: the API Key to supply in each Data API
request if necessary. This is necessary if using the Vectorize
feature and no secret is stored with the database.
In order to suppress the API Key in the copy, explicitly pass
``astrapy.authentication.EmbeddingAPIKeyHeaderProvider(None)``.
"""
return _AstraDBCollectionEnvironment(
collection_name=self.collection_name,
token=self.token if token is None else token,
api_endpoint=self.api_endpoint,
keyspace=self.keyspace,
environment=self.environment,
ext_callers=self.ext_callers if ext_callers is None else ext_callers,
component_name=self.component_name
if component_name is None
else component_name,
setup_mode=SetupMode.OFF,
collection_embedding_api_key=self.collection_embedding_api_key
if collection_embedding_api_key
else collection_embedding_api_key,
)

async def _asetup_db(
self,
*,
Expand Down
80 changes: 79 additions & 1 deletion libs/astradb/langchain_astradb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
)

import numpy as np
from astrapy.constants import Environment
from astrapy.exceptions import InsertManyException
from astrapy.info import CollectionVectorServiceOptions
from langchain_community.vectorstores.utils import maximal_marginal_relevance
from langchain_core.runnables.utils import gather_with_concurrency
from langchain_core.vectorstores import VectorStore
Expand Down Expand Up @@ -54,7 +56,6 @@
from astrapy.db import (
AsyncAstraDB as AsyncAstraDBClient,
)
from astrapy.info import CollectionVectorServiceOptions
from astrapy.results import UpdateResult
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
Expand Down Expand Up @@ -723,6 +724,83 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]:
# so here the final score transformation is not reversing the interval.
return lambda score: score

def copy(
self,
*,
token: str | TokenProvider | None = None,
ext_callers: list[tuple[str | None, str | None] | str | None] | None = None,
component_name: str | None = None,
collection_embedding_api_key: str | EmbeddingHeadersProvider | None = None,
) -> AstraDBVectorStore:
"""Create a copy, possibly with changed attributes.

This method creates a shallow copy of this environment. If a parameter
is passed and differs from None, it will replace the corresponding value
in the copy.

The method allows changing only the parameters that ensure the copy is
functional and does not trigger side-effects:
for example, one cannot create a copy acting on a new collection.
In those cases, one should create a new instance of ``AstraDBVectorStore``
from scratch.

Attributes:
token: API token for Astra DB usage, either in the form of a string
or a subclass of ``astrapy.authentication.TokenProvider``.
In order to suppress token usage in the copy, explicitly pass
``astrapy.authentication.StaticTokenProvider(None)``.
ext_callers: additional custom (caller_name, caller_version) pairs
to attach to the User-Agent header when issuing Data API requests.
component_name: a value for the LangChain component name to use when
identifying the originator of the Data API requests.
collection_embedding_api_key: the API Key to supply in each Data API
request if necessary. This is necessary if using the Vectorize
feature and no secret is stored with the database.
In order to suppress the API Key in the copy, explicitly pass
``astrapy.authentication.EmbeddingAPIKeyHeaderProvider(None)``.
"""
copy = AstraDBVectorStore(
collection_name="moot",
api_endpoint="http://moot",
environment=Environment.OTHER,
namespace="moot",
setup_mode=SetupMode.OFF,
collection_vector_service_options=CollectionVectorServiceOptions(
provider="moot",
model_name="moot",
Comment on lines +763 to +770
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not passing the final values instead of moot directly ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intent here was to ensure future readers understand the nature of this "trick", that the instantiation is misused so to speak - exploiting a path in the constructor where nothing happens.

The concern was that passing the real values (which however need to be supplied with more member-setting right thereafter) might mislead readers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exploiting a path in the constructor where nothing happens.
I am not very proud of this solution, to be honest. But it probably seems overkill to refactor everything else to make this little flow more straightforward.

),
)
copy.collection_name = self.collection_name
copy.token = self.token if token is None else token
copy.api_endpoint = self.api_endpoint
copy.environment = self.environment
copy.namespace = self.namespace
copy.indexing_policy = self.indexing_policy
copy.autodetect_collection = self.autodetect_collection
copy.embedding_dimension = self.embedding_dimension
copy.embedding = self.embedding
copy.metric = self.metric
copy.collection_embedding_api_key = (
self.collection_embedding_api_key
if collection_embedding_api_key is None
else collection_embedding_api_key
)
copy.collection_vector_service_options = self.collection_vector_service_options
copy.document_codec = self.document_codec
copy.batch_size = self.batch_size
copy.bulk_insert_batch_concurrency = self.bulk_insert_batch_concurrency
copy.bulk_insert_overwrite_concurrency = self.bulk_insert_overwrite_concurrency
copy.bulk_delete_concurrency = self.bulk_delete_concurrency
# Now the .astra_env attribute:
copy.astra_env = self.astra_env.copy(
token=token,
ext_callers=ext_callers,
component_name=component_name,
collection_embedding_api_key=collection_embedding_api_key,
)

return copy

def clear(self) -> None:
"""Empty the collection of all its stored entries."""
self.astra_env.ensure_db_setup()
Expand Down
48 changes: 46 additions & 2 deletions libs/astradb/tests/integration_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from typing import TYPE_CHECKING, Any

import pytest
from astrapy.authentication import StaticTokenProvider
from astrapy.authentication import EmbeddingAPIKeyHeaderProvider, StaticTokenProvider
from langchain_core.documents import Document

from langchain_astradb.utils.astradb import SetupMode
from langchain_astradb.utils.astradb import COMPONENT_NAME_VECTORSTORE, SetupMode
from langchain_astradb.vectorstores import AstraDBVectorStore

from .conftest import (
Expand Down Expand Up @@ -1784,3 +1784,47 @@ async def test_astradb_vectorstore_coreclients_init_async(
assert len(f_rec_warnings) == 1
assert len(results) == 1
assert results[0].page_content == "[1,2]"

@pytest.mark.parametrize(
"vector_store",
[
"vector_store_d2",
"vector_store_vz",
],
ids=["nonvectorize_store", "vectorize_store"],
)
def test_astradb_vectorstore_copy(
self,
*,
vector_store: str,
request: pytest.FixtureRequest,
) -> None:
"""Verify changed attributes in 'copy', down in the astra_env of the store."""
vstore0: AstraDBVectorStore = request.getfixturevalue(vector_store)

# component_name, deep test
# Note this line encodes assumptions on astrapy internals that will fail on 2.0:
caller_names0 = {caller[0] for caller in vstore0.astra_env.collection.callers}
assert COMPONENT_NAME_VECTORSTORE in caller_names0

vstore1 = vstore0.copy(component_name="xyz_component")

# Note this line encodes assumptions on astrapy internals that will fail on 2.0:
caller_names1 = {caller[0] for caller in vstore1.astra_env.collection.callers}
assert COMPONENT_NAME_VECTORSTORE not in caller_names1
assert "xyz_component" in caller_names1

# other changeable attributes (this check does not enter astrapy at all)
token2 = StaticTokenProvider("xyz")
apikey2 = EmbeddingAPIKeyHeaderProvider(None)
vstore2 = vstore0.copy(
token=token2,
ext_callers=[("cnx", "cvx")],
component_name="component_name2",
collection_embedding_api_key=apikey2,
)

assert vstore2.astra_env.token == token2
assert vstore2.astra_env.ext_callers == [("cnx", "cvx")]
assert vstore2.astra_env.component_name == "component_name2"
assert vstore2.astra_env.collection_embedding_api_key == apikey2
56 changes: 56 additions & 0 deletions libs/astradb/tests/unit_tests/test_callers.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,62 @@ def test_callers_component_vectorstore(self, httpserver: HTTPServer) -> None:
ext_callers=[("ec0", "ev0")],
)

def test_callers_vectorstore_copy(self, httpserver: HTTPServer) -> None:
"""
Test of "copy" for the vectorstore, checking the actual headers.
"""
base_endpoint = httpserver.url_for("/")
base_path = "/v1/ks"
coll_name = "my_coll"
new_component_name = "NEW_COMPONENT_NAME"
no_results_json: dict[str, Any] = {
"data": {
"documents": [],
"nextPageState": None,
}
}

httpserver.expect_oneshot_request(
base_path,
method="POST",
headers={
"User-Agent": "ec0/ev0",
},
header_value_matcher=hv_prefix_matcher_factory(COMPONENT_NAME_VECTORSTORE),
).respond_with_json({})

vs0 = AstraDBVectorStore(
collection_name=coll_name,
api_endpoint=base_endpoint,
environment=Environment.OTHER,
namespace="ks",
embedding=ParserEmbeddings(2),
ext_callers=[("ec0", "ev0")],
)

# a clone with different component name:
vs1 = vs0.copy(component_name=new_component_name)
httpserver.expect_oneshot_request(
base_path + "/" + coll_name,
method="POST",
headers={
"User-Agent": "ec0/ev0",
},
header_value_matcher=hv_prefix_matcher_factory(new_component_name),
).respond_with_json(no_results_json)
vs1.similarity_search("[0,1]")

# the original one is untouched:
httpserver.expect_oneshot_request(
base_path + "/" + coll_name,
method="POST",
headers={
"User-Agent": "ec0/ev0",
},
header_value_matcher=hv_prefix_matcher_factory(COMPONENT_NAME_VECTORSTORE),
).respond_with_json(no_results_json)
vs0.similarity_search("[0,1]")

def test_callers_component_graphvectorstore(self, httpserver: HTTPServer) -> None:
"""
End-to-end testing of callers passed through the components.
Expand Down