Skip to content

Commit 3bb0dca

Browse files
fix: Validate entities when running get_online_features (#5031)
1 parent ec6f1b7 commit 3bb0dca

File tree

3 files changed

+137
-19
lines changed

3 files changed

+137
-19
lines changed

sdk/python/feast/utils.py

+38-15
Original file line numberDiff line numberDiff line change
@@ -687,25 +687,48 @@ def _get_unique_entities(
687687
entity_name_to_join_key_map,
688688
join_key_values,
689689
)
690+
# Validate that all expected join keys exist and have non-empty values.
691+
expected_keys = set(entity_name_to_join_key_map.values())
692+
expected_keys.discard("__dummy_id")
693+
missing_keys = sorted(
694+
list(set([key for key in expected_keys if key not in table_entity_values]))
695+
)
696+
empty_keys = sorted(
697+
list(set([key for key in expected_keys if not table_entity_values.get(key)]))
698+
)
690699

691-
# Convert back to rowise.
692-
keys = table_entity_values.keys()
693-
# Sort the rowise data to allow for grouping but keep original index. This lambda is
694-
# sufficient as Entity types cannot be complex (ie. lists).
700+
if missing_keys or empty_keys:
701+
if not any(table_entity_values.values()):
702+
raise KeyError(
703+
f"Missing join key values for keys: {missing_keys}. "
704+
f"No values provided for keys: {empty_keys}. "
705+
f"Provided join_key_values: {list(join_key_values.keys())}"
706+
)
707+
708+
# Convert the column-oriented table_entity_values into row-wise data.
709+
keys = list(table_entity_values.keys())
710+
# Each row is a tuple of ValueProto objects corresponding to the join keys.
695711
rowise = list(enumerate(zip(*table_entity_values.values())))
712+
713+
# If there are no rows, return empty tuples.
714+
if not rowise:
715+
return (), ()
716+
717+
# Sort rowise so that rows with the same join key values are adjacent.
696718
rowise.sort(key=lambda row: tuple(getattr(x, x.WhichOneof("val")) for x in row[1]))
697719

698-
# Identify unique entities and the indexes at which they occur.
699-
unique_entities: Tuple[Dict[str, ValueProto], ...]
700-
indexes: Tuple[List[int], ...]
701-
unique_entities, indexes = tuple(
702-
zip(
703-
*[
704-
(dict(zip(keys, k)), [_[0] for _ in g])
705-
for k, g in itertools.groupby(rowise, key=lambda x: x[1])
706-
]
707-
)
708-
)
720+
# Group rows by their composite join key value.
721+
groups = [
722+
(dict(zip(keys, key_tuple)), [idx for idx, _ in group])
723+
for key_tuple, group in itertools.groupby(rowise, key=lambda row: row[1])
724+
]
725+
726+
# If no groups were formed (should not happen for valid input), return empty tuples.
727+
if not groups:
728+
return (), ()
729+
730+
# Unpack the unique entities and their original row indexes.
731+
unique_entities, indexes = tuple(zip(*groups))
709732
return unique_entities, indexes
710733

711734

sdk/python/tests/unit/online_store/test_online_retrieval.py

+15
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,21 @@ def test_get_online_features() -> None:
137137

138138
assert "trips" in result
139139

140+
with pytest.raises(KeyError) as excinfo:
141+
_ = store.get_online_features(
142+
features=["driver_locations:lon"],
143+
entity_rows=[{"customer_id": 0}],
144+
full_feature_names=False,
145+
).to_dict()
146+
147+
error_message = str(excinfo.value)
148+
assert "Missing join key values for keys:" in error_message
149+
assert (
150+
"Missing join key values for keys: ['customer_id', 'driver_id', 'item_id']."
151+
in error_message
152+
)
153+
assert "Provided join_key_values: ['customer_id']" in error_message
154+
140155
result = store.get_online_features(
141156
features=["customer_profile_pandas_odfv:on_demand_age"],
142157
entity_rows=[{"driver_id": 1, "customer_id": "5"}],

sdk/python/tests/unit/test_unit_feature_store.py

+84-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
22
from typing import Dict, List
33

4+
import pytest
5+
46
from feast import utils
57
from feast.protos.feast.types.Value_pb2 import Value
68

@@ -17,7 +19,7 @@ class MockFeatureView:
1719
projection: MockFeatureViewProjection
1820

1921

20-
def test_get_unique_entities():
22+
def test_get_unique_entities_success():
2123
entity_values = {
2224
"entity_1": [Value(int64_val=1), Value(int64_val=2), Value(int64_val=1)],
2325
"entity_2": [
@@ -41,9 +43,87 @@ def test_get_unique_entities():
4143
join_key_values=entity_values,
4244
entity_name_to_join_key_map=entity_name_to_join_key_map,
4345
)
44-
45-
assert unique_entities == (
46+
expected_entities = (
4647
{"entity_1": Value(int64_val=1), "entity_2": Value(string_val="1")},
4748
{"entity_1": Value(int64_val=2), "entity_2": Value(string_val="2")},
4849
)
49-
assert indexes == ([0, 2], [1])
50+
expected_indexes = ([0, 2], [1])
51+
52+
assert unique_entities == expected_entities
53+
assert indexes == expected_indexes
54+
55+
56+
def test_get_unique_entities_missing_join_key_success():
57+
"""
58+
Tests that _get_unique_entities raises a KeyError when a required join key is missing.
59+
"""
60+
# Here, we omit the required key for "entity_1"
61+
entity_values = {
62+
"entity_2": [
63+
Value(string_val="1"),
64+
Value(string_val="2"),
65+
Value(string_val="1"),
66+
],
67+
}
68+
69+
entity_name_to_join_key_map = {"entity_1": "entity_1", "entity_2": "entity_2"}
70+
71+
fv = MockFeatureView(
72+
name="fv_1",
73+
entities=["entity_1", "entity_2"],
74+
projection=MockFeatureViewProjection(join_key_map={}),
75+
)
76+
77+
unique_entities, indexes = utils._get_unique_entities(
78+
table=fv,
79+
join_key_values=entity_values,
80+
entity_name_to_join_key_map=entity_name_to_join_key_map,
81+
)
82+
expected_entities = (
83+
{"entity_2": Value(string_val="1")},
84+
{"entity_2": Value(string_val="2")},
85+
)
86+
expected_indexes = ([0, 2], [1])
87+
88+
assert unique_entities == expected_entities
89+
assert indexes == expected_indexes
90+
# We're not say anything about the entity_1 missing from the unique_entities list
91+
assert "entity_1" not in [entity.keys() for entity in unique_entities]
92+
93+
94+
def test_get_unique_entities_missing_all_join_keys_error():
95+
"""
96+
Tests that _get_unique_entities raises a KeyError when all required join keys are missing.
97+
"""
98+
entity_values_not_in_feature_view = {
99+
"entity_3": [Value(string_val="3")],
100+
}
101+
entity_name_to_join_key_map = {
102+
"entity_1": "entity_1",
103+
"entity_2": "entity_2",
104+
"entity_3": "entity_3",
105+
}
106+
107+
fv = MockFeatureView(
108+
name="fv_1",
109+
entities=["entity_1", "entity_2"],
110+
projection=MockFeatureViewProjection(join_key_map={}),
111+
)
112+
113+
with pytest.raises(KeyError) as excinfo:
114+
utils._get_unique_entities(
115+
table=fv,
116+
join_key_values=entity_values_not_in_feature_view,
117+
entity_name_to_join_key_map=entity_name_to_join_key_map,
118+
)
119+
120+
error_message = str(excinfo.value)
121+
assert (
122+
"Missing join key values for keys: ['entity_1', 'entity_2', 'entity_3']"
123+
in error_message
124+
)
125+
assert (
126+
"No values provided for keys: ['entity_1', 'entity_2', 'entity_3']"
127+
in error_message
128+
)
129+
assert "Provided join_key_values: ['entity_3']" in error_message

0 commit comments

Comments
 (0)