diff --git a/pinecone/control/pinecone.py b/pinecone/control/pinecone.py index b42b93712..569625685 100644 --- a/pinecone/control/pinecone.py +++ b/pinecone/control/pinecone.py @@ -1,6 +1,7 @@ import time import logging from typing import Optional, Dict, Any, Union +from enum import Enum from .index_host_store import IndexHostStore from .pinecone_interface import PineconeDBControlInterface @@ -14,6 +15,8 @@ from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client from pinecone.core.openapi.db_control.models import ( CreateCollectionRequest, + CreateIndexForModelRequest, + CreateIndexForModelRequestEmbed, CreateIndexRequest, ConfigureIndexRequest, ConfigureIndexRequestSpec, @@ -26,12 +29,29 @@ PodSpecMetadataConfig, ) from pinecone.core.openapi.db_control import API_VERSION -from pinecone.models import ServerlessSpec, PodSpec, IndexModel, IndexList, CollectionList +from pinecone.models import ( + ServerlessSpec, + PodSpec, + IndexModel, + IndexList, + CollectionList, + IndexEmbed, +) from .langchain_import_warnings import _build_langchain_attribute_error_message from pinecone.utils import parse_non_empty_args, docslinks from pinecone.data import _Index, _AsyncioIndex, _Inference -from pinecone.enums import Metric, VectorType, DeletionProtection, PodType +from pinecone.enums import ( + Metric, + VectorType, + DeletionProtection, + PodType, + CloudProvider, + AwsRegion, + GcpRegion, + AzureRegion, +) +from .types import CreateIndexForModelEmbedTypedDict from pinecone_plugin_interface import load_and_install as install_plugins @@ -127,7 +147,27 @@ def load_plugins(self): except Exception as e: logger.error(f"Error loading plugins: {e}") - def _parse_index_spec(self, spec: Union[Dict, ServerlessSpec, PodSpec]) -> IndexSpec: + def __parse_tags(self, tags: Optional[Dict[str, str]]) -> IndexTags: + if tags is None: + return IndexTags() + else: + return IndexTags(**tags) + + def __parse_deletion_protection( + self, deletion_protection: Union[DeletionProtection, str] + ) -> DeletionProtectionModel: + deletion_protection = self.__parse_enum_to_string(deletion_protection) + if deletion_protection in ["enabled", "disabled"]: + return DeletionProtectionModel(deletion_protection) + else: + raise ValueError("deletion_protection must be either 'enabled' or 'disabled'") + + def __parse_enum_to_string(self, value: Union[Enum, str]) -> str: + if isinstance(value, Enum): + return value.value + return value + + def __parse_index_spec(self, spec: Union[Dict, ServerlessSpec, PodSpec]) -> IndexSpec: if isinstance(spec, dict): if "serverless" in spec: index_spec = IndexSpec(serverless=ServerlessSpecModel(**spec["serverless"])) @@ -185,85 +225,133 @@ def create_index( deletion_protection: Optional[Union[DeletionProtection, str]] = DeletionProtection.DISABLED, vector_type: Optional[Union[VectorType, str]] = VectorType.DENSE, tags: Optional[Dict[str, str]] = None, - ): - # Convert Enums to their string values if necessary - metric = metric.value if isinstance(metric, Metric) else str(metric) - vector_type = vector_type.value if isinstance(vector_type, VectorType) else str(vector_type) - deletion_protection = ( - deletion_protection.value - if isinstance(deletion_protection, DeletionProtection) - else str(deletion_protection) - ) + ) -> IndexModel: + if metric is not None: + metric = self.__parse_enum_to_string(metric) + if vector_type is not None: + vector_type = self.__parse_enum_to_string(vector_type) + if deletion_protection is not None: + dp = self.__parse_deletion_protection(deletion_protection) + else: + dp = None + + tags_obj = self.__parse_tags(tags) + index_spec = self.__parse_index_spec(spec) if vector_type == VectorType.SPARSE.value and dimension is not None: raise ValueError("dimension should not be specified for sparse indexes") - if deletion_protection in ["enabled", "disabled"]: - dp = DeletionProtectionModel(deletion_protection) + args = parse_non_empty_args( + [ + ("name", name), + ("dimension", dimension), + ("metric", metric), + ("spec", index_spec), + ("deletion_protection", dp), + ("vector_type", vector_type), + ("tags", tags_obj), + ] + ) + + req = CreateIndexRequest(**args) + resp = self.index_api.create_index(create_index_request=req) + + if timeout == -1: + return IndexModel(resp) + return self.__poll_describe_index_until_ready(name, timeout) + + def create_index_for_model( + self, + name: str, + cloud: Union[CloudProvider, str], + region: Union[AwsRegion, GcpRegion, AzureRegion, str], + embed: Union[IndexEmbed, CreateIndexForModelEmbedTypedDict], + tags: Optional[Dict[str, str]] = None, + deletion_protection: Optional[Union[DeletionProtection, str]] = DeletionProtection.DISABLED, + timeout: Optional[int] = None, + ) -> IndexModel: + cloud = self.__parse_enum_to_string(cloud) + region = self.__parse_enum_to_string(region) + if deletion_protection is not None: + dp = self.__parse_deletion_protection(deletion_protection) else: - raise ValueError("deletion_protection must be either 'enabled' or 'disabled'") + dp = None + tags_obj = self.__parse_tags(tags) - if tags is None: - tags_obj = None + if isinstance(embed, IndexEmbed): + parsed_embed = embed.as_dict() else: - tags_obj = IndexTags(**tags) + # if dict, we need to parse enum values, if any, to string + # and verify required fields are present + required_fields = ["model", "field_map"] + for field in required_fields: + if field not in embed: + raise ValueError(f"{field} is required in embed") + parsed_embed = {} + for key, value in embed.items(): + if isinstance(value, Enum): + parsed_embed[key] = value.value + else: + parsed_embed[key] = value + + args = parse_non_empty_args( + [ + ("name", name), + ("cloud", cloud), + ("region", region), + ("embed", CreateIndexForModelRequestEmbed(**parsed_embed)), + ("deletion_protection", dp), + ("tags", tags_obj), + ] + ) - index_spec = self._parse_index_spec(spec) + req = CreateIndexForModelRequest(**args) + resp = self.index_api.create_index_for_model(req) - api_instance = self.index_api - api_instance.create_index( - create_index_request=CreateIndexRequest( - **parse_non_empty_args( - [ - ("name", name), - ("dimension", dimension), - ("metric", metric), - ("spec", index_spec), - ("deletion_protection", dp), - ("vector_type", vector_type), - ("tags", tags_obj), - ] - ) - ) - ) + if timeout == -1: + return IndexModel(resp) + return self.__poll_describe_index_until_ready(name, timeout) + + def __poll_describe_index_until_ready(self, name: str, timeout: Optional[int] = None): + description = None def is_ready(): - status = self._get_status(name) - ready = status["ready"] - return ready + nonlocal description + description = self.describe_index(name=name) + return description.status.ready total_wait_time = 0 - if timeout == -1: - logger.debug(f"Skipping wait for index {name} to be ready") - return if timeout is None: + # Wait indefinitely while not is_ready(): logger.debug( f"Waiting for index {name} to be ready. Total wait time: {total_wait_time}" ) total_wait_time += 5 time.sleep(5) + else: - while (not is_ready()) and timeout >= 0: + # Wait for a maximum of timeout seconds + while not is_ready(): + if timeout < 0: + logger.error(f"Index {name} is not ready. Timeout reached.") + link = docslinks["API_DESCRIBE_INDEX"] + timeout_msg = ( + f"Please call describe_index() to confirm index status. See docs at {link}" + ) + raise TimeoutError(timeout_msg) + logger.debug( f"Waiting for index {name} to be ready. Total wait time: {total_wait_time}" ) total_wait_time += 5 time.sleep(5) timeout -= 5 - if timeout and timeout < 0: - logger.error(f"Index {name} is not ready. Timeout reached.") - raise ( - TimeoutError( - "Please call the describe_index API ({}) to confirm index status.".format( - docslinks["API_DESCRIBE_INDEX"] - ) - ) - ) + + return description def delete_index(self, name: str, timeout: Optional[int] = None): - api_instance = self.index_api - api_instance.delete_index(name) + self.index_api.delete_index(name) self.index_host_store.delete_host(self.config, name) def get_remaining(): @@ -292,7 +380,7 @@ def list_indexes(self) -> IndexList: response = self.index_api.list_indexes() return IndexList(response) - def describe_index(self, name: str): + def describe_index(self, name: str) -> IndexModel: api_instance = self.index_api description = api_instance.describe_index(name) host = description.host @@ -373,11 +461,6 @@ def describe_collection(self, name: str): api_instance = self.index_api return api_instance.describe_collection(name).to_dict() - def _get_status(self, name: str): - api_instance = self.index_api - response = api_instance.describe_index(name) - return response["status"] - @staticmethod def from_texts(*args, **kwargs): raise AttributeError(_build_langchain_attribute_error_message("from_texts")) diff --git a/pinecone/control/pinecone_interface.py b/pinecone/control/pinecone_interface.py index 24b4dc3a5..d9e773d07 100644 --- a/pinecone/control/pinecone_interface.py +++ b/pinecone/control/pinecone_interface.py @@ -8,8 +8,25 @@ from pinecone.core.openapi.db_control.api.manage_indexes_api import ManageIndexesApi -from pinecone.models import ServerlessSpec, PodSpec, IndexList, CollectionList -from pinecone.enums import Metric, VectorType, DeletionProtection, PodType +from pinecone.models import ( + ServerlessSpec, + PodSpec, + IndexList, + CollectionList, + IndexModel, + IndexEmbed, +) +from pinecone.enums import ( + Metric, + VectorType, + DeletionProtection, + PodType, + CloudProvider, + AwsRegion, + GcpRegion, + AzureRegion, +) +from .types import CreateIndexForModelEmbedTypedDict class PineconeDBControlInterface(ABC): @@ -246,6 +263,44 @@ def create_index( """ pass + @abstractmethod + def create_index_for_model( + self, + name: str, + cloud: Union[CloudProvider, str], + region: Union[AwsRegion, GcpRegion, AzureRegion, str], + embed: Union[IndexEmbed, CreateIndexForModelEmbedTypedDict], + tags: Optional[Dict[str, str]] = None, + deletion_protection: Optional[Union[DeletionProtection, str]] = DeletionProtection.DISABLED, + timeout: Optional[int] = None, + ) -> IndexModel: + """ + Create an index for a model. + + This operation creates a serverless index for a model. The index is used to store embeddings generated by the model. The index can be used to search and retrieve embeddings. + + :param name: The name of the index to create. Must be unique within your project and + cannot be changed once created. Allowed characters are lowercase letters, numbers, + and hyphens and the name may not begin or end with hyphens. Maximum length is 45 characters. + :type name: str + :param cloud: The cloud provider to use for the index. + :type cloud: str + :param region: The region to use for the index. + :type region: str + :param embed: The embedding configuration for the index. + :type embed: Union[Dict, IndexEmbed] + :param tags: A dictionary of tags to associate with the index. + :type tags: Optional[Dict[str, str]] + :param deletion_protection: If enabled, the index cannot be deleted. If disabled, the index can be deleted. Default: "disabled" + :type deletion_protection: Optional[Literal["enabled", "disabled"]] + :type timeout: Optional[int] + :param timeout: Specify the number of seconds to wait until index gets ready. If None, wait indefinitely; if >=0, time out after this many seconds; + if -1, return immediately and do not wait. Default: None + :return: The index that was created. + :rtype: IndexModel + """ + pass + @abstractmethod def delete_index(self, name: str, timeout: Optional[int] = None): """Deletes a Pinecone index. @@ -316,7 +371,7 @@ def list_indexes(self) -> IndexList: pass @abstractmethod - def describe_index(self, name: str): + def describe_index(self, name: str) -> IndexModel: """Describes a Pinecone index. :param name: the name of the index to describe. diff --git a/pinecone/control/types/__init__.py b/pinecone/control/types/__init__.py new file mode 100644 index 000000000..12d16270d --- /dev/null +++ b/pinecone/control/types/__init__.py @@ -0,0 +1 @@ +from .create_index_for_model_embed import CreateIndexForModelEmbedTypedDict diff --git a/pinecone/control/types/create_index_for_model_embed.py b/pinecone/control/types/create_index_for_model_embed.py new file mode 100644 index 000000000..123474a0a --- /dev/null +++ b/pinecone/control/types/create_index_for_model_embed.py @@ -0,0 +1,11 @@ +from typing import TypedDict, Dict, Union +from ...enums import Metric +from ...data.features.inference import EmbedModel + + +class CreateIndexForModelEmbedTypedDict(TypedDict): + model: Union[EmbedModel, str] + field_map: Dict + metric: Union[Metric, str] + read_parameters: Dict + write_parameters: Dict diff --git a/pinecone/deprecated_plugins.py b/pinecone/deprecated_plugins.py index c412ba302..5bf857c0b 100644 --- a/pinecone/deprecated_plugins.py +++ b/pinecone/deprecated_plugins.py @@ -1,5 +1,6 @@ class DeprecatedPluginError(Exception): - def __init__(self, message): + def __init__(self, plugin_name: str): + message = f"The `{plugin_name}` package has been deprecated. The features from that plugin have been incorporated into the main `pinecone` package with no need for additional plugins. Please remove the `{plugin_name}` package from your dependencies to ensure you have the most up-to-date version of these features." super().__init__(message) @@ -8,10 +9,14 @@ def check_for_deprecated_plugins(): from pinecone_plugins.inference import __installables__ # type: ignore if __installables__ is not None: - raise DeprecatedPluginError( - "The `pinecone-plugin-inference` package has been deprecated. The embed and rerank functionality has been incorporated into the main `pinecone` package with no need for additional plugins. Please remove the `pinecone-plugin-inference` package from your dependencies to ensure you have the most up-to-date version of these features." - ) + raise DeprecatedPluginError("pinecone-plugin-inference") + except ImportError: + pass + + try: + from pinecone_plugins.records import __installables__ # type: ignore + + if __installables__ is not None: + raise DeprecatedPluginError("pinecone-plugin-records") except ImportError: - # ImportError is expected if the plugin is not installed, - # which is the good case. pass diff --git a/pinecone/models/__init__.py b/pinecone/models/__init__.py index e9b187ea6..86306c1e1 100644 --- a/pinecone/models/__init__.py +++ b/pinecone/models/__init__.py @@ -1,11 +1,11 @@ from .index_description import ServerlessSpecDefinition, PodSpecDefinition from .collection_description import CollectionDescription from .serverless_spec import ServerlessSpec -from .pod_spec import PodSpec, PodType +from .pod_spec import PodSpec from .index_list import IndexList from .collection_list import CollectionList from .index_model import IndexModel -from ..enums.metric import Metric +from .index_embed import IndexEmbed __all__ = [ "CollectionDescription", @@ -16,4 +16,5 @@ "IndexList", "CollectionList", "IndexModel", + "IndexEmbed", ] diff --git a/pinecone/models/index_embed.py b/pinecone/models/index_embed.py new file mode 100644 index 000000000..4d1ccfe39 --- /dev/null +++ b/pinecone/models/index_embed.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass +from typing import Optional, Dict, Any, Union + +from ..enums import Metric +from ..data.features.inference import EmbedModel + + +@dataclass(frozen=True) +class IndexEmbed: + """ + IndexEmbed represents the index embedding configuration when creating an index from a model. + """ + + model: str + """ + The name of the embedding model to use for the index. + Required. + """ + + field_map: Dict[str, Any] + """ + A mapping of field names to their types. + Required. + """ + + metric: Optional[str] = None + """ + The metric to use for the index. If not provided, the default metric for the model is used. + Optional. + """ + + read_parameters: Optional[Dict[str, Any]] = None + """ + The parameters to use when reading from the index. + Optional. + """ + + write_parameters: Optional[Dict[str, Any]] = None + """ + The parameters to use when writing to the index. + Optional. + """ + + def as_dict(self) -> Dict[str, Any]: + """ + Returns the IndexEmbed as a dictionary. + """ + return self.__dict__ + + def __init__( + self, + model: Union[EmbedModel, str], + field_map: Dict[str, Any], + metric: Optional[Union[Metric, str]] = None, + read_parameters: Optional[Dict[str, Any]] = None, + write_parameters: Optional[Dict[str, Any]] = None, + ): + object.__setattr__( + self, "model", model.value if isinstance(model, EmbedModel) else str(model) + ) + object.__setattr__(self, "field_map", field_map) + object.__setattr__(self, "metric", metric.value if isinstance(metric, Metric) else metric) + object.__setattr__( + self, "read_parameters", read_parameters if read_parameters is not None else {} + ) + object.__setattr__( + self, "write_parameters", write_parameters if write_parameters is not None else {} + ) diff --git a/tests/integration/control/serverless/test_configure_index_deletion_protection.py b/tests/integration/control/serverless/test_configure_index_deletion_protection.py index fb12897ab..86b7ef60b 100644 --- a/tests/integration/control/serverless/test_configure_index_deletion_protection.py +++ b/tests/integration/control/serverless/test_configure_index_deletion_protection.py @@ -23,7 +23,7 @@ def test_deletion_protection(self, client, create_sl_index_params, dp_enabled, d client.delete_index(name) - @pytest.mark.parametrize("deletion_protection", ["invalid", None]) + @pytest.mark.parametrize("deletion_protection", ["invalid"]) def test_deletion_protection_invalid_options( self, client, create_sl_index_params, deletion_protection ): diff --git a/tests/integration/control/serverless/test_create_index_sl_happy_path.py b/tests/integration/control/serverless/test_create_index.py similarity index 54% rename from tests/integration/control/serverless/test_create_index_sl_happy_path.py rename to tests/integration/control/serverless/test_create_index.py index 8069f3408..5e7f46fe6 100644 --- a/tests/integration/control/serverless/test_create_index_sl_happy_path.py +++ b/tests/integration/control/serverless/test_create_index.py @@ -1,5 +1,6 @@ import pytest from pinecone import ( + Pinecone, Metric, VectorType, DeletionProtection, @@ -10,44 +11,82 @@ class TestCreateSLIndexHappyPath: - def test_create_index(self, client, create_sl_index_params): - name = create_sl_index_params["name"] - dimension = create_sl_index_params["dimension"] - client.create_index(**create_sl_index_params) - desc = client.describe_index(name) - assert desc.name == name - assert desc.dimension == dimension + def test_create_index(self, client: Pinecone, index_name): + resp = client.create_index( + name=index_name, + dimension=10, + spec=ServerlessSpec(cloud=CloudProvider.AWS, region=AwsRegion.US_EAST_1), + ) + assert resp.name == index_name + assert resp.dimension == 10 + assert resp.metric == "cosine" # default value + assert resp.vector_type == "dense" # default value + assert resp.deletion_protection == "disabled" # default value + + desc = client.describe_index(name=index_name) + assert desc.name == index_name + assert desc.dimension == 10 assert desc.metric == "cosine" assert desc.deletion_protection == "disabled" # default value assert desc.vector_type == "dense" # default value + def test_create_skip_wait(self, client, index_name): + resp = client.create_index( + name=index_name, + dimension=10, + spec=ServerlessSpec(cloud=CloudProvider.AWS, region=AwsRegion.US_EAST_1), + timeout=-1, + ) + assert resp.name == index_name + assert resp.dimension == 10 + assert resp.metric == "cosine" + + def test_create_infinite_wait(self, client, index_name): + resp = client.create_index( + name=index_name, + dimension=10, + spec=ServerlessSpec(cloud=CloudProvider.AWS, region=AwsRegion.US_EAST_1), + timeout=None, + ) + assert resp.name == index_name + assert resp.dimension == 10 + assert resp.metric == "cosine" + @pytest.mark.parametrize("metric", ["cosine", "euclidean", "dotproduct"]) def test_create_default_index_with_metric(self, client, create_sl_index_params, metric): create_sl_index_params["metric"] = metric client.create_index(**create_sl_index_params) desc = client.describe_index(create_sl_index_params["name"]) - assert desc.metric == metric + if isinstance(metric, str): + assert desc.metric == metric + else: + assert desc.metric == metric.value assert desc.vector_type == "dense" @pytest.mark.parametrize( - "metric_enum,vector_type_enum,dim", + "metric_enum,vector_type_enum,dim,tags", [ - (Metric.COSINE, VectorType.DENSE, 10), - (Metric.EUCLIDEAN, VectorType.DENSE, 10), - (Metric.DOTPRODUCT, VectorType.SPARSE, None), + (Metric.COSINE, VectorType.DENSE, 10, None), + (Metric.EUCLIDEAN, VectorType.DENSE, 10, {"env": "prod"}), + (Metric.DOTPRODUCT, VectorType.SPARSE, None, {"env": "dev"}), ], ) - def test_create_with_enum_values(self, client, index_name, metric_enum, vector_type_enum, dim): + def test_create_with_enum_values( + self, client, index_name, metric_enum, vector_type_enum, dim, tags + ): args = { "name": index_name, "metric": metric_enum, "vector_type": vector_type_enum, "deletion_protection": DeletionProtection.DISABLED, "spec": ServerlessSpec(cloud=CloudProvider.AWS, region=AwsRegion.US_EAST_1), + "tags": tags, } if dim is not None: args["dimension"] = dim + client.create_index(**args) + desc = client.describe_index(index_name) assert desc.metric == metric_enum.value assert desc.vector_type == vector_type_enum.value @@ -56,6 +95,8 @@ def test_create_with_enum_values(self, client, index_name, metric_enum, vector_t assert desc.name == index_name assert desc.spec.serverless.cloud == "aws" assert desc.spec.serverless.region == "us-east-1" + if tags: + assert desc.tags.to_dict() == tags @pytest.mark.parametrize("metric", ["cosine", "euclidean", "dotproduct"]) def test_create_dense_index_with_metric(self, client, create_sl_index_params, metric): diff --git a/tests/integration/control/serverless/test_create_index_for_model.py b/tests/integration/control/serverless/test_create_index_for_model.py new file mode 100644 index 000000000..5f0258f75 --- /dev/null +++ b/tests/integration/control/serverless/test_create_index_for_model.py @@ -0,0 +1,68 @@ +import pytest +from pinecone import EmbedModel, CloudProvider, AwsRegion, IndexEmbed, Metric + + +class TestCreateIndexForModel: + @pytest.mark.parametrize( + "model_val,cloud_val,region_val", + [ + ("multilingual-e5-large", "aws", "us-east-1"), + (EmbedModel.Multilingual_E5_Large, CloudProvider.AWS, AwsRegion.US_EAST_1), + (EmbedModel.Pinecone_Sparse_English_V0, CloudProvider.AWS, AwsRegion.US_EAST_1), + ], + ) + def test_create_index_for_model(self, client, model_val, index_name, cloud_val, region_val): + field_map = {"text": "my-sample-text"} + index = client.create_index_for_model( + name=index_name, + cloud=cloud_val, + region=region_val, + embed={"model": model_val, "field_map": field_map}, + timeout=-1, + ) + assert index.name == index_name + assert index.spec.serverless.cloud == "aws" + assert index.spec.serverless.region == "us-east-1" + assert index.embed.field_map == field_map + if isinstance(model_val, EmbedModel): + assert index.embed.model == model_val.value + else: + assert index.embed.model == model_val + + def test_create_index_for_model_with_index_embed_obj(self, client, index_name): + field_map = {"text": "my-sample-text"} + index = client.create_index_for_model( + name=index_name, + cloud=CloudProvider.AWS, + region=AwsRegion.US_EAST_1, + embed=IndexEmbed( + metric=Metric.COSINE, model=EmbedModel.Multilingual_E5_Large, field_map=field_map + ), + timeout=-1, + ) + assert index.name == index_name + assert index.spec.serverless.cloud == "aws" + assert index.spec.serverless.region == "us-east-1" + assert index.embed.field_map == field_map + assert index.embed.model == EmbedModel.Multilingual_E5_Large.value + + @pytest.mark.parametrize( + "model_val,metric_val", + [(EmbedModel.Multilingual_E5_Large, Metric.COSINE), ("multilingual-e5-large", "cosine")], + ) + def test_create_index_for_model_with_index_embed_dict( + self, client, index_name, model_val, metric_val + ): + field_map = {"text": "my-sample-text"} + index = client.create_index_for_model( + name=index_name, + cloud=CloudProvider.AWS, + region=AwsRegion.US_EAST_1, + embed={"metric": metric_val, "field_map": field_map, "model": model_val}, + timeout=-1, + ) + assert index.name == index_name + assert index.spec.serverless.cloud == "aws" + assert index.spec.serverless.region == "us-east-1" + assert index.embed.field_map == field_map + assert index.embed.model == EmbedModel.Multilingual_E5_Large.value diff --git a/tests/integration/control/serverless/test_create_index_for_model_errors.py b/tests/integration/control/serverless/test_create_index_for_model_errors.py new file mode 100644 index 000000000..c08c581dc --- /dev/null +++ b/tests/integration/control/serverless/test_create_index_for_model_errors.py @@ -0,0 +1,122 @@ +import pytest +from pinecone import ( + EmbedModel, + CloudProvider, + AwsRegion, + Metric, + PineconeApiException, + PineconeApiValueError, +) + + +class TestCreateIndexForModelErrors: + def test_create_index_for_model_with_invalid_model(self, client, index_name): + with pytest.raises(PineconeApiException) as e: + client.create_index_for_model( + name=index_name, + cloud=CloudProvider.AWS, + region=AwsRegion.US_EAST_1, + embed={ + "model": "invalid-model", + "field_map": {"text": "my-sample-text"}, + "metric": Metric.COSINE, + }, + timeout=-1, + ) + assert "Model invalid-model not found." in str(e.value) + + def test_invalid_cloud(self, client, index_name): + with pytest.raises(PineconeApiValueError) as e: + client.create_index_for_model( + name=index_name, + cloud="invalid-cloud", + region=AwsRegion.US_EAST_1, + embed={ + "model": EmbedModel.Multilingual_E5_Large, + "field_map": {"text": "my-sample-text"}, + "metric": Metric.COSINE, + }, + timeout=-1, + ) + assert "Invalid value for `cloud`" in str(e.value) + + def test_invalid_region(self, client, index_name): + with pytest.raises(PineconeApiException) as e: + client.create_index_for_model( + name=index_name, + cloud=CloudProvider.AWS, + region="invalid-region", + embed={ + "model": EmbedModel.Multilingual_E5_Large, + "field_map": {"text": "my-sample-text"}, + "metric": Metric.COSINE, + }, + timeout=-1, + ) + assert "invalid-region not found" in str(e.value) + + def test_create_index_for_model_with_invalid_field_map(self, client, index_name): + with pytest.raises(PineconeApiException) as e: + client.create_index_for_model( + name=index_name, + cloud=CloudProvider.AWS, + region=AwsRegion.US_EAST_1, + embed={ + "model": EmbedModel.Multilingual_E5_Large, + "field_map": {"invalid_field": "my-sample-text"}, + "metric": Metric.COSINE, + }, + timeout=-1, + ) + assert "Missing required key 'text'" in str(e.value) + + def test_create_index_for_model_with_invalid_metric(self, client, index_name): + with pytest.raises(PineconeApiValueError) as e: + client.create_index_for_model( + name=index_name, + cloud=CloudProvider.AWS, + region=AwsRegion.US_EAST_1, + embed={ + "model": EmbedModel.Multilingual_E5_Large, + "field_map": {"text": "my-sample-text"}, + "metric": "invalid-metric", + }, + timeout=-1, + ) + assert "Invalid value for `metric`" in str(e.value) + + def test_create_index_for_model_with_missing_name(self, client): + with pytest.raises(TypeError) as e: + client.create_index_for_model( + cloud=CloudProvider.AWS, + region=AwsRegion.US_EAST_1, + embed={ + "model": EmbedModel.Multilingual_E5_Large, + "field_map": {"text": "my-sample-text"}, + "metric": Metric.EUCLIDEAN, + }, + timeout=-1, + ) + assert "name" in str(e.value) + + def test_create_index_with_missing_model(self, client, index_name): + with pytest.raises(ValueError) as e: + client.create_index_for_model( + name=index_name, + cloud=CloudProvider.AWS, + region=AwsRegion.US_EAST_1, + embed={"field_map": {"text": "my-sample-text"}, "metric": Metric.COSINE}, + timeout=-1, + ) + assert "model is required" in str(e.value) + + def test_create_index_with_missing_field_map(self, client, index_name): + with pytest.raises(ValueError) as e: + client.create_index_for_model( + name=index_name, + cloud=CloudProvider.AWS, + region=AwsRegion.US_EAST_1, + embed={"model": EmbedModel.Multilingual_E5_Large, "metric": Metric.COSINE}, + timeout=-1, + ) + assert "field_map is required" in str(e.value) diff --git a/tests/unit/models/test_index_embed.py b/tests/unit/models/test_index_embed.py new file mode 100644 index 000000000..3a372dfe9 --- /dev/null +++ b/tests/unit/models/test_index_embed.py @@ -0,0 +1,57 @@ +from pinecone import IndexEmbed, EmbedModel, Metric + + +def test_initialization_required_fields(): + embed = IndexEmbed(model="test-model", field_map={"text": "my_text_field"}) + + assert embed.model == "test-model" + assert embed.field_map == {"text": "my_text_field"} + + +def test_initialization_with_optional_fields(): + embed = IndexEmbed( + model="test-model", + field_map={"text": "my_text_field"}, + metric="cosine", + read_parameters={"param1": "value1"}, + write_parameters={"param2": "value2"}, + ) + + assert embed.model == "test-model" + assert embed.field_map == {"text": "my_text_field"} + assert embed.metric == "cosine" + assert embed.read_parameters == {"param1": "value1"} + assert embed.write_parameters == {"param2": "value2"} + + +def test_as_dict_method(): + embed = IndexEmbed( + model="test-model", + field_map={"text": "my_text_field"}, + metric="cosine", + read_parameters={"param1": "value1"}, + write_parameters={"param2": "value2"}, + ) + embed_dict = embed.as_dict() + + expected_dict = { + "model": "test-model", + "field_map": {"text": "my_text_field"}, + "metric": "cosine", + "read_parameters": {"param1": "value1"}, + "write_parameters": {"param2": "value2"}, + } + + assert embed_dict == expected_dict + + +def test_when_passed_enums(): + embed = IndexEmbed( + model=EmbedModel.Multilingual_E5_Large, + field_map={"text": "my_text_field"}, + metric=Metric.COSINE, + ) + + assert embed.model == EmbedModel.Multilingual_E5_Large.value + assert embed.field_map == {"text": "my_text_field"} + assert embed.metric == Metric.COSINE.value diff --git a/tests/unit/test_control.py b/tests/unit/test_control.py index 1115948d6..47a0d3f8a 100644 --- a/tests/unit/test_control.py +++ b/tests/unit/test_control.py @@ -12,12 +12,32 @@ PodIndexEnvironment, PodType, ) -from pinecone.core.openapi.db_control.models import IndexList, IndexModel, DeletionProtection +from pinecone.core.openapi.db_control.models import ( + IndexList, + IndexModel, + DeletionProtection, + IndexModelSpec, + ServerlessSpec as ServerlessSpecOpenApi, + IndexModelStatus, +) from pinecone.core.openapi.db_control.api.manage_indexes_api import ManageIndexesApi import time +def description_with_status(status: bool): + state = "Ready" if status else "Initializing" + return IndexModel( + name="foo", + status=IndexModelStatus(ready=status, state=state), + dimension=10, + deletion_protection=DeletionProtection("enabled"), + host="https://foo", + metric="euclidean", + spec=IndexModelSpec(serverless=ServerlessSpecOpenApi(cloud="aws", region="us-west1")), + ) + + @pytest.fixture def index_list_response(): return IndexList( @@ -117,20 +137,20 @@ def test_set_source_tag_in_useragent_via_config(self): "timeout_value, describe_index_responses, expected_describe_index_calls, expected_sleep_calls", [ # When timeout=None, describe_index is called until ready - (None, [{"status": {"ready": False}}, {"status": {"ready": True}}], 2, 1), - # Timeout of 10 seconds, describe_index called 3 times, sleep twice + (None, [description_with_status(False), description_with_status(True)], 2, 1), + # # Timeout of 10 seconds, describe_index called 3 times, sleep twice ( 10, [ - {"status": {"ready": False}}, - {"status": {"ready": False}}, - {"status": {"ready": True}}, + description_with_status(False), + description_with_status(False), + description_with_status(True), ], 3, 2, ), - # When timeout=-1, create_index returns immediately without calling describe_index or sleep - (-1, [{"status": {"ready": False}}], 0, 0), + # # When timeout=-1, create_index returns immediately without calling sleep + (-1, [description_with_status(False)], 0, 0), ], ) def test_create_index_with_timeout( @@ -213,20 +233,20 @@ def test_create_index_with_spec_dictionary(self, mocker, index_spec): "timeout_value, describe_index_responses, expected_describe_index_calls, expected_sleep_calls", [ # When timeout=None, describe_index is called until ready - (None, [{"status": {"ready": False}}, {"status": {"ready": True}}], 2, 1), + (None, [description_with_status(False), description_with_status(True)], 2, 1), # Timeout of 10 seconds, describe_index called 3 times, sleep twice ( 10, [ - {"status": {"ready": False}}, - {"status": {"ready": False}}, - {"status": {"ready": True}}, + description_with_status(False), + description_with_status(False), + description_with_status(True), ], 3, 2, ), - # When timeout=-1, create_index returns immediately without calling describe_index or sleep - (-1, [{"status": {"ready": False}}], 0, 0), + # When timeout=-1, create_index returns immediately without calling sleep + (-1, [description_with_status(False)], 0, 0), ], ) def test_create_index_from_source_collection( @@ -258,7 +278,7 @@ def test_create_index_when_timeout_exceeded(self, mocker): p = Pinecone(api_key="123-456-789") mocker.patch.object(p.index_api, "create_index") - describe_index_response = [{"status": {"ready": False}}] * 5 + describe_index_response = [description_with_status(False)] * 5 mocker.patch.object(p.index_api, "describe_index", side_effect=describe_index_response) mocker.patch("time.sleep")