diff --git a/docs/docs/in_depth/join_data.md b/docs/docs/in_depth/join_data.md index 12e97ba0..8ee1cef7 100644 --- a/docs/docs/in_depth/join_data.md +++ b/docs/docs/in_depth/join_data.md @@ -5,6 +5,7 @@ Combining datasets from various feature groups is crucial for building comprehen - **Different Compute Frameworks**: Merging data from feature groups that utilize different underlying compute technologies. - **Same Compute Framework, Different Sources**: Combining datasets that use the same compute framework but originate from different data sources. - **Same Feature Group, Different Feature Options**: Integrating data from the same feature group configured with different feature options. +- **Multiple Links on the Same Compute Framework**: Joining a single base FG to multiple right-side FGs (same class, subclasses, or distinct) all within the same compute framework. _**If we have a feature, which is dependent on a aforementioned setup,**_ @@ -221,6 +222,8 @@ link = Link.inner( The balanced inheritance rule ensures that joins only occur between "parallel" subclasses - both sides must be at the same level in the inheritance hierarchy relative to the link definition. +When multiple batches match (e.g. three subclasses all matching a base-class link), the engine disambiguates using `right_index`: first by exact `feature.index` match, then by checking whether the feature name appears in the join key columns of `right_index`. Ensure each link's `right_index` contains the column name that uniquely identifies its right-side batch. + #### mlodaAPI ``` python diff --git a/mloda/core/abstract_plugins/components/base_artifact.py b/mloda/core/abstract_plugins/components/base_artifact.py index 76242b03..a0777e3a 100644 --- a/mloda/core/abstract_plugins/components/base_artifact.py +++ b/mloda/core/abstract_plugins/components/base_artifact.py @@ -54,7 +54,7 @@ def custom_loader(cls, features: FeatureSet) -> Optional[Any]: options = cls.get_singular_option_from_options(features) if options is None or features.name_of_one_feature is None: return None - return options.get(features.name_of_one_feature.name) + return options[features.name_of_one_feature.name] @classmethod def get_singular_option_from_options(cls, features: FeatureSet) -> Options | None: diff --git a/mloda/core/abstract_plugins/components/input_data/api/api_input_data.py b/mloda/core/abstract_plugins/components/input_data/api/api_input_data.py index cded05eb..3196fab7 100644 --- a/mloda/core/abstract_plugins/components/input_data/api/api_input_data.py +++ b/mloda/core/abstract_plugins/components/input_data/api/api_input_data.py @@ -23,7 +23,7 @@ def matches( if not _data_access_name: raise ValueError(f"Data access name was not set for ApiInputData class {self.__class__.__name__}.") - api_input_data_column_names = options.get(_data_access_name) + api_input_data_column_names = options[_data_access_name] if api_input_data_column_names is None: return False diff --git a/mloda/core/abstract_plugins/components/options.py b/mloda/core/abstract_plugins/components/options.py index 8b7426d5..8b946e9a 100644 --- a/mloda/core/abstract_plugins/components/options.py +++ b/mloda/core/abstract_plugins/components/options.py @@ -120,6 +120,9 @@ def get(self, key: str) -> Any: return self.group[key] return self.context.get(key, None) + def __getitem__(self, key: str) -> Any: + return self.get(key) + def items(self) -> list[tuple[str, Any]]: """ Get all key-value pairs from both group and context. @@ -160,6 +163,9 @@ def set(self, key: str, value: Any) -> None: # New key, add to group by default self.group[key] = value + def __setitem__(self, key: str, value: Any) -> None: + self.set(key, value) + def get_in_features(self) -> "frozenset[Feature]": val = self.get(DefaultOptionKeys.in_features) diff --git a/mloda/core/filter/global_filter.py b/mloda/core/filter/global_filter.py index ed2d4761..181cd3de 100644 --- a/mloda/core/filter/global_filter.py +++ b/mloda/core/filter/global_filter.py @@ -103,14 +103,12 @@ def identity_matched_filters( def unify_options(self, feat_options: Options, filter_options: Options) -> Options: for key, value in feat_options.items(): if key not in filter_options: - filter_options.set(key, value) + filter_options[key] = value else: - if filter_options.get(key) == value: + if filter_options[key] == value: continue else: - logger.warning( - f"Options are not the same. {key} is different. {filter_options.get(key)} != {value}" - ) + logger.warning(f"Options are not the same. {key} is different. {filter_options[key]} != {value}") return filter_options def criteria( diff --git a/mloda/core/prepare/execution_plan.py b/mloda/core/prepare/execution_plan.py index 67b380a4..973c6685 100644 --- a/mloda/core/prepare/execution_plan.py +++ b/mloda/core/prepare/execution_plan.py @@ -625,9 +625,7 @@ def case_link_fw_is_equal_to_children_fw( if len(feature_set_collection_per_uuid) == 0: raise ValueError("Feature set collection per uuid is None. This should not happen.") - unique_solution_counter = 0 - left_uuids = None - right_uuids = None + valid_pairs: List[Tuple[Set[UUID], Set[UUID]]] = [] for uuid, uuid_complete in feature_set_collection_per_uuid.items(): # get the feature set collection, where feature cfw = left link cfw @@ -651,27 +649,46 @@ def case_link_fw_is_equal_to_children_fw( if not issubclass(graph.nodes[_uuid].feature_group_class, link_fw[0].right_feature_group): continue - if left_uuids is None: - left_uuids = uuid_complete - right_uuids = _uuid_complete - unique_solution_counter += 1 - continue - - if left_uuids == uuid_complete and right_uuids == _uuid_complete: - continue - - unique_solution_counter += 1 + # Deduplicate using set equality + if not any(l == uuid_complete and r == _uuid_complete for l, r in valid_pairs): + valid_pairs.append((uuid_complete, _uuid_complete)) - if unique_solution_counter == 1: - if left_uuids is None or right_uuids is None: - raise ValueError("This should not happen.") - return (left_uuids, right_uuids) - elif unique_solution_counter == 0: + if len(valid_pairs) == 1: + return valid_pairs[0] + elif len(valid_pairs) == 0: return False - else: - raise ValueError( - "There are more than one solution for the join. This should not happen. If you have this occurence, please check your logic, but you can also contact the developers, as we skipped this algorithm part for now." - ) + + # Secondary disambiguation: use right_index to pick the correct right batch + right_index = link_fw[0].right_index + if right_index is not None: + # First pass: match by feature.index == link.right_index + filtered = [ + (l, r) + for l, r in valid_pairs + if any( + graph.nodes[u].feature.index is not None and graph.nodes[u].feature.index == right_index + for u in r + if u in graph.nodes + ) + ] + if len(filtered) == 1: + return filtered[0] + + # Second pass: match by feature name appearing in right_index columns + if not filtered: + filtered = [ + (l, r) + for l, r in valid_pairs + if any(graph.nodes[u].feature.name in right_index.index for u in r if u in graph.nodes) + ] + if len(filtered) == 1: + return filtered[0] + + raise ValueError( + "There are more than one solution for the join. " + "If you encounter this, check your links and feature group configuration, " + "or contact the mloda developers." + ) def case_link_equal_feature_groups( self, diff --git a/mloda/core/prepare/resolve_links.py b/mloda/core/prepare/resolve_links.py index 1a7104c2..99a85ed3 100644 --- a/mloda/core/prepare/resolve_links.py +++ b/mloda/core/prepare/resolve_links.py @@ -1,5 +1,5 @@ from collections import OrderedDict, defaultdict -from typing import Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union from uuid import UUID from mloda.core.abstract_plugins.compute_framework import ComputeFramework @@ -262,7 +262,7 @@ def go_through_each_child_and_its_parents_and_look_for_links(self) -> None: right_fg = r_right.feature_group_class # Two-pass matching: exact match first, then polymorphic - matched_links = self._find_matching_links(left_fg, right_fg) + matched_links = self._find_matching_links(left_fg, right_fg, r_right.feature) for matched_link in matched_links: key = self.create_link_trekker_key( @@ -270,7 +270,7 @@ def go_through_each_child_and_its_parents_and_look_for_links(self) -> None: ) self.set_link_trekker(key, child) - def _find_matching_links(self, left_fg: type, right_fg: type) -> List[Link]: + def _find_matching_links(self, left_fg: type, right_fg: type, right_feature: Any = None) -> List[Link]: """Find all matching links using two-pass matching: exact first, then polymorphic. Returns all exact matches if any exist, otherwise returns the most specific @@ -282,6 +282,15 @@ def _find_matching_links(self, left_fg: type, right_fg: type) -> List[Link]: # Pass 1: Collect all exact matches exact_matches = [link for link in self.links if link.matches_exact(left_fg, right_fg)] if exact_matches: + if len(exact_matches) > 1 and right_feature is not None: + # Multiple links share the same (left_fg, right_fg): disambiguate by right_index + index_filtered = [ + link + for link in exact_matches + if right_feature.index is not None and link.right_index == right_feature.index + ] + if index_filtered: + return index_filtered return exact_matches # Pass 2: If no exact matches, find most specific polymorphic matches diff --git a/mloda_plugins/feature_group/experimental/aggregated_feature_group/base.py b/mloda_plugins/feature_group/experimental/aggregated_feature_group/base.py index c63562c7..f0a3284f 100644 --- a/mloda_plugins/feature_group/experimental/aggregated_feature_group/base.py +++ b/mloda_plugins/feature_group/experimental/aggregated_feature_group/base.py @@ -145,7 +145,7 @@ def _extract_aggregation_type(cls, feature: Feature) -> Optional[str]: return aggregation_type # Fall back to configuration - aggregation_type = feature.options.get(cls.AGGREGATION_TYPE) + aggregation_type = feature.options[cls.AGGREGATION_TYPE] return str(aggregation_type) if aggregation_type is not None else None @classmethod diff --git a/mloda_plugins/feature_group/experimental/clustering/base.py b/mloda_plugins/feature_group/experimental/clustering/base.py index a2f10625..c813e8d1 100644 --- a/mloda_plugins/feature_group/experimental/clustering/base.py +++ b/mloda_plugins/feature_group/experimental/clustering/base.py @@ -237,7 +237,7 @@ def _extract_clustering_params(cls, feature: Feature) -> tuple[Optional[str], Op return algorithm, k_value # Fall back to configuration-based - algorithm = feature.options.get(cls.ALGORITHM) + algorithm = feature.options[cls.ALGORITHM] k_value_raw = feature.options.get(cls.K_VALUE) if k_value_raw is None: diff --git a/mloda_plugins/feature_group/experimental/sklearn/pipeline/base.py b/mloda_plugins/feature_group/experimental/sklearn/pipeline/base.py index db4b047a..4a82d3c3 100644 --- a/mloda_plugins/feature_group/experimental/sklearn/pipeline/base.py +++ b/mloda_plugins/feature_group/experimental/sklearn/pipeline/base.py @@ -287,7 +287,7 @@ def _extract_pipeline_name(cls, feature: Feature) -> Optional[str]: return prefix_part # Fall back to configuration-based approach - pipeline_name = feature.options.get(cls.PIPELINE_NAME) + pipeline_name = feature.options[cls.PIPELINE_NAME] pipeline_steps = feature.options.get(cls.PIPELINE_STEPS) # Handle mutual exclusivity: either PIPELINE_NAME or PIPELINE_STEPS diff --git a/mloda_plugins/feature_group/experimental/sklearn/scaling/base.py b/mloda_plugins/feature_group/experimental/sklearn/scaling/base.py index eed711cc..bb7106ac 100644 --- a/mloda_plugins/feature_group/experimental/sklearn/scaling/base.py +++ b/mloda_plugins/feature_group/experimental/sklearn/scaling/base.py @@ -216,7 +216,7 @@ def _extract_scaler_type(cls, feature: Feature) -> Optional[str]: return scaler_type # Fall back to configuration-based approach - scaler_type = feature.options.get(cls.SCALER_TYPE) + scaler_type = feature.options[cls.SCALER_TYPE] if scaler_type is not None and scaler_type not in cls.SUPPORTED_SCALERS: raise ValueError( diff --git a/mloda_plugins/feature_group/input_data/read_db.py b/mloda_plugins/feature_group/input_data/read_db.py index 0a64885f..939ce166 100644 --- a/mloda_plugins/feature_group/input_data/read_db.py +++ b/mloda_plugins/feature_group/input_data/read_db.py @@ -47,7 +47,7 @@ def init_reader(self, options: Optional[Options]) -> Tuple["ReadDB", Any]: if options is None: raise ValueError("Options were not set.") - reader_data_access = options.get("BaseInputData") + reader_data_access = options["BaseInputData"] if reader_data_access is None: raise ValueError("Reader data access was not set.") diff --git a/tests/test_core/test_abstract_plugins/test_components/test_options.py b/tests/test_core/test_abstract_plugins/test_components/test_options.py index 7e437acd..900fc7fd 100644 --- a/tests/test_core/test_abstract_plugins/test_components/test_options.py +++ b/tests/test_core/test_abstract_plugins/test_components/test_options.py @@ -180,6 +180,39 @@ def test_backward_compatibility_with_feature_class(self) -> None: assert feature1 != feature3 + def test_getitem_group_key(self) -> None: + """options["key"] returns same value as options.get("key") for group key.""" + options = Options(group={"group_key": "group_value"}, context={"context_key": "context_value"}) + assert options["group_key"] == options.get("group_key") + + def test_getitem_context_key(self) -> None: + """options["key"] returns same value as options.get("key") for context key.""" + options = Options(group={"group_key": "group_value"}, context={"context_key": "context_value"}) + assert options["context_key"] == options.get("context_key") + + def test_getitem_missing_key_returns_none(self) -> None: + """options["missing"] returns None, consistent with get().""" + options = Options(group={"key": "value"}) + assert options["missing"] is None + + def test_setitem_adds_new_key_to_group(self) -> None: + """options["new_key"] = value adds key to group.""" + options = Options(group={"existing": "value"}) + options["new_key"] = "new_value" + assert options.group["new_key"] == "new_value" + + def test_setitem_updates_existing_group_key(self) -> None: + """options["existing_group_key"] = new_value updates group key.""" + options = Options(group={"key": "old_value"}) + options["key"] = "new_value" + assert options.group["key"] == "new_value" + + def test_setitem_updates_existing_context_key(self) -> None: + """options["existing_context_key"] = new_value updates context key.""" + options = Options(group={"g": 1}, context={"ctx_key": "old_value"}) + options["ctx_key"] = "new_value" + assert options.context["ctx_key"] == "new_value" + def test_migration_scenario(self) -> None: """Test typical migration scenario: all options start in group.""" # Current usage pattern (all options in group during migration) diff --git a/tests/test_plugins/compute_framework/test_multi_link_same_cfw.py b/tests/test_plugins/compute_framework/test_multi_link_same_cfw.py new file mode 100644 index 00000000..b9907ee3 --- /dev/null +++ b/tests/test_plugins/compute_framework/test_multi_link_same_cfw.py @@ -0,0 +1,260 @@ +from typing import Any, Optional, Set + +import pytest + +from mloda.provider import FeatureGroup +from mloda.user import Feature +from mloda.user import FeatureName +from mloda.provider import FeatureSet +from mloda.user import Index +from mloda.provider import BaseInputData +from mloda.provider import DataCreator +from mloda.user import Link, JoinSpec +from mloda.user import Options +from mloda.user import ParallelizationMode +from mloda.user import PluginCollector +from mloda.user import mloda + + +# === Scenario 1: Same class used 3x with different feature names per batch === + + +class ReadFGS1(FeatureGroup): + @classmethod + def input_data(cls) -> Optional[BaseInputData]: + return DataCreator(supports_features={cls.get_class_name()}) + + @classmethod + def calculate_feature(cls, data: Any, features: FeatureSet) -> Any: + return {cls.get_class_name(): ["v1"]} + + +class AggFGS1(FeatureGroup): + @classmethod + def input_data(cls) -> Optional[BaseInputData]: + return DataCreator(supports_features={"agg_sum_s1", "agg_count_s1", "agg_avg_s1"}) + + @classmethod + def calculate_feature(cls, data: Any, features: FeatureSet) -> Any: + name = next(iter(features.get_all_names())) + return {name: ["v1"]} + + +class ConsumerFGS1(FeatureGroup): + def input_features(self, options: Options, feature_name: FeatureName) -> Optional[Set[Feature]]: + return { + Feature(name=ReadFGS1.get_class_name()), + Feature(name="agg_sum_s1", options={"agg_type": "sum"}), + Feature(name="agg_count_s1", options={"agg_type": "count"}), + Feature(name="agg_avg_s1", options={"agg_type": "avg"}), + } + + @classmethod + def calculate_feature(cls, data: Any, features: FeatureSet) -> Any: + return {cls.get_class_name(): ["ok"]} + + +# === Scenario 2: Subclasses with base-class Links === + + +class BaseAggS2(FeatureGroup): + @classmethod + def input_data(cls) -> Optional[BaseInputData]: + return DataCreator(supports_features={cls.get_class_name()}) + + @classmethod + def calculate_feature(cls, data: Any, features: FeatureSet) -> Any: + return {cls.get_class_name(): ["v1"]} + + +class SumAggS2(BaseAggS2): + pass + + +class CountAggS2(BaseAggS2): + pass + + +class AvgAggS2(BaseAggS2): + pass + + +class ReadFGS2(FeatureGroup): + @classmethod + def input_data(cls) -> Optional[BaseInputData]: + return DataCreator(supports_features={cls.get_class_name()}) + + @classmethod + def calculate_feature(cls, data: Any, features: FeatureSet) -> Any: + return {cls.get_class_name(): ["v1"]} + + +class ConsumerFGS2(FeatureGroup): + def input_features(self, options: Options, feature_name: FeatureName) -> Optional[Set[Feature]]: + return { + Feature(name=ReadFGS2.get_class_name()), + Feature(name=SumAggS2.get_class_name()), + Feature(name=CountAggS2.get_class_name()), + Feature(name=AvgAggS2.get_class_name()), + } + + @classmethod + def calculate_feature(cls, data: Any, features: FeatureSet) -> Any: + return {cls.get_class_name(): ["ok"]} + + +# === Scenario 3: Fully distinct classes (verify still works with same cfw) === + + +class ReadFGS3(FeatureGroup): + @classmethod + def input_data(cls) -> Optional[BaseInputData]: + return DataCreator(supports_features={cls.get_class_name()}) + + @classmethod + def calculate_feature(cls, data: Any, features: FeatureSet) -> Any: + return {cls.get_class_name(): ["v1"]} + + +class SumFGS3(FeatureGroup): + @classmethod + def input_data(cls) -> Optional[BaseInputData]: + return DataCreator(supports_features={cls.get_class_name()}) + + @classmethod + def calculate_feature(cls, data: Any, features: FeatureSet) -> Any: + return {cls.get_class_name(): ["v1"]} + + +class CountFGS3(FeatureGroup): + @classmethod + def input_data(cls) -> Optional[BaseInputData]: + return DataCreator(supports_features={cls.get_class_name()}) + + @classmethod + def calculate_feature(cls, data: Any, features: FeatureSet) -> Any: + return {cls.get_class_name(): ["v1"]} + + +class AvgFGS3(FeatureGroup): + @classmethod + def input_data(cls) -> Optional[BaseInputData]: + return DataCreator(supports_features={cls.get_class_name()}) + + @classmethod + def calculate_feature(cls, data: Any, features: FeatureSet) -> Any: + return {cls.get_class_name(): ["v1"]} + + +class ConsumerFGS3(FeatureGroup): + def input_features(self, options: Options, feature_name: FeatureName) -> Optional[Set[Feature]]: + return { + Feature(name=ReadFGS3.get_class_name()), + Feature(name=SumFGS3.get_class_name()), + Feature(name=CountFGS3.get_class_name()), + Feature(name=AvgFGS3.get_class_name()), + } + + @classmethod + def calculate_feature(cls, data: Any, features: FeatureSet) -> Any: + return {cls.get_class_name(): ["ok"]} + + +@pytest.mark.parametrize( + "modes", + [ + ({ParallelizationMode.SYNC}), + ({ParallelizationMode.THREADING}), + ], +) +class TestMultiLinkSameCfw: + def test_multi_link_same_class_same_cfw(self, modes: Set[ParallelizationMode], flight_server: Any) -> None: + """Scenario 1: same FG class used 3x (different feature names per batch), same compute framework.""" + feature = Feature(name=ConsumerFGS1.get_class_name()) + links = { + Link.inner( + JoinSpec(ReadFGS1, Index(("ReadFGS1",))), + JoinSpec(AggFGS1, Index(("agg_sum_s1",))), + ), + Link.inner( + JoinSpec(ReadFGS1, Index(("ReadFGS1",))), + JoinSpec(AggFGS1, Index(("agg_count_s1",))), + ), + Link.inner( + JoinSpec(ReadFGS1, Index(("ReadFGS1",))), + JoinSpec(AggFGS1, Index(("agg_avg_s1",))), + ), + } + result = mloda.run_all( + [feature], + links=links, + compute_frameworks=["PandasDataFrame"], + plugin_collector=PluginCollector.enabled_feature_groups({ReadFGS1, AggFGS1, ConsumerFGS1}), + flight_server=flight_server, + parallelization_modes=modes, + ) + for res in result: + assert len(res) == 1 + assert ConsumerFGS1.get_class_name() in res.columns + + def test_multi_link_subclass_same_cfw(self, modes: Set[ParallelizationMode], flight_server: Any) -> None: + """Scenario 2: subclasses of a common base, Links defined on the base class, same compute framework.""" + feature = Feature(name=ConsumerFGS2.get_class_name()) + links = { + Link.inner( + JoinSpec(ReadFGS2, Index(("ReadFGS2",))), + JoinSpec(BaseAggS2, Index(("SumAggS2",))), + ), + Link.inner( + JoinSpec(ReadFGS2, Index(("ReadFGS2",))), + JoinSpec(BaseAggS2, Index(("CountAggS2",))), + ), + Link.inner( + JoinSpec(ReadFGS2, Index(("ReadFGS2",))), + JoinSpec(BaseAggS2, Index(("AvgAggS2",))), + ), + } + result = mloda.run_all( + [feature], + links=links, + compute_frameworks=["PandasDataFrame"], + plugin_collector=PluginCollector.enabled_feature_groups( + {ReadFGS2, SumAggS2, CountAggS2, AvgAggS2, ConsumerFGS2} + ), + flight_server=flight_server, + parallelization_modes=modes, + ) + for res in result: + assert len(res) == 1 + assert ConsumerFGS2.get_class_name() in res.columns + + def test_multi_link_distinct_class_same_cfw(self, modes: Set[ParallelizationMode], flight_server: Any) -> None: + """Scenario 3: fully distinct FG classes, same compute framework (verify no regression).""" + feature = Feature(name=ConsumerFGS3.get_class_name()) + links = { + Link.inner( + JoinSpec(ReadFGS3, Index(("ReadFGS3",))), + JoinSpec(SumFGS3, Index(("SumFGS3",))), + ), + Link.inner( + JoinSpec(ReadFGS3, Index(("ReadFGS3",))), + JoinSpec(CountFGS3, Index(("CountFGS3",))), + ), + Link.inner( + JoinSpec(ReadFGS3, Index(("ReadFGS3",))), + JoinSpec(AvgFGS3, Index(("AvgFGS3",))), + ), + } + result = mloda.run_all( + [feature], + links=links, + compute_frameworks=["PandasDataFrame"], + plugin_collector=PluginCollector.enabled_feature_groups( + {ReadFGS3, SumFGS3, CountFGS3, AvgFGS3, ConsumerFGS3} + ), + flight_server=flight_server, + parallelization_modes=modes, + ) + for res in result: + assert len(res) == 1 + assert ConsumerFGS3.get_class_name() in res.columns