From 285a6777b165cd75cb350dc439e62fbdde33c8ba Mon Sep 17 00:00:00 2001 From: eavanvalkenburg <github@vanvalkenburg.eu> Date: Thu, 23 Jan 2025 14:38:22 +0100 Subject: [PATCH] small test updates --- .../mongodb_atlas/mongodb_atlas_collection.py | 33 ++++++++-------- .../mongodb_atlas/mongodb_atlas_settings.py | 11 +++--- .../test_mongodb_atlas_collection.py | 38 ++++++++++++++++++- 3 files changed, 61 insertions(+), 21 deletions(-) diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py index fc2a4446dbbb..bbd524019510 100644 --- a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py +++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py @@ -17,7 +17,11 @@ from pymongo.asynchronous.database import AsyncDatabase from pymongo.driver_info import DriverInfo -from semantic_kernel.connectors.memory.mongodb_atlas.const import MONGODB_ID_FIELD +from semantic_kernel.connectors.memory.mongodb_atlas.const import ( + DEFAULT_DB_NAME, + DEFAULT_SEARCH_INDEX_NAME, + MONGODB_ID_FIELD, +) from semantic_kernel.connectors.memory.mongodb_atlas.utils import create_index_definition from semantic_kernel.data.filter_clauses import AnyTagsEqualTo, EqualTo from semantic_kernel.data.kernel_search_results import KernelSearchResults @@ -57,9 +61,9 @@ class MongoDBAtlasCollection( def __init__( self, + collection_name: str, data_model_type: type[TModel], data_model_definition: VectorStoreRecordDefinition | None = None, - collection_name: str | None = None, index_name: str | None = None, mongo_client: AsyncMongoClient | None = None, **kwargs: Any, @@ -81,17 +85,16 @@ def __init__( env_file_encoding: str | None = None """ - if not collection_name: - raise VectorStoreInitializationException("Collection name is required.") - if mongo_client and "database_name" in kwargs: + managed_client = not mongo_client + if mongo_client: super().__init__( data_model_type=data_model_type, data_model_definition=data_model_definition, mongo_client=mongo_client, collection_name=collection_name, - database_name=kwargs["database_name"], - index_name=index_name or f"{collection_name}_idx", - managed_client=False, + database_name=kwargs.get("database_name", DEFAULT_DB_NAME), + index_name=index_name or DEFAULT_SEARCH_INDEX_NAME, + managed_client=managed_client, ) return @@ -103,17 +106,15 @@ def __init__( env_file_encoding=kwargs.get("env_file_encoding"), connection_string=kwargs.get("connection_string"), database_name=kwargs.get("database_name"), + index_name=index_name, ) except ValidationError as exc: raise VectorStoreInitializationException("Failed to create MongoDB Atlas settings.") from exc - managed_client = not mongo_client if not mongo_client: mongo_client = AsyncMongoClient( mongodb_atlas_settings.connection_string.get_secret_value(), driver=DriverInfo("Microsoft Semantic Kernel", metadata.version("semantic-kernel")), ) - if not mongodb_atlas_settings.database_name: - raise VectorStoreInitializationException("Database name is required.") super().__init__( data_model_type=data_model_type, @@ -122,7 +123,7 @@ def __init__( mongo_client=mongo_client, managed_client=managed_client, database_name=mongodb_atlas_settings.database_name, - index_name=index_name or f"{collection_name}_idx", + index_name=mongodb_atlas_settings.index_name, ) def _get_database(self) -> AsyncDatabase: @@ -186,16 +187,18 @@ def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: A async def create_collection(self, **kwargs) -> None: """Create a new collection in MongoDB Atlas. + This first creates a collection, with the kwargs. + Then creates a search index based on the data model definition. + Args: **kwargs: Additional keyword arguments. """ - database = self._get_database() - collection = await database.create_collection(self.collection_name, **kwargs) + collection = await self._get_database().create_collection(self.collection_name, **kwargs) await collection.create_search_index(create_index_definition(self.data_model_definition, self.index_name)) @override async def does_collection_exist(self, **kwargs) -> bool: - return self.collection_name in await self._get_database().list_collection_names() + return bool(await self._get_database().list_collection_names(filter={"name": self.collection_name})) @override async def delete_collection(self, **kwargs) -> None: diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_settings.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_settings.py index 118cd8f5c267..11a21183fcf2 100644 --- a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_settings.py +++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_settings.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Annotated, ClassVar +from typing import ClassVar -from pydantic import Field, SecretStr +from pydantic import SecretStr from semantic_kernel.connectors.memory.mongodb_atlas.const import DEFAULT_DB_NAME, DEFAULT_SEARCH_INDEX_NAME from semantic_kernel.kernel_pydantic import KernelBaseSettings @@ -17,12 +17,13 @@ class MongoDBAtlasSettings(KernelBaseSettings): - connection_string: str - MongoDB Atlas connection string (Env var MONGODB_ATLAS_CONNECTION_STRING) - database_name: str - MongoDB Atlas database name, defaults to 'default' + (Env var MONGODB_ATLAS_DATABASE_NAME) + - index_name: str - MongoDB Atlas search index name, defaults to 'default' + (Env var MONGODB_ATLAS_INDEX_NAME) """ env_prefix: ClassVar[str] = "MONGODB_ATLAS_" connection_string: SecretStr database_name: str = DEFAULT_DB_NAME - index_name: Annotated[str, Field(deprecated="This field is not used with the new store and collection")] = ( - DEFAULT_SEARCH_INDEX_NAME - ) + index_name: str = DEFAULT_SEARCH_INDEX_NAME diff --git a/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_collection.py b/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_collection.py index f0bf621bfa50..00afe491e2a3 100644 --- a/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_collection.py +++ b/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_collection.py @@ -1,13 +1,15 @@ # Copyright (c) Microsoft. All rights reserved. - from unittest.mock import AsyncMock, patch from pymongo import AsyncMongoClient from pymongo.asynchronous.cursor import AsyncCursor from pymongo.results import UpdateResult +from pytest import mark, raises +from semantic_kernel.connectors.memory.mongodb_atlas.const import DEFAULT_DB_NAME, DEFAULT_SEARCH_INDEX_NAME from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_collection import MongoDBAtlasCollection +from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreInitializationException def test_mongodb_atlas_collection_initialization(mongodb_atlas_unit_test_env, data_model_definition, mock_mongo_client): @@ -21,6 +23,27 @@ def test_mongodb_atlas_collection_initialization(mongodb_atlas_unit_test_env, da assert isinstance(collection.mongo_client, AsyncMongoClient) +@mark.parametrize("exclude_list", [["MONGODB_ATLAS_CONNECTION_STRING"]], indirect=True) +def test_mongodb_atlas_collection_initialization_fail(mongodb_atlas_unit_test_env, data_model_definition): + with raises(VectorStoreInitializationException): + MongoDBAtlasCollection( + collection_name="test_collection", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + + +@mark.parametrize("exclude_list", [["MONGODB_ATLAS_DATABASE_NAME", "MONGODB_ATLAS_INDEX_NAME"]], indirect=True) +def test_mongodb_atlas_collection_initialization_defaults(mongodb_atlas_unit_test_env, data_model_definition): + collection = MongoDBAtlasCollection( + collection_name="test_collection", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + assert collection.database_name == DEFAULT_DB_NAME + assert collection.index_name == DEFAULT_SEARCH_INDEX_NAME + + async def test_mongodb_atlas_collection_upsert(mongodb_atlas_unit_test_env, data_model_definition, mock_get_collection): collection = MongoDBAtlasCollection( data_model_type=dict, @@ -58,3 +81,16 @@ async def test_mongodb_atlas_collection_delete(mongodb_atlas_unit_test_env, data with patch.object(collection, "_get_collection", new=mock_get_collection) as mock_get: await collection._inner_delete(["test_id"]) mock_get.return_value.delete_many.assert_called_with({"_id": {"$in": ["test_id"]}}) + + +async def test_mongodb_atlas_collection_collection_exists( + mongodb_atlas_unit_test_env, data_model_definition, mock_get_database +): + collection = MongoDBAtlasCollection( + data_model_type=dict, + data_model_definition=data_model_definition, + collection_name="test_collection", + ) + with patch.object(collection, "_get_database", new=mock_get_database) as mock_get: + mock_get.return_value.list_collection_names.return_value = ["test_collection"] + assert await collection.does_collection_exist()