diff --git a/lnn/symbolic/_gm.py b/lnn/symbolic/_gm.py index dac8bc1..d0679da 100644 --- a/lnn/symbolic/_gm.py +++ b/lnn/symbolic/_gm.py @@ -152,6 +152,8 @@ def _operational_bounds( joined = operand_dfs[0] else: joined = ft.reduce(_full_outer_join, operand_dfs) + if hasattr(operator, 'filter_valid_groundings'): + joined = operator.filter_valid_groundings(joined) operator_groundings = _operator_groundings(joined, operator) ground_objects = _operand_groundings(joined, operator, bindings)