diff --git a/.github/workflows/pr_local_integration_tests.yml b/.github/workflows/pr_local_integration_tests.yml
index 2825b96f482..d3488cd08c3 100644
--- a/.github/workflows/pr_local_integration_tests.yml
+++ b/.github/workflows/pr_local_integration_tests.yml
@@ -50,7 +50,7 @@ jobs:
         uses: actions/cache@v4
         with:
           path: ${{ steps.uv-cache.outputs.dir }}
-          key: ${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-uv-${{ hashFiles(format('**/py{0}-ci-requirements.txt', env.PYTHON)) }}
+          key: ${{ runner.os }}-${{ matrix.python-version }}-uv-${{ hashFiles(format('**/py{0}-ci-requirements.txt', matrix.python-version)) }}
       - name: Install dependencies
         run: make install-python-dependencies-ci
       - name: Test local integration tests
diff --git a/Makefile b/Makefile
index de2ee568b68..bef7437bc8a 100644
--- a/Makefile
+++ b/Makefile
@@ -268,7 +268,7 @@ test-python-universal-postgres-online:
 				not test_snowflake" \
  			sdk/python/tests
 
- test-python-universal-mysql-online:
+test-python-universal-mysql-online:
 	PYTHONPATH='.' \
 		FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.mysql_online_store.mysql_repo_configuration \
 		PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.mysql \
@@ -292,7 +292,11 @@ test-python-universal-cassandra:
 	FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.cassandra_online_store.cassandra_repo_configuration \
 	PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.cassandra \
 	python -m pytest -x --integration \
-	sdk/python/tests
+	sdk/python/tests/integration/offline_store/test_feature_logging.py \
+		--ignore=sdk/python/tests/integration/offline_store/test_validation.py \
+		-k "not test_snowflake and \
+			not test_spark_materialization_consistency and \
+			not test_universal_materialization"
 
 test-python-universal-hazelcast:
 	PYTHONPATH='.' \
@@ -330,7 +334,7 @@ test-python-universal-cassandra-no-cloud-providers:
 	  not test_snowflake" \
 	sdk/python/tests
 
- test-python-universal-elasticsearch-online:
+test-python-universal-elasticsearch-online:
 	PYTHONPATH='.' \
 		FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.elasticsearch_online_store.elasticsearch_repo_configuration \
 		PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.elasticsearch \
@@ -349,6 +353,14 @@ test-python-universal-cassandra-no-cloud-providers:
 				not test_snowflake" \
  			sdk/python/tests
 
+test-python-universal-milvus-online:
+	PYTHONPATH='.' \
+		FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.milvus_online_store.milvus_repo_configuration \
+		PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.milvus \
+		python -m pytest -n 8 --integration \
+		-k "test_retrieve_online_milvus_ocuments" \
+ 			sdk/python/tests --ignore=sdk/python/tests/integration/offline_store/test_dqm_validation.py
+
 test-python-universal-singlestore-online:
 	PYTHONPATH='.' \
 		FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.singlestore_repo_configuration \
diff --git a/docs/how-to-guides/customizing-feast/adding-support-for-a-new-online-store.md b/docs/how-to-guides/customizing-feast/adding-support-for-a-new-online-store.md
index 5e26f133cef..ee75aa6b74f 100644
--- a/docs/how-to-guides/customizing-feast/adding-support-for-a-new-online-store.md
+++ b/docs/how-to-guides/customizing-feast/adding-support-for-a-new-online-store.md
@@ -25,7 +25,7 @@ OnlineStore class names must end with the OnlineStore suffix!
 
 ### Contrib online stores
 
-New online stores go in `sdk/python/feast/infra/online_stores/contrib/`.
+New online stores go in `sdk/python/feast/infra/online_stores/`.
 
 #### What is a contrib plugin?
 
diff --git a/sdk/python/docs/source/feast.infra.online_stores.milvus_online_store.rst b/sdk/python/docs/source/feast.infra.online_stores.milvus_online_store.rst
index ee9faa55dc0..5ae3015bf37 100644
--- a/sdk/python/docs/source/feast.infra.online_stores.milvus_online_store.rst
+++ b/sdk/python/docs/source/feast.infra.online_stores.milvus_online_store.rst
@@ -4,6 +4,14 @@ feast.infra.online\_stores.milvus\_online\_store package
 Submodules
 ----------
 
+feast.infra.online\_stores.milvus\_online\_store.milvus module
+--------------------------------------------------------------
+
+.. automodule:: feast.infra.online_stores.milvus_online_store.milvus
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
 feast.infra.online\_stores.milvus\_online\_store.milvus\_repo\_configuration module
 -----------------------------------------------------------------------------------
 
diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
new file mode 100644
index 00000000000..a1a4a3a5fe5
--- /dev/null
+++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
@@ -0,0 +1,428 @@
+from datetime import datetime
+from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
+
+from pydantic import StrictStr
+from pymilvus import (
+    Collection,
+    CollectionSchema,
+    DataType,
+    FieldSchema,
+    connections,
+)
+from pymilvus.orm.connections import Connections
+
+from feast import Entity
+from feast.feature_view import FeatureView
+from feast.infra.infra_object import InfraObject
+from feast.infra.key_encoding_utils import (
+    serialize_entity_key,
+)
+from feast.infra.online_stores.online_store import OnlineStore
+from feast.infra.online_stores.vector_store import VectorStoreConfig
+from feast.protos.feast.core.InfraObject_pb2 import InfraObject as InfraObjectProto
+from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
+from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
+from feast.protos.feast.types.Value_pb2 import Value as ValueProto
+from feast.repo_config import FeastConfigBaseModel, RepoConfig
+from feast.type_map import PROTO_VALUE_TO_VALUE_TYPE_MAP
+from feast.types import (
+    VALUE_TYPES_TO_FEAST_TYPES,
+    Array,
+    ComplexFeastType,
+    PrimitiveFeastType,
+    ValueType,
+)
+from feast.utils import (
+    _build_retrieve_online_document_record,
+    _serialize_vector_to_float_list,
+    to_naive_utc,
+)
+
+PROTO_TO_MILVUS_TYPE_MAPPING: Dict[ValueType, DataType] = {
+    PROTO_VALUE_TO_VALUE_TYPE_MAP["bytes_val"]: DataType.VARCHAR,
+    PROTO_VALUE_TO_VALUE_TYPE_MAP["bool_val"]: DataType.BOOL,
+    PROTO_VALUE_TO_VALUE_TYPE_MAP["string_val"]: DataType.VARCHAR,
+    PROTO_VALUE_TO_VALUE_TYPE_MAP["float_val"]: DataType.FLOAT,
+    PROTO_VALUE_TO_VALUE_TYPE_MAP["double_val"]: DataType.DOUBLE,
+    PROTO_VALUE_TO_VALUE_TYPE_MAP["int32_val"]: DataType.INT32,
+    PROTO_VALUE_TO_VALUE_TYPE_MAP["int64_val"]: DataType.INT64,
+    PROTO_VALUE_TO_VALUE_TYPE_MAP["float_list_val"]: DataType.FLOAT_VECTOR,
+    PROTO_VALUE_TO_VALUE_TYPE_MAP["int32_list_val"]: DataType.FLOAT_VECTOR,
+    PROTO_VALUE_TO_VALUE_TYPE_MAP["int64_list_val"]: DataType.FLOAT_VECTOR,
+    PROTO_VALUE_TO_VALUE_TYPE_MAP["double_list_val"]: DataType.FLOAT_VECTOR,
+    PROTO_VALUE_TO_VALUE_TYPE_MAP["bool_list_val"]: DataType.BINARY_VECTOR,
+}
+
+FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING: Dict[
+    Union[PrimitiveFeastType, Array, ComplexFeastType], DataType
+] = {}
+
+for value_type, feast_type in VALUE_TYPES_TO_FEAST_TYPES.items():
+    if isinstance(feast_type, PrimitiveFeastType):
+        milvus_type = PROTO_TO_MILVUS_TYPE_MAPPING.get(value_type)
+        if milvus_type:
+            FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = milvus_type
+    elif isinstance(feast_type, Array):
+        base_type = feast_type.base_type
+        base_value_type = base_type.to_value_type()
+        if base_value_type in [
+            ValueType.INT32,
+            ValueType.INT64,
+            ValueType.FLOAT,
+            ValueType.DOUBLE,
+        ]:
+            FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.FLOAT_VECTOR
+        elif base_value_type == ValueType.STRING:
+            FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.VARCHAR
+        elif base_value_type == ValueType.BOOL:
+            FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.BINARY_VECTOR
+
+
+class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
+    """
+    Configuration for the Milvus online store.
+    NOTE: The class *must* end with the `OnlineStoreConfig` suffix.
+    """
+
+    type: Literal["milvus"] = "milvus"
+
+    host: Optional[StrictStr] = "localhost"
+    port: Optional[int] = 19530
+    index_type: Optional[str] = "IVF_FLAT"
+    metric_type: Optional[str] = "L2"
+    embedding_dim: Optional[int] = 128
+    vector_enabled: Optional[bool] = True
+    nlist: Optional[int] = 128
+
+
+class MilvusOnlineStore(OnlineStore):
+    """
+    Milvus implementation of the online store interface.
+
+    Attributes:
+        _collections: Dictionary to cache Milvus collections.
+    """
+
+    _conn: Optional[Connections] = None
+    _collections: Dict[str, Collection] = {}
+
+    def _connect(self, config: RepoConfig) -> connections:
+        if not self._conn:
+            if not connections.has_connection("feast"):
+                self._conn = connections.connect(
+                    alias="feast",
+                    host=config.online_store.host,
+                    port=str(config.online_store.port),
+                )
+        return self._conn
+
+    def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
+        collection_name = _table_id(config.project, table)
+        if collection_name not in self._collections:
+            self._connect(config)
+
+            # Create a composite key by combining entity fields
+            composite_key_name = (
+                "_".join([field.name for field in table.entity_columns]) + "_pk"
+            )
+
+            fields = [
+                FieldSchema(
+                    name=composite_key_name,
+                    dtype=DataType.VARCHAR,
+                    max_length=512,
+                    is_primary=True,
+                ),
+                FieldSchema(name="event_ts", dtype=DataType.INT64),
+                FieldSchema(name="created_ts", dtype=DataType.INT64),
+            ]
+            fields_to_exclude = [
+                "event_ts",
+                "created_ts",
+            ]
+            fields_to_add = [f for f in table.schema if f.name not in fields_to_exclude]
+            for field in fields_to_add:
+                dtype = FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING.get(field.dtype)
+                if dtype:
+                    if dtype == DataType.FLOAT_VECTOR:
+                        fields.append(
+                            FieldSchema(
+                                name=field.name,
+                                dtype=dtype,
+                                dim=config.online_store.embedding_dim,
+                            )
+                        )
+                    elif dtype == DataType.VARCHAR:
+                        fields.append(
+                            FieldSchema(
+                                name=field.name,
+                                dtype=dtype,
+                                max_length=512,
+                            )
+                        )
+                    else:
+                        fields.append(FieldSchema(name=field.name, dtype=dtype))
+
+            schema = CollectionSchema(
+                fields=fields, description="Feast feature view data"
+            )
+            collection = Collection(name=collection_name, schema=schema, using="feast")
+            if not collection.has_index():
+                index_params = {
+                    "index_type": config.online_store.index_type,
+                    "metric_type": config.online_store.metric_type,
+                    "params": {"nlist": config.online_store.nlist},
+                }
+            for vector_field in schema.fields:
+                if vector_field.dtype in [
+                    DataType.FLOAT_VECTOR,
+                    DataType.BINARY_VECTOR,
+                ]:
+                    collection.create_index(
+                        field_name=vector_field.name, index_params=index_params
+                    )
+            collection.load()
+            self._collections[collection_name] = collection
+        return self._collections[collection_name]
+
+    def online_write_batch(
+        self,
+        config: RepoConfig,
+        table: FeatureView,
+        data: List[
+            Tuple[
+                EntityKeyProto,
+                Dict[str, ValueProto],
+                datetime,
+                Optional[datetime],
+            ]
+        ],
+        progress: Optional[Callable[[int], Any]],
+    ) -> None:
+        collection = self._get_collection(config, table)
+        entity_batch_to_insert = []
+        for entity_key, values_dict, timestamp, created_ts in data:
+            # need to construct the composite primary key also need to handle the fact that entities are a list
+            entity_key_str = serialize_entity_key(
+                entity_key,
+                entity_key_serialization_version=config.entity_key_serialization_version,
+            ).hex()
+            composite_key_name = (
+                "_".join([str(value) for value in entity_key.join_keys]) + "_pk"
+            )
+            timestamp_int = int(to_naive_utc(timestamp).timestamp() * 1e6)
+            created_ts_int = (
+                int(to_naive_utc(created_ts).timestamp() * 1e6) if created_ts else 0
+            )
+            values_dict = _extract_proto_values_to_dict(values_dict)
+            entity_dict = _extract_proto_values_to_dict(
+                dict(zip(entity_key.join_keys, entity_key.entity_values))
+            )
+            values_dict.update(entity_dict)
+
+            single_entity_record = {
+                composite_key_name: entity_key_str,
+                "event_ts": timestamp_int,
+                "created_ts": created_ts_int,
+            }
+            single_entity_record.update(values_dict)
+            entity_batch_to_insert.append(single_entity_record)
+
+            if progress:
+                progress(1)
+
+        collection.insert(entity_batch_to_insert)
+        collection.flush()
+
+    def online_read(
+        self,
+        config: RepoConfig,
+        table: FeatureView,
+        entity_keys: List[EntityKeyProto],
+        requested_features: Optional[List[str]] = None,
+    ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
+        raise NotImplementedError
+
+    def update(
+        self,
+        config: RepoConfig,
+        tables_to_delete: Sequence[FeatureView],
+        tables_to_keep: Sequence[FeatureView],
+        entities_to_delete: Sequence[Entity],
+        entities_to_keep: Sequence[Entity],
+        partial: bool,
+    ):
+        self._connect(config)
+        for table in tables_to_keep:
+            self._get_collection(config, table)
+        for table in tables_to_delete:
+            collection_name = _table_id(config.project, table)
+            collection = Collection(name=collection_name)
+            if collection.exists():
+                collection.drop()
+                self._collections.pop(collection_name, None)
+
+    def plan(
+        self, config: RepoConfig, desired_registry_proto: RegistryProto
+    ) -> List[InfraObject]:
+        raise NotImplementedError
+
+    def teardown(
+        self,
+        config: RepoConfig,
+        tables: Sequence[FeatureView],
+        entities: Sequence[Entity],
+    ):
+        self._connect(config)
+        for table in tables:
+            collection = self._get_collection(config, table)
+            if collection:
+                collection.drop()
+                self._collections.pop(collection.name, None)
+
+    def retrieve_online_documents(
+        self,
+        config: RepoConfig,
+        table: FeatureView,
+        requested_feature: Optional[str],
+        requested_features: Optional[List[str]],
+        embedding: List[float],
+        top_k: int,
+        distance_metric: Optional[str] = None,
+    ) -> List[
+        Tuple[
+            Optional[datetime],
+            Optional[EntityKeyProto],
+            Optional[ValueProto],
+            Optional[ValueProto],
+            Optional[ValueProto],
+        ]
+    ]:
+        collection = self._get_collection(config, table)
+        if not config.online_store.vector_enabled:
+            raise ValueError("Vector search is not enabled in the online store config")
+
+        search_params = {
+            "metric_type": distance_metric or config.online_store.metric_type,
+            "params": {"nprobe": 10},
+        }
+        expr = f"feature_name == '{requested_feature}'"
+
+        composite_key_name = (
+            "_".join([str(field.name) for field in table.entity_columns]) + "_pk"
+        )
+        if requested_features:
+            features_str = ", ".join([f"'{f}'" for f in requested_features])
+            expr += f" && feature_name in [{features_str}]"
+
+        output_fields = (
+            [composite_key_name]
+            + (requested_features if requested_features else [])
+            + ["created_ts", "event_ts"]
+        )
+        assert all(
+            field
+            for field in output_fields
+            if field in [f.name for f in collection.schema.fields]
+        ), f"field(s) [{[field for field in output_fields if field not in [f.name for f in collection.schema.fields]]}'] not found in collection schema"
+
+        # Note we choose the first vector field as the field to search on. Not ideal but it's something.
+        ann_search_field = None
+        for field in collection.schema.fields:
+            if (
+                field.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]
+                and field.name in output_fields
+            ):
+                ann_search_field = field.name
+                break
+
+        results = collection.search(
+            data=[embedding],
+            anns_field=ann_search_field,
+            param=search_params,
+            limit=top_k,
+            output_fields=output_fields,
+            consistency_level="Strong",
+        )
+
+        result_list = []
+        for hits in results:
+            for hit in hits:
+                single_record = {}
+                for field in output_fields:
+                    single_record[field] = hit.entity.get(field)
+
+                entity_key_bytes = bytes.fromhex(hit.entity.get(composite_key_name))
+                embedding = hit.entity.get(ann_search_field)
+                serialized_embedding = _serialize_vector_to_float_list(embedding)
+                distance = hit.distance
+                event_ts = datetime.fromtimestamp(hit.entity.get("event_ts") / 1e6)
+                prepared_result = _build_retrieve_online_document_record(
+                    entity_key_bytes,
+                    # This may have a bug
+                    serialized_embedding.SerializeToString(),
+                    embedding,
+                    distance,
+                    event_ts,
+                    config.entity_key_serialization_version,
+                )
+                result_list.append(prepared_result)
+        return result_list
+
+
+def _table_id(project: str, table: FeatureView) -> str:
+    return f"{project}_{table.name}"
+
+
+def _extract_proto_values_to_dict(input_dict: Dict[str, Any]) -> Dict[str, Any]:
+    numeric_vector_list_types = [
+        k
+        for k in PROTO_VALUE_TO_VALUE_TYPE_MAP.keys()
+        if k is not None and "list" in k and "string" not in k
+    ]
+    output_dict = {}
+    for feature_name, feature_values in input_dict.items():
+        for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP:
+            if feature_values.HasField(proto_val_type):
+                if proto_val_type in numeric_vector_list_types:
+                    vector_values = getattr(feature_values, proto_val_type).val
+                else:
+                    vector_values = getattr(feature_values, proto_val_type)
+                output_dict[feature_name] = vector_values
+    return output_dict
+
+
+class MilvusTable(InfraObject):
+    """
+    A Milvus collection managed by Feast.
+
+    Attributes:
+        host: The host of the Milvus server.
+        port: The port of the Milvus server.
+        name: The name of the collection.
+    """
+
+    host: str
+    port: int
+
+    def __init__(self, host: str, port: int, name: str):
+        super().__init__(name)
+        self.host = host
+        self.port = port
+        self._connect()
+
+    def _connect(self):
+        return connections.connect(alias="default", host=self.host, port=str(self.port))
+
+    def to_infra_object_proto(self) -> InfraObjectProto:
+        # Implement serialization if needed
+        raise NotImplementedError
+
+    def update(self):
+        # Implement update logic if needed
+        raise NotImplementedError
+
+    def teardown(self):
+        collection = Collection(name=self.name)
+        if collection.exists():
+            collection.drop()
diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py
index fe34a12adf8..2b8d5174e1f 100644
--- a/sdk/python/feast/repo_config.py
+++ b/sdk/python/feast/repo_config.py
@@ -81,6 +81,7 @@
     "singlestore": "feast.infra.online_stores.singlestore_online_store.singlestore.SingleStoreOnlineStore",
     "qdrant": "feast.infra.online_stores.cqdrant.QdrantOnlineStore",
     "couchbase": "feast.infra.online_stores.couchbase_online_store.couchbase.CouchbaseOnlineStore",
+    "milvus": "feast.infra.online_stores.milvus_online_store.milvus.MilvusOnlineStore",
     **LEGACY_ONLINE_STORE_CLASS_FOR_TYPE,
 }
 
diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py
index 8a88c24ffc1..000e9cdae4e 100644
--- a/sdk/python/feast/type_map.py
+++ b/sdk/python/feast/type_map.py
@@ -523,6 +523,24 @@ def python_values_to_proto_values(
     return proto_values
 
 
+PROTO_VALUE_TO_VALUE_TYPE_MAP: Dict[str, ValueType] = {
+    "int32_val": ValueType.INT32,
+    "int64_val": ValueType.INT64,
+    "double_val": ValueType.DOUBLE,
+    "float_val": ValueType.FLOAT,
+    "string_val": ValueType.STRING,
+    "bytes_val": ValueType.BYTES,
+    "bool_val": ValueType.BOOL,
+    "int32_list_val": ValueType.INT32_LIST,
+    "int64_list_val": ValueType.INT64_LIST,
+    "double_list_val": ValueType.DOUBLE_LIST,
+    "float_list_val": ValueType.FLOAT_LIST,
+    "string_list_val": ValueType.STRING_LIST,
+    "bytes_list_val": ValueType.BYTES_LIST,
+    "bool_list_val": ValueType.BOOL_LIST,
+}
+
+
 def _proto_value_to_value_type(proto_value: ProtoValue) -> ValueType:
     """
     Returns Feast ValueType given Feast ValueType string.
@@ -534,25 +552,9 @@ def _proto_value_to_value_type(proto_value: ProtoValue) -> ValueType:
         A variant of ValueType.
     """
     proto_str = proto_value.WhichOneof("val")
-    type_map = {
-        "int32_val": ValueType.INT32,
-        "int64_val": ValueType.INT64,
-        "double_val": ValueType.DOUBLE,
-        "float_val": ValueType.FLOAT,
-        "string_val": ValueType.STRING,
-        "bytes_val": ValueType.BYTES,
-        "bool_val": ValueType.BOOL,
-        "int32_list_val": ValueType.INT32_LIST,
-        "int64_list_val": ValueType.INT64_LIST,
-        "double_list_val": ValueType.DOUBLE_LIST,
-        "float_list_val": ValueType.FLOAT_LIST,
-        "string_list_val": ValueType.STRING_LIST,
-        "bytes_list_val": ValueType.BYTES_LIST,
-        "bool_list_val": ValueType.BOOL_LIST,
-        None: ValueType.NULL,
-    }
-
-    return type_map[proto_str]
+    if proto_str is None:
+        return ValueType.UNKNOWN
+    return PROTO_VALUE_TO_VALUE_TYPE_MAP[proto_str]
 
 
 def pa_to_feast_value_type(pa_type_as_str: str) -> ValueType:
diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py
index 570a6d4f8d5..3d1f9219991 100644
--- a/sdk/python/tests/foo_provider.py
+++ b/sdk/python/tests/foo_provider.py
@@ -150,6 +150,7 @@ def retrieve_online_documents(
         config: RepoConfig,
         table: FeatureView,
         requested_feature: str,
+        requested_features: Optional[List[str]],
         query: List[float],
         top_k: int,
         distance_metric: Optional[str] = None,
diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py
index 4074dcb194e..d337d365e9b 100644
--- a/sdk/python/tests/integration/online_store/test_universal_online.py
+++ b/sdk/python/tests/integration/online_store/test_universal_online.py
@@ -614,6 +614,10 @@ def eventually_apply() -> Tuple[None, bool]:
     online_features = fs.get_online_features(
         features=features, entity_rows=entity_rows
     ).to_dict()
+
+    # Debugging print statement
+    print("Online features values:", online_features["value"])
+
     assert all(v is None for v in online_features["value"])
 
 
@@ -891,3 +895,28 @@ def test_retrieve_online_documents(vectordb_environment, fake_document_data):
             top_k=2,
             distance_metric="wrong",
         ).to_dict()
+
+
+@pytest.mark.integration
+@pytest.mark.universal_online_stores(only=["milvus"])
+def test_retrieve_online_milvus_documents(vectordb_environment, fake_document_data):
+    fs = vectordb_environment.feature_store
+    df, data_source = fake_document_data
+    item_embeddings_feature_view = create_item_embeddings_feature_view(data_source)
+    fs.apply([item_embeddings_feature_view, item()])
+    fs.write_to_online_store("item_embeddings", df)
+    documents = fs.retrieve_online_documents(
+        feature=None,
+        features=[
+            "item_embeddings:embedding_float",
+            "item_embeddings:item_id",
+            "item_embeddings:string_feature",
+        ],
+        query=[1.0, 2.0],
+        top_k=2,
+        distance_metric="L2",
+    ).to_dict()
+    assert len(documents["embedding_float"]) == 2
+
+    assert len(documents["item_id"]) == 2
+    assert documents["item_id"] == [2, 3]