diff --git a/integration/conftest.py b/integration/conftest.py index f41c7ebaf..d19f3830d 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -47,13 +47,14 @@ def __call__( vector_index_config: Optional[_VectorIndexConfigCreate] = None, description: Optional[str] = None, reranker_config: Optional[_RerankerConfigCreate] = None, - ) -> Collection[Any, Any]: + return_client: bool = False + ) -> Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]]: """Typing for fixture.""" ... @pytest.fixture -def collection_factory(request: SubRequest) -> Generator[CollectionFactory, None, None]: +def collection_factory(request: SubRequest) -> Generator[Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]], None, None]: name_fixture: Optional[str] = None client_fixture: Optional[weaviate.WeaviateClient] = None @@ -75,6 +76,7 @@ def _factory( vector_index_config: Optional[_VectorIndexConfigCreate] = None, description: Optional[str] = None, reranker_config: Optional[_RerankerConfigCreate] = None, + return_client: bool = False ) -> Collection[Any, Any]: nonlocal client_fixture, name_fixture name_fixture = _sanitize_collection_name(request.node.name) + name @@ -101,7 +103,10 @@ def _factory( vector_index_config=vector_index_config, reranker_config=reranker_config, ) - return collection + if return_client: + return collection, client_fixture + else: + return collection try: yield _factory diff --git a/integration/test_collection_config.py b/integration/test_collection_config.py index 7c099734d..1c9b4d7bf 100644 --- a/integration/test_collection_config.py +++ b/integration/test_collection_config.py @@ -27,6 +27,7 @@ GenerativeSearches, Rerankers, _RerankerConfigCreate, + Tokenization ) from weaviate.collections.classes.tenants import Tenant @@ -589,7 +590,76 @@ def test_collection_config_get_shards_multi_tenancy(collection_factory: Collecti assert "tenant1" in [shard.name for shard in shards] assert "tenant2" in [shard.name for shard in shards] - +def test_collection_config_create_from_dict(collection_factory: CollectionFactory) -> None: + collection, client = collection_factory( + inverted_index_config=Configure.inverted_index(bm25_b=0.8, bm25_k1=1.3), + multi_tenancy_config=Configure.multi_tenancy(enabled=True), + generative_config=Configure.Generative.openai(model="gpt-4"), + vectorizer_config=Configure.Vectorizer.text2vec_openai( + model="text-embedding-3-small", + base_url="http://weaviate.io", + vectorize_collection_name=False, + dimensions=512 + ), + vector_index_config=Configure.VectorIndex.flat( + vector_cache_max_objects=234, + quantizer=Configure.VectorIndex.Quantizer.bq(rescore_limit=456), + ), + description="Some description", + reranker_config=Configure.Reranker.cohere(model="rerank-english-v2.0"), + properties=[ + Property(name="field_tokenization", data_type=DataType.TEXT, tokenization=Tokenization.FIELD), + Property(name="field_description", data_type=DataType.TEXT, + tokenization=Tokenization.FIELD, description="field desc"), + Property(name="field_index_filterable", data_type=DataType.TEXT, + index_filterable=False), + Property(name="field_skip_vectorization", data_type=DataType.TEXT, + skip_vectorization=True), + Property(name="text", data_type=DataType.TEXT), + Property(name="texts", data_type=DataType.TEXT_ARRAY), + Property(name="number", data_type=DataType.NUMBER), + Property(name="numbers", data_type=DataType.NUMBER_ARRAY), + Property(name="int", data_type=DataType.INT), + Property(name="ints", data_type=DataType.INT_ARRAY), + Property(name="date", data_type=DataType.DATE), + Property(name="dates", data_type=DataType.DATE_ARRAY), + Property(name="boolean", data_type=DataType.BOOL), + Property(name="booleans", data_type=DataType.BOOL_ARRAY), + Property(name="geo", data_type=DataType.GEO_COORDINATES), + Property(name="phone", data_type=DataType.PHONE_NUMBER), + Property(name="vectorize_property_name", data_type=DataType.TEXT, + vectorize_property_name=False), + Property(name="field_index_searchable", data_type=DataType.TEXT, + index_searchable=False), + # TODO: this will fail + # Property( + # name="name", + # data_type=DataType.OBJECT, + # nested_properties=[ + # Property(name="first", data_type=DataType.TEXT), + # Property(name="last", data_type=DataType.TEXT), + # ], + # ), + ], + return_client=True + ) + old_dict = collection.config.get().to_dict() + new_dict = old_dict + new_collection_name = collection.name + "_FROM_DICT" + client.collections.delete(new_collection_name) + new_dict["class"] = new_collection_name + new_collection = client.collections.create_from_dict(new_dict) + new_collection_dict = new_collection.config.get().to_dict() + # make the same name for collections + new_collection_dict["class"] = collection.name + old_dict["class"] = collection.name + # check if both dict are the same + #print("old", old_dict) + #print("new", new_collection_dict) + assert new_collection_dict == old_dict + # remove the created collection + client.collections.delete(new_collection_name) + def test_config_vector_index_flat_and_quantizer_bq(collection_factory: CollectionFactory) -> None: collection = collection_factory( vector_index_config=Configure.VectorIndex.flat( diff --git a/weaviate/client.py b/weaviate/client.py index 8a84e16a4..963841143 100644 --- a/weaviate/client.py +++ b/weaviate/client.py @@ -157,6 +157,7 @@ def __init__( additional_headers: Optional[dict] = None, additional_config: Optional[AdditionalConfig] = None, skip_init_checks: bool = False, + disable_ssl_verification: bool = False, ) -> None: """Initialise a WeaviateClient class instance to use when interacting with Weaviate. @@ -191,6 +192,7 @@ def __init__( config = additional_config or AdditionalConfig() self.__skip_init_checks = skip_init_checks + self.__disable_ssl_verification = disable_ssl_verification self._connection = ConnectionV4( # pyright: ignore reportIncompatibleVariableOverride connection_params=connection_params, @@ -284,7 +286,7 @@ def connect(self) -> None: """ if self._connection.is_connected(): return - self._connection.connect(self.__skip_init_checks) + self._connection.connect(self.__skip_init_checks, self.__disable_ssl_verification) def is_connected(self) -> bool: """Check if the client is connected to Weaviate. diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index 8a837a21e..65f549c28 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -285,6 +285,7 @@ def _shutdown(self) -> None: def __batch_send(self) -> None: loop = self.__start_new_event_loop() + # TODO: figure a way to pass disable_verification_process to aopen future = asyncio.run_coroutine_threadsafe(self.__connection.aopen(), loop) future.result() # Wait for self._connection.aopen() to finish refresh_time: float = 0.01 diff --git a/weaviate/collections/classes/config.py b/weaviate/collections/classes/config.py index e0b7700e3..e57b71127 100644 --- a/weaviate/collections/classes/config.py +++ b/weaviate/collections/classes/config.py @@ -1016,8 +1016,8 @@ def to_dict(self) -> Dict[str, Any]: out = super().to_dict() out["dataType"] = [self.data_type.value] out["indexFilterable"] = self.index_filterable - out["indexVector"] = self.index_searchable - out["tokenizer"] = self.tokenization.value if self.tokenization else None + out["indexSearchable"] = self.index_searchable + out["tokenization"] = self.tokenization.value if self.tokenization else None module_config: Dict[str, Any] = {} if self.vectorizer is not None: diff --git a/weaviate/connect/helpers.py b/weaviate/connect/helpers.py index ce79fe1ad..9568ac7a0 100644 --- a/weaviate/connect/helpers.py +++ b/weaviate/connect/helpers.py @@ -17,6 +17,7 @@ def connect_to_weaviate_cloud( headers: Optional[Dict[str, str]] = None, additional_config: Optional[AdditionalConfig] = None, skip_init_checks: bool = False, + disable_ssl_verification: bool = False, ) -> WeaviateClient: """ Connect to a Weaviate Cloud (WCD) instance. @@ -81,6 +82,7 @@ def connect_to_weaviate_cloud( additional_headers=headers, additional_config=additional_config, skip_init_checks=skip_init_checks, + disable_ssl_verification=disable_ssl_verification, ) return __connect(client) @@ -91,6 +93,7 @@ def connect_to_wcs( headers: Optional[Dict[str, str]] = None, additional_config: Optional[AdditionalConfig] = None, skip_init_checks: bool = False, + disable_ssl_verification: bool = False, ) -> WeaviateClient: """ Connect to a Weaviate Cloud (WCD) instance. @@ -137,7 +140,12 @@ def connect_to_wcs( >>> # The connection is automatically closed when the context is exited. """ return connect_to_weaviate_cloud( - cluster_url, auth_credentials, headers, additional_config, skip_init_checks + cluster_url, + auth_credentials, + headers, + additional_config, + skip_init_checks, + disable_ssl_verification, ) @@ -148,6 +156,7 @@ def connect_to_local( headers: Optional[Dict[str, str]] = None, additional_config: Optional[AdditionalConfig] = None, skip_init_checks: bool = False, + disable_ssl_verification: bool = False, auth_credentials: Optional[AuthCredentials] = None, ) -> WeaviateClient: """ @@ -208,6 +217,7 @@ def connect_to_local( additional_headers=headers, additional_config=additional_config, skip_init_checks=skip_init_checks, + disable_ssl_verification=disable_ssl_verification, auth_client_secret=auth_credentials, ) return __connect(client) @@ -310,6 +320,7 @@ def connect_to_custom( additional_config: Optional[AdditionalConfig] = None, auth_credentials: Optional[AuthCredentials] = None, skip_init_checks: bool = False, + disable_ssl_verification: bool = False, ) -> WeaviateClient: """ Connect to a Weaviate instance with custom connection parameters. @@ -388,6 +399,7 @@ def connect_to_custom( additional_headers=headers, additional_config=additional_config, skip_init_checks=skip_init_checks, + disable_ssl_verification=disable_ssl_verification, ) return __connect(client) diff --git a/weaviate/connect/v4.py b/weaviate/connect/v4.py index 1f46dbd3a..1c95701dd 100644 --- a/weaviate/connect/v4.py +++ b/weaviate/connect/v4.py @@ -132,10 +132,10 @@ def __init__( if auth_client_secret is not None and isinstance(auth_client_secret, AuthApiKey): self._headers["authorization"] = "Bearer " + auth_client_secret.api_key - def connect(self, skip_init_checks: bool) -> None: + def connect(self, skip_init_checks: bool, disable_ssl_verification: bool) -> None: if self.embedded_db is not None: self.embedded_db.start() - self._create_clients(self._auth, skip_init_checks) + self._create_clients(self._auth, skip_init_checks, disable_ssl_verification) self.__connected = True if self.embedded_db is not None: try: @@ -214,46 +214,51 @@ def __make_mounts( if key != "grpc" } - def __make_sync_client(self) -> Client: + def __make_sync_client(self, disable_ssl_verification: bool) -> Client: return Client( headers=self._headers, timeout=Timeout( None, connect=self.timeout_config.query, read=self.timeout_config.insert ), mounts=self.__make_mounts("sync"), + verify=not disable_ssl_verification, ) - def __make_async_client(self) -> AsyncClient: + def __make_async_client(self, disable_ssl_verification: bool) -> AsyncClient: return AsyncClient( headers=self._headers, timeout=Timeout( None, connect=self.timeout_config.query, read=self.timeout_config.insert ), mounts=self.__make_mounts("async"), + verify=not disable_ssl_verification, ) - def __make_clients(self) -> None: - self._client = self.__make_sync_client() + def __make_clients(self, disable_ssl_verification: bool) -> None: + self._client = self.__make_sync_client(disable_ssl_verification) def _create_clients( - self, auth_client_secret: Optional[AuthCredentials], skip_init_checks: bool + self, + auth_client_secret: Optional[AuthCredentials], + skip_init_checks: bool, + disable_ssl_verification: bool, ) -> None: # API keys are separate from OIDC and do not need any config from weaviate if auth_client_secret is not None and isinstance(auth_client_secret, AuthApiKey): - self.__make_clients() + self.__make_clients(disable_ssl_verification) return if "authorization" in self._headers and auth_client_secret is None: - self.__make_clients() + self.__make_clients(disable_ssl_verification) return # no need to check OIDC if no auth is provided and users dont want any checks at initialization time if skip_init_checks and auth_client_secret is None: - self.__make_clients() + self.__make_clients(disable_ssl_verification) return oidc_url = self.url + self._api_version_path + "/.well-known/openid-configuration" - with self.__make_sync_client() as client: + with self.__make_sync_client(disable_ssl_verification=disable_ssl_verification) as client: try: response = client.get(oidc_url) except Exception as e: @@ -269,7 +274,7 @@ def _create_clients( resp = response.json() except Exception: _Warnings.auth_cannot_parse_oidc_config(oidc_url) - self.__make_clients() + self.__make_clients(disable_ssl_verification=disable_ssl_verification) return if auth_client_secret is not None: @@ -309,9 +314,9 @@ def _create_clients( raise AuthenticationFailedError(msg) elif response.status_code == 404 and auth_client_secret is not None: _Warnings.auth_with_anon_weaviate() - self.__make_clients() + self.__make_clients(disable_ssl_verification) else: - self.__make_clients() + self.__make_clients(disable_ssl_verification) def get_current_bearer_token(self) -> str: if not self.is_connected(): @@ -376,9 +381,9 @@ def periodic_refresh_token(refresh_time: int, _auth: Optional[_Auth[OAuth2Client ) demon.start() - async def aopen(self) -> None: + async def aopen(self, disable_ssl_verification: bool = False) -> None: if self._aclient is None: - self._aclient = await self.__make_async_client().__aenter__() + self._aclient = await self.__make_async_client(disable_ssl_verification).__aenter__() if self._grpc_stub_async is None: self._grpc_channel_async = self._connection_params._grpc_channel( async_channel=True, proxies=self._proxies @@ -453,7 +458,7 @@ def __send( except RuntimeError as e: raise WeaviateClosedClientError() from e except ConnectError as conn_err: - raise WeaviateConnectionError(error_msg) from conn_err + raise WeaviateConnectionError(f"{conn_err} {error_msg}") def delete( self, @@ -707,8 +712,8 @@ def _ping_grpc(self) -> None: f"v{self.server_version}", self._connection_params._grpc_address ) from e - def connect(self, skip_init_checks: bool) -> None: - super().connect(skip_init_checks) + def connect(self, skip_init_checks: bool, disable_ssl_verification: bool) -> None: + super().connect(skip_init_checks, disable_ssl_verification) # create GRPC channel. If Weaviate does not support GRPC then error now. self._grpc_channel = self._connection_params._grpc_channel( async_channel=False, proxies=self._proxies diff --git a/weaviate/exceptions.py b/weaviate/exceptions.py index 0d6161e41..5510b5bf5 100644 --- a/weaviate/exceptions.py +++ b/weaviate/exceptions.py @@ -182,7 +182,35 @@ def __init__(self, type_: str) -> None: class WeaviateStartUpError(WeaviateBaseError): - """Is raised if weaviate is not available on the given url+port.""" + """Is raised if weaviate is not available on the given url+port or due to ssl verification.""" + + def __init__(self, message: str = ""): + """ + Weaviate base exception initializer. + + Arguments: + `message`: + An error message specific to the context in which the error occurred. + """ + + self.message = message + if "SSL: CERTIFICATE_VERIFY_FAILED" in str(message): + msg = """ + We have identified a SSL CERTIFICATE_VERIFY_FAILED error. + + This error could be due to one of several reasons: + - Weaviate client is under a corporate network that terminates ssl and issues it's own certificates. + - You have a self signed certificate + + Weaviate python client uses certifi, and because of that, it will not be able to trust + Potential fixes: + - disable ssl verification by setting using `disable_ssl_verification=True` in client initialization + - note that Weaviate will trust any certificate + - Replace certifi cacert with the same cacert that is issued by your corporate network. + - for example: cat MyCompanyRootCA.pem >> $(python -m certifi) + """ + message = message + msg + super().__init__(message) class WeaviateEmbeddedInvalidVersionError(WeaviateBaseError):