Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/docs/in_depth/join_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Combining datasets from various feature groups is crucial for building comprehen
- **Different Compute Frameworks**: Merging data from feature groups that utilize different underlying compute technologies.
- **Same Compute Framework, Different Sources**: Combining datasets that use the same compute framework but originate from different data sources.
- **Same Feature Group, Different Feature Options**: Integrating data from the same feature group configured with different feature options.
- **Multiple Links on the Same Compute Framework**: Joining a single base FG to multiple right-side FGs (same class, subclasses, or distinct) all within the same compute framework.


_**If we have a feature, which is dependent on a aforementioned setup,**_
Expand Down Expand Up @@ -221,6 +222,8 @@ link = Link.inner(

The balanced inheritance rule ensures that joins only occur between "parallel" subclasses - both sides must be at the same level in the inheritance hierarchy relative to the link definition.

When multiple batches match (e.g. three subclasses all matching a base-class link), the engine disambiguates using `right_index`: first by exact `feature.index` match, then by checking whether the feature name appears in the join key columns of `right_index`. Ensure each link's `right_index` contains the column name that uniquely identifies its right-side batch.

#### mlodaAPI

``` python
Expand Down
2 changes: 1 addition & 1 deletion mloda/core/abstract_plugins/components/base_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def custom_loader(cls, features: FeatureSet) -> Optional[Any]:
options = cls.get_singular_option_from_options(features)
if options is None or features.name_of_one_feature is None:
return None
return options.get(features.name_of_one_feature.name)
return options[features.name_of_one_feature.name]

@classmethod
def get_singular_option_from_options(cls, features: FeatureSet) -> Options | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def matches(
if not _data_access_name:
raise ValueError(f"Data access name was not set for ApiInputData class {self.__class__.__name__}.")

api_input_data_column_names = options.get(_data_access_name)
api_input_data_column_names = options[_data_access_name]
if api_input_data_column_names is None:
return False

Expand Down
6 changes: 6 additions & 0 deletions mloda/core/abstract_plugins/components/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def get(self, key: str) -> Any:
return self.group[key]
return self.context.get(key, None)

def __getitem__(self, key: str) -> Any:
return self.get(key)

def items(self) -> list[tuple[str, Any]]:
"""
Get all key-value pairs from both group and context.
Expand Down Expand Up @@ -160,6 +163,9 @@ def set(self, key: str, value: Any) -> None:
# New key, add to group by default
self.group[key] = value

def __setitem__(self, key: str, value: Any) -> None:
self.set(key, value)

def get_in_features(self) -> "frozenset[Feature]":
val = self.get(DefaultOptionKeys.in_features)

Expand Down
8 changes: 3 additions & 5 deletions mloda/core/filter/global_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,12 @@ def identity_matched_filters(
def unify_options(self, feat_options: Options, filter_options: Options) -> Options:
for key, value in feat_options.items():
if key not in filter_options:
filter_options.set(key, value)
filter_options[key] = value
else:
if filter_options.get(key) == value:
if filter_options[key] == value:
continue
else:
logger.warning(
f"Options are not the same. {key} is different. {filter_options.get(key)} != {value}"
)
logger.warning(f"Options are not the same. {key} is different. {filter_options[key]} != {value}")
return filter_options

def criteria(
Expand Down
61 changes: 39 additions & 22 deletions mloda/core/prepare/execution_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,9 +625,7 @@ def case_link_fw_is_equal_to_children_fw(
if len(feature_set_collection_per_uuid) == 0:
raise ValueError("Feature set collection per uuid is None. This should not happen.")

unique_solution_counter = 0
left_uuids = None
right_uuids = None
valid_pairs: List[Tuple[Set[UUID], Set[UUID]]] = []

for uuid, uuid_complete in feature_set_collection_per_uuid.items():
# get the feature set collection, where feature cfw = left link cfw
Expand All @@ -651,27 +649,46 @@ def case_link_fw_is_equal_to_children_fw(
if not issubclass(graph.nodes[_uuid].feature_group_class, link_fw[0].right_feature_group):
continue

if left_uuids is None:
left_uuids = uuid_complete
right_uuids = _uuid_complete
unique_solution_counter += 1
continue

if left_uuids == uuid_complete and right_uuids == _uuid_complete:
continue

unique_solution_counter += 1
# Deduplicate using set equality
if not any(l == uuid_complete and r == _uuid_complete for l, r in valid_pairs):
valid_pairs.append((uuid_complete, _uuid_complete))

if unique_solution_counter == 1:
if left_uuids is None or right_uuids is None:
raise ValueError("This should not happen.")
return (left_uuids, right_uuids)
elif unique_solution_counter == 0:
if len(valid_pairs) == 1:
return valid_pairs[0]
elif len(valid_pairs) == 0:
return False
else:
raise ValueError(
"There are more than one solution for the join. This should not happen. If you have this occurence, please check your logic, but you can also contact the developers, as we skipped this algorithm part for now."
)

# Secondary disambiguation: use right_index to pick the correct right batch
right_index = link_fw[0].right_index
if right_index is not None:
# First pass: match by feature.index == link.right_index
filtered = [
(l, r)
for l, r in valid_pairs
if any(
graph.nodes[u].feature.index is not None and graph.nodes[u].feature.index == right_index
for u in r
if u in graph.nodes
)
]
if len(filtered) == 1:
return filtered[0]

# Second pass: match by feature name appearing in right_index columns
if not filtered:
filtered = [
(l, r)
for l, r in valid_pairs
if any(graph.nodes[u].feature.name in right_index.index for u in r if u in graph.nodes)
]
if len(filtered) == 1:
return filtered[0]

raise ValueError(
"There are more than one solution for the join. "
"If you encounter this, check your links and feature group configuration, "
"or contact the mloda developers."
)

def case_link_equal_feature_groups(
self,
Expand Down
15 changes: 12 additions & 3 deletions mloda/core/prepare/resolve_links.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import OrderedDict, defaultdict
from typing import Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from uuid import UUID

from mloda.core.abstract_plugins.compute_framework import ComputeFramework
Expand Down Expand Up @@ -262,15 +262,15 @@ def go_through_each_child_and_its_parents_and_look_for_links(self) -> None:
right_fg = r_right.feature_group_class

# Two-pass matching: exact match first, then polymorphic
matched_links = self._find_matching_links(left_fg, right_fg)
matched_links = self._find_matching_links(left_fg, right_fg, r_right.feature)

for matched_link in matched_links:
key = self.create_link_trekker_key(
matched_link, r_left.feature.compute_frameworks, r_right.feature.compute_frameworks
)
self.set_link_trekker(key, child)

def _find_matching_links(self, left_fg: type, right_fg: type) -> List[Link]:
def _find_matching_links(self, left_fg: type, right_fg: type, right_feature: Any = None) -> List[Link]:
"""Find all matching links using two-pass matching: exact first, then polymorphic.

Returns all exact matches if any exist, otherwise returns the most specific
Expand All @@ -282,6 +282,15 @@ def _find_matching_links(self, left_fg: type, right_fg: type) -> List[Link]:
# Pass 1: Collect all exact matches
exact_matches = [link for link in self.links if link.matches_exact(left_fg, right_fg)]
if exact_matches:
if len(exact_matches) > 1 and right_feature is not None:
# Multiple links share the same (left_fg, right_fg): disambiguate by right_index
index_filtered = [
link
for link in exact_matches
if right_feature.index is not None and link.right_index == right_feature.index
]
if index_filtered:
return index_filtered
return exact_matches

# Pass 2: If no exact matches, find most specific polymorphic matches
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _extract_aggregation_type(cls, feature: Feature) -> Optional[str]:
return aggregation_type

# Fall back to configuration
aggregation_type = feature.options.get(cls.AGGREGATION_TYPE)
aggregation_type = feature.options[cls.AGGREGATION_TYPE]
return str(aggregation_type) if aggregation_type is not None else None

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def _extract_clustering_params(cls, feature: Feature) -> tuple[Optional[str], Op
return algorithm, k_value

# Fall back to configuration-based
algorithm = feature.options.get(cls.ALGORITHM)
algorithm = feature.options[cls.ALGORITHM]
k_value_raw = feature.options.get(cls.K_VALUE)

if k_value_raw is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _extract_pipeline_name(cls, feature: Feature) -> Optional[str]:
return prefix_part

# Fall back to configuration-based approach
pipeline_name = feature.options.get(cls.PIPELINE_NAME)
pipeline_name = feature.options[cls.PIPELINE_NAME]
pipeline_steps = feature.options.get(cls.PIPELINE_STEPS)

# Handle mutual exclusivity: either PIPELINE_NAME or PIPELINE_STEPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _extract_scaler_type(cls, feature: Feature) -> Optional[str]:
return scaler_type

# Fall back to configuration-based approach
scaler_type = feature.options.get(cls.SCALER_TYPE)
scaler_type = feature.options[cls.SCALER_TYPE]

if scaler_type is not None and scaler_type not in cls.SUPPORTED_SCALERS:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion mloda_plugins/feature_group/input_data/read_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def init_reader(self, options: Optional[Options]) -> Tuple["ReadDB", Any]:
if options is None:
raise ValueError("Options were not set.")

reader_data_access = options.get("BaseInputData")
reader_data_access = options["BaseInputData"]

if reader_data_access is None:
raise ValueError("Reader data access was not set.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,39 @@ def test_backward_compatibility_with_feature_class(self) -> None:

assert feature1 != feature3

def test_getitem_group_key(self) -> None:
"""options["key"] returns same value as options.get("key") for group key."""
options = Options(group={"group_key": "group_value"}, context={"context_key": "context_value"})
assert options["group_key"] == options.get("group_key")

def test_getitem_context_key(self) -> None:
"""options["key"] returns same value as options.get("key") for context key."""
options = Options(group={"group_key": "group_value"}, context={"context_key": "context_value"})
assert options["context_key"] == options.get("context_key")

def test_getitem_missing_key_returns_none(self) -> None:
"""options["missing"] returns None, consistent with get()."""
options = Options(group={"key": "value"})
assert options["missing"] is None

def test_setitem_adds_new_key_to_group(self) -> None:
"""options["new_key"] = value adds key to group."""
options = Options(group={"existing": "value"})
options["new_key"] = "new_value"
assert options.group["new_key"] == "new_value"

def test_setitem_updates_existing_group_key(self) -> None:
"""options["existing_group_key"] = new_value updates group key."""
options = Options(group={"key": "old_value"})
options["key"] = "new_value"
assert options.group["key"] == "new_value"

def test_setitem_updates_existing_context_key(self) -> None:
"""options["existing_context_key"] = new_value updates context key."""
options = Options(group={"g": 1}, context={"ctx_key": "old_value"})
options["ctx_key"] = "new_value"
assert options.context["ctx_key"] == "new_value"

def test_migration_scenario(self) -> None:
"""Test typical migration scenario: all options start in group."""
# Current usage pattern (all options in group during migration)
Expand Down
Loading