Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 142 additions & 59 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
import logging
from typing import Optional, Dict, Any, Union
from enum import Enum

from .index_host_store import IndexHostStore
from .pinecone_interface import PineconeDBControlInterface
Expand All @@ -14,6 +15,8 @@
from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client
from pinecone.core.openapi.db_control.models import (
CreateCollectionRequest,
CreateIndexForModelRequest,
CreateIndexForModelRequestEmbed,
CreateIndexRequest,
ConfigureIndexRequest,
ConfigureIndexRequestSpec,
Expand All @@ -26,12 +29,29 @@
PodSpecMetadataConfig,
)
from pinecone.core.openapi.db_control import API_VERSION
from pinecone.models import ServerlessSpec, PodSpec, IndexModel, IndexList, CollectionList
from pinecone.models import (
ServerlessSpec,
PodSpec,
IndexModel,
IndexList,
CollectionList,
IndexEmbed,
)
from .langchain_import_warnings import _build_langchain_attribute_error_message
from pinecone.utils import parse_non_empty_args, docslinks

from pinecone.data import _Index, _AsyncioIndex, _Inference
from pinecone.enums import Metric, VectorType, DeletionProtection, PodType
from pinecone.enums import (
Metric,
VectorType,
DeletionProtection,
PodType,
CloudProvider,
AwsRegion,
GcpRegion,
AzureRegion,
)
from .types import CreateIndexForModelEmbedTypedDict

from pinecone_plugin_interface import load_and_install as install_plugins

Expand Down Expand Up @@ -127,7 +147,27 @@ def load_plugins(self):
except Exception as e:
logger.error(f"Error loading plugins: {e}")

def _parse_index_spec(self, spec: Union[Dict, ServerlessSpec, PodSpec]) -> IndexSpec:
def __parse_tags(self, tags: Optional[Dict[str, str]]) -> IndexTags:
if tags is None:
return IndexTags()
else:
return IndexTags(**tags)

def __parse_deletion_protection(
self, deletion_protection: Union[DeletionProtection, str]
) -> DeletionProtectionModel:
deletion_protection = self.__parse_enum_to_string(deletion_protection)
if deletion_protection in ["enabled", "disabled"]:
return DeletionProtectionModel(deletion_protection)
else:
raise ValueError("deletion_protection must be either 'enabled' or 'disabled'")

def __parse_enum_to_string(self, value: Union[Enum, str]) -> str:
if isinstance(value, Enum):
return value.value
return value

def __parse_index_spec(self, spec: Union[Dict, ServerlessSpec, PodSpec]) -> IndexSpec:
if isinstance(spec, dict):
if "serverless" in spec:
index_spec = IndexSpec(serverless=ServerlessSpecModel(**spec["serverless"]))
Expand Down Expand Up @@ -185,85 +225,133 @@ def create_index(
deletion_protection: Optional[Union[DeletionProtection, str]] = DeletionProtection.DISABLED,
vector_type: Optional[Union[VectorType, str]] = VectorType.DENSE,
tags: Optional[Dict[str, str]] = None,
):
# Convert Enums to their string values if necessary
metric = metric.value if isinstance(metric, Metric) else str(metric)
vector_type = vector_type.value if isinstance(vector_type, VectorType) else str(vector_type)
deletion_protection = (
deletion_protection.value
if isinstance(deletion_protection, DeletionProtection)
else str(deletion_protection)
)
) -> IndexModel:
if metric is not None:
metric = self.__parse_enum_to_string(metric)
if vector_type is not None:
vector_type = self.__parse_enum_to_string(vector_type)
if deletion_protection is not None:
dp = self.__parse_deletion_protection(deletion_protection)
else:
dp = None

tags_obj = self.__parse_tags(tags)
index_spec = self.__parse_index_spec(spec)

if vector_type == VectorType.SPARSE.value and dimension is not None:
raise ValueError("dimension should not be specified for sparse indexes")

if deletion_protection in ["enabled", "disabled"]:
dp = DeletionProtectionModel(deletion_protection)
args = parse_non_empty_args(
[
("name", name),
("dimension", dimension),
("metric", metric),
("spec", index_spec),
("deletion_protection", dp),
("vector_type", vector_type),
("tags", tags_obj),
]
)

