Skip to content

Add disable ssl verification to client instantiation #1260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,14 @@ def __call__(
vector_index_config: Optional[_VectorIndexConfigCreate] = None,
description: Optional[str] = None,
reranker_config: Optional[_RerankerConfigCreate] = None,
) -> Collection[Any, Any]:
return_client: bool = False
) -> Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]]:
"""Typing for fixture."""
...


@pytest.fixture
def collection_factory(request: SubRequest) -> Generator[CollectionFactory, None, None]:
def collection_factory(request: SubRequest) -> Generator[Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]], None, None]:
name_fixture: Optional[str] = None
client_fixture: Optional[weaviate.WeaviateClient] = None

Expand All @@ -75,6 +76,7 @@ def _factory(
vector_index_config: Optional[_VectorIndexConfigCreate] = None,
description: Optional[str] = None,
reranker_config: Optional[_RerankerConfigCreate] = None,
return_client: bool = False
) -> Collection[Any, Any]:
nonlocal client_fixture, name_fixture
name_fixture = _sanitize_collection_name(request.node.name) + name
Expand All @@ -101,7 +103,10 @@ def _factory(
vector_index_config=vector_index_config,
reranker_config=reranker_config,
)
return collection
if return_client:
return collection, client_fixture
else:
return collection

try:
yield _factory
Expand Down
72 changes: 71 additions & 1 deletion integration/test_collection_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
GenerativeSearches,
Rerankers,
_RerankerConfigCreate,
Tokenization
)
from weaviate.collections.classes.tenants import Tenant

