From 285a6777b165cd75cb350dc439e62fbdde33c8ba Mon Sep 17 00:00:00 2001
From: eavanvalkenburg <github@vanvalkenburg.eu>
Date: Thu, 23 Jan 2025 14:38:22 +0100
Subject: [PATCH] small test updates

---
 .../mongodb_atlas/mongodb_atlas_collection.py | 33 ++++++++--------
 .../mongodb_atlas/mongodb_atlas_settings.py   | 11 +++---
 .../test_mongodb_atlas_collection.py          | 38 ++++++++++++++++++-
 3 files changed, 61 insertions(+), 21 deletions(-)

diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py
index fc2a4446dbbb..bbd524019510 100644
--- a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py
+++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py
@@ -17,7 +17,11 @@
 from pymongo.asynchronous.database import AsyncDatabase
 from pymongo.driver_info import DriverInfo
 
-from semantic_kernel.connectors.memory.mongodb_atlas.const import MONGODB_ID_FIELD
+from semantic_kernel.connectors.memory.mongodb_atlas.const import (
+    DEFAULT_DB_NAME,
+    DEFAULT_SEARCH_INDEX_NAME,
+    MONGODB_ID_FIELD,
+)
 from semantic_kernel.connectors.memory.mongodb_atlas.utils import create_index_definition
 from semantic_kernel.data.filter_clauses import AnyTagsEqualTo, EqualTo
 from semantic_kernel.data.kernel_search_results import KernelSearchResults