req = CreateIndexRequest(**args)
resp = self.index_api.create_index(create_index_request=req)

if timeout == -1:
return IndexModel(resp)
return self.__poll_describe_index_until_ready(name, timeout)

def create_index_for_model(
self,
name: str,
cloud: Union[CloudProvider, str],
region: Union[AwsRegion, GcpRegion, AzureRegion, str],
embed: Union[IndexEmbed, CreateIndexForModelEmbedTypedDict],
tags: Optional[Dict[str, str]] = None,
deletion_protection: Optional[Union[DeletionProtection, str]] = DeletionProtection.DISABLED,
timeout: Optional[int] = None,
) -> IndexModel:
cloud = self.__parse_enum_to_string(cloud)
region = self.__parse_enum_to_string(region)
if deletion_protection is not None:
dp = self.__parse_deletion_protection(deletion_protection)
else:
raise ValueError("deletion_protection must be either 'enabled' or 'disabled'")
dp = None
tags_obj = self.__parse_tags(tags)

if tags is None:
tags_obj = None
if isinstance(embed, IndexEmbed):
parsed_embed = embed.as_dict()
else:
tags_obj = IndexTags(**tags)
# if dict, we need to parse enum values, if any, to string
# and verify required fields are present
required_fields = ["model", "field_map"]
for field in required_fields:
if field not in embed:
raise ValueError(f"{field} is required in embed")
parsed_embed = {}
for key, value in embed.items():
if isinstance(value, Enum):
parsed_embed[key] = value.value
else:
parsed_embed[key] = value

args = parse_non_empty_args(
[
("name", name),
("cloud", cloud),
("region", region),
("embed", CreateIndexForModelRequestEmbed(**parsed_embed)),
("deletion_protection", dp),
("tags", tags_obj),
]
)

index_spec = self._parse_index_spec(spec)
req = CreateIndexForModelRequest(**args)
resp = self.index_api.create_index_for_model(req)

api_instance = self.index_api
api_instance.create_index(
create_index_request=CreateIndexRequest(
**parse_non_empty_args(
[
("name", name),
("dimension", dimension),
("metric", metric),
("spec", index_spec),
("deletion_protection", dp),
("vector_type", vector_type),
("tags", tags_obj),
]
)
)
)
if timeout == -1:
return IndexModel(resp)
return self.__poll_describe_index_until_ready(name, timeout)

def __poll_describe_index_until_ready(self, name: str, timeout: Optional[int] = None):
description = None

def is_ready():
status = self._get_status(name)
ready = status["ready"]
return ready
nonlocal description
description = self.describe_index(name=name)
return description.status.ready

total_wait_time = 0
if timeout == -1:
logger.debug(f"Skipping wait for index {name} to be ready")
return
if timeout is None:
# Wait indefinitely
while not is_ready():
logger.debug(
f"Waiting for index {name} to be ready. Total wait time: {total_wait_time}"
)
total_wait_time += 5
time.sleep(5)

else:
while (not is_ready()) and timeout >= 0:
# Wait for a maximum of timeout seconds
while not is_ready():
if timeout < 0:
logger.error(f"Index {name} is not ready. Timeout reached.")
link = docslinks["API_DESCRIBE_INDEX"]
timeout_msg = (
f"Please call describe_index() to confirm index status. See docs at {link}"
)
raise TimeoutError(timeout_msg)

logger.debug(
f"Waiting for index {name} to be ready. Total wait time: {total_wait_time}"
)
total_wait_time += 5
time.sleep(5)
timeout -= 5
if timeout and timeout < 0:
logger.error(f"Index {name} is not ready. Timeout reached.")
raise (
TimeoutError(
"Please call the describe_index API ({}) to confirm index status.".format(
docslinks["API_DESCRIBE_INDEX"]
)
)
)

return description

def delete_index(self, name: str, timeout: Optional[int] = None):
api_instance = self.index_api
api_instance.delete_index(name)
self.index_api.delete_index(name)
self.index_host_store.delete_host(self.config, name)

def get_remaining():
Expand Down Expand Up @@ -292,7 +380,7 @@ def list_indexes(self) -> IndexList:
response = self.index_api.list_indexes()
return IndexList(response)