Expand Down Expand Up @@ -589,7 +590,76 @@ def test_collection_config_get_shards_multi_tenancy(collection_factory: Collecti
assert "tenant1" in [shard.name for shard in shards]
assert "tenant2" in [shard.name for shard in shards]


def test_collection_config_create_from_dict(collection_factory: CollectionFactory) -> None:
collection, client = collection_factory(
inverted_index_config=Configure.inverted_index(bm25_b=0.8, bm25_k1=1.3),
multi_tenancy_config=Configure.multi_tenancy(enabled=True),
generative_config=Configure.Generative.openai(model="gpt-4"),
vectorizer_config=Configure.Vectorizer.text2vec_openai(
model="text-embedding-3-small",
base_url="http://weaviate.io",
vectorize_collection_name=False,
dimensions=512
),
vector_index_config=Configure.VectorIndex.flat(
vector_cache_max_objects=234,
quantizer=Configure.VectorIndex.Quantizer.bq(rescore_limit=456),
),
description="Some description",
reranker_config=Configure.Reranker.cohere(model="rerank-english-v2.0"),
properties=[
Property(name="field_tokenization", data_type=DataType.TEXT, tokenization=Tokenization.FIELD),
Property(name="field_description", data_type=DataType.TEXT,
tokenization=Tokenization.FIELD, description="field desc"),
Property(name="field_index_filterable", data_type=DataType.TEXT,
index_filterable=False),
Property(name="field_skip_vectorization", data_type=DataType.TEXT,
skip_vectorization=True),
Property(name="text", data_type=DataType.TEXT),
Property(name="texts", data_type=DataType.TEXT_ARRAY),
Property(name="number", data_type=DataType.NUMBER),
Property(name="numbers", data_type=DataType.NUMBER_ARRAY),
Property(name="int", data_type=DataType.INT),
Property(name="ints", data_type=DataType.INT_ARRAY),
Property(name="date", data_type=DataType.DATE),
Property(name="dates", data_type=DataType.DATE_ARRAY),
Property(name="boolean", data_type=DataType.BOOL),
Property(name="booleans", data_type=DataType.BOOL_ARRAY),
Property(name="geo", data_type=DataType.GEO_COORDINATES),
Property(name="phone", data_type=DataType.PHONE_NUMBER),
Property(name="vectorize_property_name", data_type=DataType.TEXT,
vectorize_property_name=False),
Property(name="field_index_searchable", data_type=DataType.TEXT,
index_searchable=False),
# TODO: this will fail
# Property(
# name="name",
# data_type=DataType.OBJECT,
# nested_properties=[
# Property(name="first", data_type=DataType.TEXT),
# Property(name="last", data_type=DataType.TEXT),
# ],
# ),
],
return_client=True
)
old_dict = collection.config.get().to_dict()
new_dict = old_dict
new_collection_name = collection.name + "_FROM_DICT"
client.collections.delete(new_collection_name)
new_dict["class"] = new_collection_name
new_collection = client.collections.create_from_dict(new_dict)
new_collection_dict = new_collection.config.get().to_dict()
# make the same name for collections
new_collection_dict["class"] = collection.name
old_dict["class"] = collection.name
# check if both dict are the same
#print("old", old_dict)
#print("new", new_collection_dict)
assert new_collection_dict == old_dict
# remove the created collection
client.collections.delete(new_collection_name)

def test_config_vector_index_flat_and_quantizer_bq(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
vector_index_config=Configure.VectorIndex.flat(
Expand Down
4 changes: 3 additions & 1 deletion weaviate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(
additional_headers: Optional[dict] = None,
additional_config: Optional[AdditionalConfig] = None,
skip_init_checks: bool = False,
disable_ssl_verification: bool = False,
) -> None:
"""Initialise a WeaviateClient class instance to use when interacting with Weaviate.

Expand Down Expand Up @@ -191,6 +192,7 @@ def __init__(
config = additional_config or AdditionalConfig()

self.__skip_init_checks = skip_init_checks
self.__disable_ssl_verification = disable_ssl_verification

self._connection = ConnectionV4( # pyright: ignore reportIncompatibleVariableOverride
connection_params=connection_params,
Expand Down Expand Up @@ -284,7 +286,7 @@ def connect(self) -> None:
"""
if self._connection.is_connected():
return
self._connection.connect(self.__skip_init_checks)
self._connection.connect(self.__skip_init_checks, self.__disable_ssl_verification)

def is_connected(self) -> bool:
"""Check if the client is connected to Weaviate.
Expand Down
1 change: 1 addition & 0 deletions weaviate/collections/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def _shutdown(self) -> None:

def __batch_send(self) -> None:
loop = self.__start_new_event_loop()
# TODO: figure a way to pass disable_verification_process to aopen
future = asyncio.run_coroutine_threadsafe(self.__connection.aopen(), loop)
future.result() # Wait for self._connection.aopen() to finish
refresh_time: float = 0.01
Expand Down
4 changes: 2 additions & 2 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,8 +1016,8 @@ def to_dict(self) -> Dict[str, Any]:
out = super().to_dict()
out["dataType"] = [self.data_type.value]
out["indexFilterable"] = self.index_filterable
out["indexVector"] = self.index_searchable
out["tokenizer"] = self.tokenization.value if self.tokenization else None
out["indexSearchable"] = self.index_searchable
out["tokenization"] = self.tokenization.value if self.tokenization else None

module_config: Dict[str, Any] = {}
if self.vectorizer is not None:
Expand Down
14 changes: 13 additions & 1 deletion weaviate/connect/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def connect_to_weaviate_cloud(
headers: Optional[Dict[str, str]] = None,
additional_config: Optional[AdditionalConfig] = None,
skip_init_checks: bool = False,
disable_ssl_verification: bool = False,
) -> WeaviateClient:
"""
Connect to a Weaviate Cloud (WCD) instance.
Expand Down Expand Up @@ -81,6 +82,7 @@ def connect_to_weaviate_cloud(
additional_headers=headers,
additional_config=additional_config,
skip_init_checks=skip_init_checks,
disable_ssl_verification=disable_ssl_verification,
)
return __connect(client)

Expand All @@ -91,6 +93,7 @@ def connect_to_wcs(
headers: Optional[Dict[str, str]] = None,
additional_config: Optional[AdditionalConfig] = None,
skip_init_checks: bool = False,
disable_ssl_verification: bool = False,
) -> WeaviateClient:
"""
Connect to a Weaviate Cloud (WCD) instance.
Expand Down Expand Up @@ -137,7 +140,12 @@ def connect_to_wcs(
>>> # The connection is automatically closed when the context is exited.
"""
return connect_to_weaviate_cloud(
cluster_url, auth_credentials, headers, additional_config, skip_init_checks
cluster_url,
auth_credentials,
headers,
additional_config,
skip_init_checks,
disable_ssl_verification,
)


Expand All @@ -148,6 +156,7 @@ def connect_to_local(
headers: Optional[Dict[str, str]] = None,
additional_config: Optional[AdditionalConfig] = None,
skip_init_checks: bool = False,
disable_ssl_verification: bool = False,
auth_credentials: Optional[AuthCredentials] = None,
) -> WeaviateClient:
"""
Expand Down Expand Up @@ -208,6 +217,7 @@ def connect_to_local(
additional_headers=headers,
additional_config=additional_config,
skip_init_checks=skip_init_checks,
disable_ssl_verification=disable_ssl_verification,
auth_client_secret=auth_credentials,
)
return __connect(client)
Expand Down Expand Up @@ -310,6 +320,7 @@ def connect_to_custom(
additional_config: Optional[AdditionalConfig] = None,
auth_credentials: Optional[AuthCredentials] = None,
skip_init_checks: bool = False,
disable_ssl_verification: bool = False,
) -> WeaviateClient:
"""
Connect to a Weaviate instance with custom connection parameters.
Expand Down Expand Up @@ -388,6 +399,7 @@ def connect_to_custom(
additional_headers=headers,
additional_config=additional_config,
skip_init_checks=skip_init_checks,
disable_ssl_verification=disable_ssl_verification,
)
return __connect(client)

Expand Down
43 changes: 24 additions & 19 deletions weaviate/connect/v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ def __init__(
if auth_client_secret is not None and isinstance(auth_client_secret, AuthApiKey):
self._headers["authorization"] = "Bearer " + auth_client_secret.api_key

def connect(self, skip_init_checks: bool) -> None:
def connect(self, skip_init_checks: bool, disable_ssl_verification: bool) -> None:
if self.embedded_db is not None:
self.embedded_db.start()
self._create_clients(self._auth, skip_init_checks)
self._create_clients(self._auth, skip_init_checks, disable_ssl_verification)
self.__connected = True
if self.embedded_db is not None:
try:
Expand Down Expand Up @@ -214,46 +214,51 @@ def __make_mounts(
if key != "grpc"
}

def __make_sync_client(self) -> Client:
def __make_sync_client(self, disable_ssl_verification: bool) -> Client:
return Client(
headers=self._headers,
timeout=Timeout(
None, connect=self.timeout_config.query, read=self.timeout_config.insert
),
mounts=self.__make_mounts("sync"),
verify=not disable_ssl_verification,
)

def __make_async_client(self) -> AsyncClient:
def __make_async_client(self, disable_ssl_verification: bool) -> AsyncClient:
return AsyncClient(
headers=self._headers,
timeout=Timeout(
None, connect=self.timeout_config.query, read=self.timeout_config.insert
),
mounts=self.__make_mounts("async"),
verify=not disable_ssl_verification,
)

def __make_clients(self) -> None:
self._client = self.__make_sync_client()
def __make_clients(self, disable_ssl_verification: bool) -> None:
self._client = self.__make_sync_client(disable_ssl_verification)

def _create_clients(
self, auth_client_secret: Optional[AuthCredentials], skip_init_checks: bool
self,
auth_client_secret: Optional[AuthCredentials],
skip_init_checks: bool,
disable_ssl_verification: bool,
) -> None:
# API keys are separate from OIDC and do not need any config from weaviate
if auth_client_secret is not None and isinstance(auth_client_secret, AuthApiKey):
self.__make_clients()
self.__make_clients(disable_ssl_verification)
return

if "authorization" in self._headers and auth_client_secret is None:
self.__make_clients()
self.__make_clients(disable_ssl_verification)
return

# no need to check OIDC if no auth is provided and users dont want any checks at initialization time
if skip_init_checks and auth_client_secret is None:
self.__make_clients()
self.__make_clients(disable_ssl_verification)
return

oidc_url = self.url + self._api_version_path + "/.well-known/openid-configuration"
with self.__make_sync_client() as client:
with self.__make_sync_client(disable_ssl_verification=disable_ssl_verification) as client:
try:
response = client.get(oidc_url)
except Exception as e:
Expand All @@ -269,7 +274,7 @@ def _create_clients(
resp = response.json()
except Exception:
_Warnings.auth_cannot_parse_oidc_config(oidc_url)
self.__make_clients()
self.__make_clients(disable_ssl_verification=disable_ssl_verification)
return

if auth_client_secret is not None:
Expand Down Expand Up @@ -309,9 +314,9 @@ def _create_clients(
raise AuthenticationFailedError(msg)
elif response.status_code == 404 and auth_client_secret is not None:
_Warnings.auth_with_anon_weaviate()
self.__make_clients()
self.__make_clients(disable_ssl_verification)
else:
self.__make_clients()
self.__make_clients(disable_ssl_verification)

def get_current_bearer_token(self) -> str:
if not self.is_connected():
Expand Down Expand Up @@ -376,9 +381,9 @@ def periodic_refresh_token(refresh_time: int, _auth: Optional[_Auth[OAuth2Client
)
demon.start()

async def aopen(self) -> None:
async def aopen(self, disable_ssl_verification: bool = False) -> None:
if self._aclient is None:
self._aclient = await self.__make_async_client().__aenter__()
self._aclient = await self.__make_async_client(disable_ssl_verification).__aenter__()
if self._grpc_stub_async is None:
self._grpc_channel_async = self._connection_params._grpc_channel(
async_channel=True, proxies=self._proxies
Expand Down Expand Up @@ -453,7 +458,7 @@ def __send(
except RuntimeError as e:
raise WeaviateClosedClientError() from e
except ConnectError as conn_err:
raise WeaviateConnectionError(error_msg) from conn_err
raise WeaviateConnectionError(f"{conn_err} {error_msg}")

def delete(
self,
Expand Down Expand Up @@ -707,8 +712,8 @@ def _ping_grpc(self) -> None:
f"v{self.server_version}", self._connection_params._grpc_address
) from e

def connect(self, skip_init_checks: bool) -> None:
super().connect(skip_init_checks)
def connect(self, skip_init_checks: bool, disable_ssl_verification: bool) -> None:
super().connect(skip_init_checks, disable_ssl_verification)
# create GRPC channel. If Weaviate does not support GRPC then error now.
self._grpc_channel = self._connection_params._grpc_channel(
async_channel=False, proxies=self._proxies
Expand Down
Loading