@@ -57,9 +61,9 @@ class MongoDBAtlasCollection(
 
     def __init__(
         self,
+        collection_name: str,
         data_model_type: type[TModel],
         data_model_definition: VectorStoreRecordDefinition | None = None,
-        collection_name: str | None = None,
         index_name: str | None = None,
         mongo_client: AsyncMongoClient | None = None,
         **kwargs: Any,
@@ -81,17 +85,16 @@ def __init__(
                     env_file_encoding: str | None = None
 
         """
-        if not collection_name:
-            raise VectorStoreInitializationException("Collection name is required.")
-        if mongo_client and "database_name" in kwargs:
+        managed_client = not mongo_client
+        if mongo_client:
             super().__init__(
                 data_model_type=data_model_type,
                 data_model_definition=data_model_definition,
                 mongo_client=mongo_client,
                 collection_name=collection_name,
-                database_name=kwargs["database_name"],
-                index_name=index_name or f"{collection_name}_idx",
-                managed_client=False,
+                database_name=kwargs.get("database_name", DEFAULT_DB_NAME),
+                index_name=index_name or DEFAULT_SEARCH_INDEX_NAME,
+                managed_client=managed_client,
             )
             return
 
@@ -103,17 +106,15 @@ def __init__(
                 env_file_encoding=kwargs.get("env_file_encoding"),
                 connection_string=kwargs.get("connection_string"),
                 database_name=kwargs.get("database_name"),
+                index_name=index_name,
             )
         except ValidationError as exc:
             raise VectorStoreInitializationException("Failed to create MongoDB Atlas settings.") from exc
-        managed_client = not mongo_client
         if not mongo_client:
             mongo_client = AsyncMongoClient(
                 mongodb_atlas_settings.connection_string.get_secret_value(),
                 driver=DriverInfo("Microsoft Semantic Kernel", metadata.version("semantic-kernel")),
             )
-        if not mongodb_atlas_settings.database_name:
-            raise VectorStoreInitializationException("Database name is required.")
 
         super().__init__(
             data_model_type=data_model_type,
@@ -122,7 +123,7 @@ def __init__(
             mongo_client=mongo_client,
             managed_client=managed_client,
             database_name=mongodb_atlas_settings.database_name,
-            index_name=index_name or f"{collection_name}_idx",
+            index_name=mongodb_atlas_settings.index_name,
         )
 
     def _get_database(self) -> AsyncDatabase:
@@ -186,16 +187,18 @@ def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: A
     async def create_collection(self, **kwargs) -> None:
         """Create a new collection in MongoDB Atlas.
 
+        This first creates a collection, with the kwargs.
+        Then creates a search index based on the data model definition.
+
         Args:
             **kwargs: Additional keyword arguments.
         """
-        database = self._get_database()
-        collection = await database.create_collection(self.collection_name, **kwargs)
+        collection = await self._get_database().create_collection(self.collection_name, **kwargs)
         await collection.create_search_index(create_index_definition(self.data_model_definition, self.index_name))
 
     @override
     async def does_collection_exist(self, **kwargs) -> bool:
-        return self.collection_name in await self._get_database().list_collection_names()
+        return bool(await self._get_database().list_collection_names(filter={"name": self.collection_name}))
 
     @override
     async def delete_collection(self, **kwargs) -> None:
diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_settings.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_settings.py
index 118cd8f5c267..11a21183fcf2 100644
--- a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_settings.py
+++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_settings.py
@@ -1,8 +1,8 @@
 # Copyright (c) Microsoft. All rights reserved.
 
-from typing import Annotated, ClassVar
+from typing import ClassVar
 
-from pydantic import Field, SecretStr
+from pydantic import SecretStr
 
 from semantic_kernel.connectors.memory.mongodb_atlas.const import DEFAULT_DB_NAME, DEFAULT_SEARCH_INDEX_NAME
 from semantic_kernel.kernel_pydantic import KernelBaseSettings
@@ -17,12 +17,13 @@ class MongoDBAtlasSettings(KernelBaseSettings):
     - connection_string: str - MongoDB Atlas connection string
         (Env var MONGODB_ATLAS_CONNECTION_STRING)
     - database_name: str - MongoDB Atlas database name, defaults to 'default'
+        (Env var MONGODB_ATLAS_DATABASE_NAME)
+    - index_name: str - MongoDB Atlas search index name, defaults to 'default'
+        (Env var MONGODB_ATLAS_INDEX_NAME)
     """
 
     env_prefix: ClassVar[str] = "MONGODB_ATLAS_"
 
     connection_string: SecretStr
     database_name: str = DEFAULT_DB_NAME
-    index_name: Annotated[str, Field(deprecated="This field is not used with the new store and collection")] = (
-        DEFAULT_SEARCH_INDEX_NAME
-    )
+    index_name: str = DEFAULT_SEARCH_INDEX_NAME
diff --git a/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_collection.py b/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_collection.py
index f0bf621bfa50..00afe491e2a3 100644
--- a/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_collection.py
+++ b/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_collection.py
@@ -1,13 +1,15 @@
 # Copyright (c) Microsoft. All rights reserved.
 
-
 from unittest.mock import AsyncMock, patch
 
 from pymongo import AsyncMongoClient
 from pymongo.asynchronous.cursor import AsyncCursor
 from pymongo.results import UpdateResult
+from pytest import mark, raises
 
+from semantic_kernel.connectors.memory.mongodb_atlas.const import DEFAULT_DB_NAME, DEFAULT_SEARCH_INDEX_NAME
 from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_collection import MongoDBAtlasCollection
+from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreInitializationException
 
 
 def test_mongodb_atlas_collection_initialization(mongodb_atlas_unit_test_env, data_model_definition, mock_mongo_client):
@@ -21,6 +23,27 @@ def test_mongodb_atlas_collection_initialization(mongodb_atlas_unit_test_env, da
     assert isinstance(collection.mongo_client, AsyncMongoClient)
 
 
+@mark.parametrize("exclude_list", [["MONGODB_ATLAS_CONNECTION_STRING"]], indirect=True)
+def test_mongodb_atlas_collection_initialization_fail(mongodb_atlas_unit_test_env, data_model_definition):
+    with raises(VectorStoreInitializationException):
+        MongoDBAtlasCollection(
+            collection_name="test_collection",
+            data_model_type=dict,
+            data_model_definition=data_model_definition,
+        )
+
+
+@mark.parametrize("exclude_list", [["MONGODB_ATLAS_DATABASE_NAME", "MONGODB_ATLAS_INDEX_NAME"]], indirect=True)
+def test_mongodb_atlas_collection_initialization_defaults(mongodb_atlas_unit_test_env, data_model_definition):
+    collection = MongoDBAtlasCollection(
+        collection_name="test_collection",
+        data_model_type=dict,
+        data_model_definition=data_model_definition,
+    )
+    assert collection.database_name == DEFAULT_DB_NAME
+    assert collection.index_name == DEFAULT_SEARCH_INDEX_NAME
+
+
 async def test_mongodb_atlas_collection_upsert(mongodb_atlas_unit_test_env, data_model_definition, mock_get_collection):
     collection = MongoDBAtlasCollection(
         data_model_type=dict,
@@ -58,3 +81,16 @@ async def test_mongodb_atlas_collection_delete(mongodb_atlas_unit_test_env, data
     with patch.object(collection, "_get_collection", new=mock_get_collection) as mock_get:
         await collection._inner_delete(["test_id"])
         mock_get.return_value.delete_many.assert_called_with({"_id": {"$in": ["test_id"]}})
+
+
+async def test_mongodb_atlas_collection_collection_exists(
+    mongodb_atlas_unit_test_env, data_model_definition, mock_get_database
+):
+    collection = MongoDBAtlasCollection(
+        data_model_type=dict,
+        data_model_definition=data_model_definition,
+        collection_name="test_collection",
+    )
+    with patch.object(collection, "_get_database", new=mock_get_database) as mock_get:
+        mock_get.return_value.list_collection_names.return_value = ["test_collection"]
+        assert await collection.does_collection_exist()