Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Implement MongoDB Atlas store #10177

Merged
merged 8 commits into from
Feb 6, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
working tests
eavanvalkenburg committed Jan 29, 2025
commit 8d7220e022bcfd61dd923cb03516584ace382199
2 changes: 1 addition & 1 deletion python/.coveragerc
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ omit =
semantic_kernel/connectors/memory/azure_cosmosdb_no_sql/*
semantic_kernel/connectors/memory/chroma/*
semantic_kernel/connectors/memory/milvus/*
semantic_kernel/connectors/memory/mongodb_atlas/*
semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_memory_store.py
semantic_kernel/connectors/memory/pinecone/*
semantic_kernel/connectors/memory/postgres/*
semantic_kernel/connectors/memory/qdrant/qdrant_memory_store.py
Original file line number Diff line number Diff line change
@@ -10,3 +10,5 @@
DistanceFunction.DOT_PROD: "dotProduct",
}
MONGODB_ID_FIELD: Final[str] = "_id"
DEFAULT_DB_NAME = "default"
DEFAULT_SEARCH_INDEX_NAME = "default"
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import logging
import sys
from collections.abc import Sequence
from importlib import metadata
from typing import Any, ClassVar, Generic, TypeVar

if sys.version_info >= (3, 12):
@@ -14,6 +15,7 @@
from pymongo import AsyncMongoClient
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.driver_info import DriverInfo

from semantic_kernel.connectors.memory.mongodb_atlas.const import MONGODB_ID_FIELD
from semantic_kernel.connectors.memory.mongodb_atlas.utils import create_index_definition
@@ -58,9 +60,8 @@ def __init__(
data_model_type: type[TModel],
data_model_definition: VectorStoreRecordDefinition | None = None,
collection_name: str | None = None,
database_name: str | None = None,
mongo_client: AsyncMongoClient | None = None,
index_name: str | None = None,
mongo_client: AsyncMongoClient | None = None,
**kwargs: Any,
) -> None:
"""Initializes a new instance of the MongoDBAtlasCollection class.
@@ -69,26 +70,26 @@ def __init__(
data_model_type: The type of the data model.
data_model_definition: The model definition, optional.
collection_name: The name of the collection, optional.
database_name: The name of the database, will be filled from the env when this is not set.
mongo_client: The MongoDB client for interacting with MongoDB Atlas,
used for creating and deleting collections.
index_name: The name of the index to use for searching, when not passed, will use <collection_name>_idx.
**kwargs: Additional keyword arguments, including:
The same keyword arguments used for MongoDBAtlasStore:
database_name: The name of the database, will be filled from the env when this is not set.
connection_string: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None

"""
if not collection_name:
raise VectorStoreInitializationException("Collection name is required.")
if mongo_client and database_name:
if mongo_client and "database_name" in kwargs:
super().__init__(
data_model_type=data_model_type,
data_model_definition=data_model_definition,
mongo_client=mongo_client,
collection_name=collection_name,
database_name=database_name,
database_name=kwargs["database_name"],
index_name=index_name or f"{collection_name}_idx",
managed_client=False,
)
@@ -99,15 +100,18 @@ def __init__(
try:
mongodb_atlas_settings = MongoDBAtlasSettings.create(
env_file_path=kwargs.get("env_file_path"),
connection_string=kwargs.get("connection_string"),
database_name=database_name,
env_file_encoding=kwargs.get("env_file_encoding"),
connection_string=kwargs.get("connection_string"),
database_name=kwargs.get("database_name"),
)
except ValidationError as exc:
raise VectorStoreInitializationException("Failed to create MongoDB Atlas settings.") from exc
managed_client = not mongo_client
if not mongo_client:
mongo_client = AsyncMongoClient(mongodb_atlas_settings.connection_string)
mongo_client = AsyncMongoClient(
mongodb_atlas_settings.connection_string.get_secret_value(),
driver=DriverInfo("Microsoft Semantic Kernel", metadata.version("semantic-kernel")),
)
if not mongodb_atlas_settings.database_name:
raise VectorStoreInitializationException("Database name is required.")

@@ -122,11 +126,17 @@ def __init__(
)

def _get_database(self) -> AsyncDatabase:
"""Get the database."""
"""Get the database.

If you need control over things like read preference, you can override this method.
"""
return self.mongo_client.get_database(self.database_name)

def _get_collection(self) -> AsyncCollection:
"""Get the collection."""
"""Get the collection.

If you need control over things like read preference, you can override this method.
"""
return self.mongo_client.get_database(self.database_name).get_collection(self.collection_name)

@override
@@ -203,6 +213,7 @@ async def _inner_search(
collection = self._get_collection()
vector_search_query: dict[str, Any] = {
"limit": options.top + options.skip,
"index": self.index_name,
}
if options.filter.filters:
vector_search_query["filter"] = self._build_filter_dict(options.filter)
@@ -253,5 +264,10 @@ def _get_score_from_result(self, result: dict[str, Any]) -> float | None:
@override
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
"""Exit the context manager."""
if self.managed_mongo_client:
if self.managed_client:
await self.mongo_client.close()

async def __aenter__(self) -> "MongoDBAtlasCollection":
"""Enter the context manager."""
await self.mongo_client.aconnect()
return self
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import ClassVar
from typing import Annotated, ClassVar

from pydantic import SecretStr
from pydantic import Field, SecretStr

from semantic_kernel.connectors.memory.mongodb_atlas.utils import DEFAULT_DB_NAME, DEFAULT_SEARCH_INDEX_NAME
from semantic_kernel.connectors.memory.mongodb_atlas.const import DEFAULT_DB_NAME, DEFAULT_SEARCH_INDEX_NAME
from semantic_kernel.kernel_pydantic import KernelBaseSettings
from semantic_kernel.utils.experimental_decorator import experimental_class

@@ -16,10 +16,13 @@ class MongoDBAtlasSettings(KernelBaseSettings):
Args:
- connection_string: str - MongoDB Atlas connection string
(Env var MONGODB_ATLAS_CONNECTION_STRING)
- database_name: str - MongoDB Atlas database name, defaults to 'default'
"""

env_prefix: ClassVar[str] = "MONGODB_ATLAS_"

connection_string: SecretStr
database_name: str = DEFAULT_DB_NAME
index_name: str = DEFAULT_SEARCH_INDEX_NAME
index_name: Annotated[str, Field(deprecated="This field is not used with the new store and collection")] = (
DEFAULT_SEARCH_INDEX_NAME
)
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@

import logging
import sys
from importlib import metadata
from typing import TYPE_CHECKING, Any, TypeVar

if sys.version_info >= (3, 12):
@@ -12,6 +13,7 @@
from pydantic import ValidationError
from pymongo import AsyncMongoClient
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.driver_info import DriverInfo

from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_collection import (
MongoDBAtlasCollection,
@@ -61,18 +63,27 @@ def __init__(
MongoDBAtlasSettings,
)

if mongo_client and database_name:
super().__init__(
mongo_client=mongo_client,
managed_client=False,
database_name=database_name,
)
managed_client: bool = False
try:
mongodb_atlas_settings = MongoDBAtlasSettings.create(
env_file_path=env_file_path,
connection_string=connection_string,
database_name=database_name,
env_file_encoding=env_file_encoding,
)
except ValidationError as exc:
raise VectorStoreInitializationException("Failed to create MongoDB Atlas settings.") from exc
if not mongo_client:
try:
mongodb_atlas_settings = MongoDBAtlasSettings.create(
env_file_path=env_file_path,
connection_string=connection_string,
database_name=database_name,
env_file_encoding=env_file_encoding,
)
except ValidationError as exc:
raise VectorStoreInitializationException("Failed to create MongoDB Atlas settings.") from exc
mongo_client = AsyncMongoClient(mongodb_atlas_settings.connection_string)
mongo_client = AsyncMongoClient(
mongodb_atlas_settings.connection_string.get_secret_value(),
driver=DriverInfo("Microsoft Semantic Kernel", metadata.version("semantic-kernel")),
)
managed_client = True

super().__init__(
@@ -117,3 +128,8 @@ async def __aexit__(self, exc_type, exc_value, traceback) -> None:
"""Exit the context manager."""
if self.managed_client:
await self.mongo_client.close()

async def __aenter__(self) -> "MongoDBAtlasStore":
"""Enter the context manager."""
await self.mongo_client.aconnect()
return self
Original file line number Diff line number Diff line change
@@ -12,8 +12,6 @@
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
from semantic_kernel.memory.memory_record import MemoryRecord

DEFAULT_DB_NAME = "default"
DEFAULT_SEARCH_INDEX_NAME = "default"
NUM_CANDIDATES_SCALAR = 10

MONGODB_FIELD_ID = "_id"
22 changes: 22 additions & 0 deletions python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -367,6 +367,28 @@ def azure_ai_search_unit_test_env(monkeypatch, exclude_list, override_env_param_
return env_vars


@fixture()
def mongodb_atlas_unit_test_env(monkeypatch, exclude_list, override_env_param_dict):
"""Fixture to set environment variables for MongoDB Atlas Unit Tests."""
if exclude_list is None:
exclude_list = []

if override_env_param_dict is None:
override_env_param_dict = {}

env_vars = {"MONGODB_ATLAS_CONNECTION_STRING": "mongodb://test", "MONGODB_ATLAS_DATABASE_NAME": "test-database"}

env_vars.update(override_env_param_dict)

for key, value in env_vars.items():
if key not in exclude_list:
monkeypatch.setenv(key, value)
else:
monkeypatch.delenv(key, raising=False)

return env_vars


@fixture()
def bing_unit_test_env(monkeypatch, exclude_list, override_env_param_dict):
"""Fixture to set environment variables for BingConnector."""
37 changes: 37 additions & 0 deletions python/tests/unit/connectors/memory/mongodb_atlas/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft. All rights reserved.


from unittest.mock import patch

import pytest
from pymongo import AsyncMongoClient
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.database import AsyncDatabase

BASE_PATH = "pymongo.asynchronous.mongo_client.AsyncMongoClient"
DATABASE_PATH = "pymongo.asynchronous.database.AsyncDatabase"
COLLECTION_PATH = "pymongo.asynchronous.collection.AsyncCollection"


@pytest.fixture(autouse=True)
def mock_mongo_client():
with patch(BASE_PATH, spec=AsyncMongoClient) as mock:
yield mock


@pytest.fixture(autouse=True)
def mock_get_database(mock_mongo_client):
with (
patch(DATABASE_PATH, spec=AsyncDatabase) as mock_db,
patch.object(mock_mongo_client, "get_database", new_callable=lambda: mock_db) as mock,
):
yield mock


@pytest.fixture(autouse=True)
def mock_get_collection(mock_get_database):
with (
patch(COLLECTION_PATH, spec=AsyncCollection) as mock_collection,
patch.object(mock_get_database, "get_collection", new_callable=lambda: mock_collection) as mock,
):
yield mock
Original file line number Diff line number Diff line change
@@ -1,63 +1,60 @@
# Copyright (c) Microsoft. All rights reserved.

from unittest.mock import MagicMock, patch

import pytest
from pymongo import MongoClient
from unittest.mock import AsyncMock, patch

from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_collection import MongoDBAtlasCollection


@pytest.fixture
def mock_mongo_client():
with patch("pymongo.AsyncMongoClient") as mock:
yield mock
from pymongo import AsyncMongoClient
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.results import UpdateResult


@pytest.fixture
def mock_mongo_db(mock_mongo_client):
mock_db = MagicMock()
mock_mongo_client.return_value.get_database.return_value = mock_db
yield mock_db
from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_collection import MongoDBAtlasCollection


def test_mongodb_atlas_collection_initialization(mock_mongo_client):
def test_mongodb_atlas_collection_initialization(mongodb_atlas_unit_test_env, data_model_definition, mock_mongo_client):
collection = MongoDBAtlasCollection(
data_model_type=dict,
data_model_definition=data_model_definition,
collection_name="test_collection",
mongo_client=mock_mongo_client,
)
assert collection.mongo_client is not None
assert isinstance(collection.mongo_client, MongoClient)
assert isinstance(collection.mongo_client, AsyncMongoClient)


def test_mongodb_atlas_collection_upsert(mock_mongo_db):
async def test_mongodb_atlas_collection_upsert(mongodb_atlas_unit_test_env, data_model_definition, mock_get_collection):
collection = MongoDBAtlasCollection(
data_model_type=dict,
data_model_definition=data_model_definition,
collection_name="test_collection",
mongo_client=mock_mongo_db,
)
mock_mongo_db.get_collection.return_value.insert_one.return_value.inserted_id = "test_id"
result = collection._inner_upsert([{"_id": "test_id", "data": "test_data"}])
assert result == ["test_id"]
with patch.object(collection, "_get_collection", new=mock_get_collection) as mock_get:
result_mock = AsyncMock(spec=UpdateResult)
result_mock.upserted_id = ["test_id"]
mock_get.return_value.update_many.return_value = result_mock
result = await collection._inner_upsert([{"_id": "test_id", "data": "test_data"}])
assert result == ["test_id"]


def test_mongodb_atlas_collection_get(mock_mongo_db):
async def test_mongodb_atlas_collection_get(mongodb_atlas_unit_test_env, data_model_definition, mock_get_collection):
collection = MongoDBAtlasCollection(
data_model_type=dict,
data_model_definition=data_model_definition,
collection_name="test_collection",
mongo_client=mock_mongo_db,
)
mock_mongo_db.get_collection.return_value.find_one.return_value = {"_id": "test_id", "data": "test_data"}
result = collection._inner_get(["test_id"])
assert result == [{"_id": "test_id", "data": "test_data"}]
with patch.object(collection, "_get_collection", new=mock_get_collection) as mock_get:
result_mock = AsyncMock(spec=AsyncCursor)
result_mock.to_list.return_value = [{"_id": "test_id", "data": "test_data"}]
mock_get.return_value.find.return_value = result_mock
result = await collection._inner_get(["test_id"])
assert result == [{"_id": "test_id", "data": "test_data"}]


def test_mongodb_atlas_collection_delete(mock_mongo_db):
async def test_mongodb_atlas_collection_delete(mongodb_atlas_unit_test_env, data_model_definition, mock_get_collection):
collection = MongoDBAtlasCollection(
data_model_type=dict,
data_model_definition=data_model_definition,
collection_name="test_collection",
mongo_client=mock_mongo_db,
)
collection._inner_delete(["test_id"])
mock_mongo_db.get_collection.return_value.delete_one.assert_called_with({"_id": "test_id"})
with patch.object(collection, "_get_collection", new=mock_get_collection) as mock_get:
await collection._inner_delete(["test_id"])
mock_get.return_value.delete_many.assert_called_with({"_id": {"$in": ["test_id"]}})
Original file line number Diff line number Diff line change
@@ -1,47 +1,31 @@
# Copyright (c) Microsoft. All rights reserved.

from unittest.mock import MagicMock, patch

import pytest
from pymongo import MongoClient
from pymongo import AsyncMongoClient

from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_collection import MongoDBAtlasCollection
from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_store import MongoDBAtlasStore
from semantic_kernel.data.record_definition import VectorStoreRecordDefinition


@pytest.fixture
def mock_mongo_client():
with patch("pymongo.AsyncMongoClient") as mock:
yield mock


@pytest.fixture
def mock_mongo_db(mock_mongo_client):
mock_db = MagicMock()
mock_mongo_client.return_value.get_database.return_value = mock_db
yield mock_db


def test_mongodb_atlas_store_initialization(mock_mongo_client):
store = MongoDBAtlasStore(connection_string="mongodb://test", database_name="test_db")
def test_mongodb_atlas_store_initialization(mongodb_atlas_unit_test_env):
store = MongoDBAtlasStore()
assert store.mongo_client is not None
assert isinstance(store.mongo_client, MongoClient)
assert isinstance(store.mongo_client, AsyncMongoClient)


def test_mongodb_atlas_store_get_collection(mock_mongo_client):
store = MongoDBAtlasStore(connection_string="mongodb://test", database_name="test_db")
def test_mongodb_atlas_store_get_collection(mongodb_atlas_unit_test_env, data_model_definition):
store = MongoDBAtlasStore()
collection = store.get_collection(
collection_name="test_collection",
data_model_type=dict,
data_model_definition=VectorStoreRecordDefinition(),
data_model_definition=data_model_definition,
)
assert collection is not None
assert isinstance(collection, MongoDBAtlasCollection)


def test_mongodb_atlas_store_list_collection_names(mock_mongo_db):
store = MongoDBAtlasStore(connection_string="mongodb://test", database_name="test_db")
mock_mongo_db.list_collection_names.return_value = ["test_collection"]
result = store.list_collection_names()
async def test_mongodb_atlas_store_list_collection_names(mongodb_atlas_unit_test_env, mock_mongo_client):
store = MongoDBAtlasStore(mongo_client=mock_mongo_client, database_name="test_db")
store.mongo_client.get_database().list_collection_names.return_value = ["test_collection"]
result = await store.list_collection_names()
assert result == ["test_collection"]