def describe_index(self, name: str):
def describe_index(self, name: str) -> IndexModel:
api_instance = self.index_api
description = api_instance.describe_index(name)
host = description.host
Expand Down Expand Up @@ -373,11 +461,6 @@ def describe_collection(self, name: str):
api_instance = self.index_api
return api_instance.describe_collection(name).to_dict()

def _get_status(self, name: str):
api_instance = self.index_api
response = api_instance.describe_index(name)
return response["status"]

@staticmethod
def from_texts(*args, **kwargs):
raise AttributeError(_build_langchain_attribute_error_message("from_texts"))
Expand Down
61 changes: 58 additions & 3 deletions pinecone/control/pinecone_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,25 @@
from pinecone.core.openapi.db_control.api.manage_indexes_api import ManageIndexesApi


from pinecone.models import ServerlessSpec, PodSpec, IndexList, CollectionList
from pinecone.enums import Metric, VectorType, DeletionProtection, PodType
from pinecone.models import (
ServerlessSpec,
PodSpec,
IndexList,
CollectionList,
IndexModel,
IndexEmbed,
)
from pinecone.enums import (
Metric,
VectorType,
DeletionProtection,
PodType,
CloudProvider,
AwsRegion,
GcpRegion,
AzureRegion,
)
from .types import CreateIndexForModelEmbedTypedDict


class PineconeDBControlInterface(ABC):
Expand Down Expand Up @@ -246,6 +263,44 @@ def create_index(
"""
pass

@abstractmethod
def create_index_for_model(
self,
name: str,
cloud: Union[CloudProvider, str],
region: Union[AwsRegion, GcpRegion, AzureRegion, str],
embed: Union[IndexEmbed, CreateIndexForModelEmbedTypedDict],
tags: Optional[Dict[str, str]] = None,
deletion_protection: Optional[Union[DeletionProtection, str]] = DeletionProtection.DISABLED,
timeout: Optional[int] = None,
) -> IndexModel:
"""
Create an index for a model.

This operation creates a serverless index for a model. The index is used to store embeddings generated by the model. The index can be used to search and retrieve embeddings.

:param name: The name of the index to create. Must be unique within your project and
cannot be changed once created. Allowed characters are lowercase letters, numbers,
and hyphens and the name may not begin or end with hyphens. Maximum length is 45 characters.
:type name: str
:param cloud: The cloud provider to use for the index.
:type cloud: str
:param region: The region to use for the index.
:type region: str
:param embed: The embedding configuration for the index.
:type embed: Union[Dict, IndexEmbed]
:param tags: A dictionary of tags to associate with the index.
:type tags: Optional[Dict[str, str]]
:param deletion_protection: If enabled, the index cannot be deleted. If disabled, the index can be deleted. Default: "disabled"
:type deletion_protection: Optional[Literal["enabled", "disabled"]]
:type timeout: Optional[int]
:param timeout: Specify the number of seconds to wait until index gets ready. If None, wait indefinitely; if >=0, time out after this many seconds;
if -1, return immediately and do not wait. Default: None
:return: The index that was created.
:rtype: IndexModel
"""
pass

@abstractmethod
def delete_index(self, name: str, timeout: Optional[int] = None):
"""Deletes a Pinecone index.
Expand Down Expand Up @@ -316,7 +371,7 @@ def list_indexes(self) -> IndexList:
pass

@abstractmethod
def describe_index(self, name: str):
def describe_index(self, name: str) -> IndexModel:
"""Describes a Pinecone index.

:param name: the name of the index to describe.
Expand Down
1 change: 1 addition & 0 deletions pinecone/control/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .create_index_for_model_embed import CreateIndexForModelEmbedTypedDict
11 changes: 11 additions & 0 deletions pinecone/control/types/create_index_for_model_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import TypedDict, Dict, Union
from ...enums import Metric
from ...data.features.inference import EmbedModel


class CreateIndexForModelEmbedTypedDict(TypedDict):
model: Union[EmbedModel, str]
field_map: Dict
metric: Union[Metric, str]
read_parameters: Dict
write_parameters: Dict
Loading
Loading