Skip to content

Commit

Permalink
Allow metrics in filters (#274)
Browse files Browse the repository at this point in the history
Resolves #273
Resolves SL-1849

<!---
Include the number of the issue addressed by this PR above if
applicable.
  PRs for code changes without an associated issue *will not be merged*.
  See CONTRIBUTING.md for more information.
-->

### Description
Allow users to reference metrics in where filters for metrics, measures,
and saved queries. This uses syntax like:
`{{ Metric('metric_name', group_by=['entity_name', 'dimension_name']) }}
= 10`
This unlocks new types of metrics that users have been asking for. Some
examples can be found in the linked issue and in [this design
doc](https://www.notion.so/dbtlabs/Metrics-as-Dimensions-55718e9516a7462787ffd6e3e8c1237e?pvs=4).

### Checklist

- [x] I have read [the contributing
guide](https://github.com/dbt-labs/dbt-semantic-interfaces/blob/main/CONTRIBUTING.md)
and understand what's expected of me
- [x] I have signed the
[CLA](https://docs.getdbt.com/docs/contributor-license-agreements)
- [x] This PR includes tests, or tests are not required/relevant for
this PR
- [x] I have run `changie new` to [create a changelog
entry](https://github.com/dbt-labs/dbt-semantic-interfaces/blob/main/CONTRIBUTING.md#adding-a-changelog-entry)
  • Loading branch information
courtneyholcomb authored Mar 19, 2024
1 parent f258b58 commit e4f029b
Show file tree
Hide file tree
Showing 13 changed files with 242 additions and 6 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240318-130949.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Allow metrics in filters.
time: 2024-03-18T13:09:49.730653-07:00
custom:
Author: courtneyholcomb
Issue: "273"
11 changes: 11 additions & 0 deletions dbt_semantic_interfaces/call_parameter_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from dbt_semantic_interfaces.references import (
DimensionReference,
EntityReference,
LinkableElementReference,
MetricReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.type_enums import TimeGranularity
Expand Down Expand Up @@ -38,13 +40,22 @@ class EntityCallParameterSet:
entity_reference: EntityReference


@dataclass(frozen=True)
class MetricCallParameterSet:
"""When 'Metric(...)' is used in the Jinja template of the where filter, the parameters to that call."""

metric_reference: MetricReference
group_by: Tuple[LinkableElementReference, ...]


@dataclass(frozen=True)
class FilterCallParameterSets:
"""The calls for metric items made in the Jinja template of the where filter."""

dimension_call_parameter_sets: Tuple[DimensionCallParameterSet, ...] = ()
time_dimension_call_parameter_sets: Tuple[TimeDimensionCallParameterSet, ...] = ()
entity_call_parameter_sets: Tuple[EntityCallParameterSet, ...] = ()
metric_call_parameter_sets: Tuple[MetricCallParameterSet, ...] = ()


class ParseWhereFilterException(Exception): # noqa: D
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dbt_semantic_interfaces.call_parameter_sets import (
DimensionCallParameterSet,
EntityCallParameterSet,
MetricCallParameterSet,
ParseWhereFilterException,
TimeDimensionCallParameterSet,
)
Expand All @@ -14,6 +15,8 @@
from dbt_semantic_interfaces.references import (
DimensionReference,
EntityReference,
LinkableElementReference,
MetricReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.type_enums import TimeGranularity
Expand Down Expand Up @@ -101,3 +104,16 @@ def create_entity(entity_name: str, entity_path: Sequence[str] = ()) -> EntityCa
entity_path=additional_entity_path_elements + structured_dundered_name.entity_links,
entity_reference=EntityReference(element_name=structured_dundered_name.element_name),
)

@staticmethod
def create_metric(metric_name: str, group_by: Sequence[str] = ()) -> MetricCallParameterSet:
"""Gets called by Jinja when rendering {{ Metric(...) }}."""
if not group_by:
raise ParseWhereFilterException(
"`group_by` parameter is required for Metric in where filter. This is needed to determine 1) the "
"granularity to aggregate the metric to and 2) how to join the metric to the rest of the query."
)
return MetricCallParameterSet(
metric_reference=MetricReference(element_name=metric_name),
group_by=tuple([LinkableElementReference(element_name=group_by_name) for group_by_name in group_by]),
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@

from typing_extensions import override

from dbt_semantic_interfaces.call_parameter_sets import EntityCallParameterSet
from dbt_semantic_interfaces.call_parameter_sets import (
EntityCallParameterSet,
MetricCallParameterSet,
)
from dbt_semantic_interfaces.errors import InvalidQuerySyntax
from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import (
ParameterSetFactory,
)
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.protocols.query_interface import (
QueryInterfaceEntity,
QueryInterfaceEntityFactory,
QueryInterfaceMetric,
QueryInterfaceMetricFactory,
)


Expand All @@ -27,6 +33,20 @@ def _implements_protocol(self) -> QueryInterfaceEntity:
return self


class MetricStub(ProtocolHint[QueryInterfaceMetric]):
"""A Metric implementation that just satisfies the protocol.
QueryInterfaceMetric currently has no methods and the parameter set is created in the factory.
"""

@override
def _implements_protocol(self) -> QueryInterfaceMetric:
return self

def descending(self, _is_descending: bool) -> QueryInterfaceMetric: # noqa: D
raise InvalidQuerySyntax("descending is invalid in the where parameter and filter spec")


class WhereFilterEntityFactory(ProtocolHint[QueryInterfaceEntityFactory]):
"""Executes in the Jinja sandbox to produce parameter sets and append them to a list."""

Expand All @@ -41,3 +61,20 @@ def create(self, entity_name: str, entity_path: Sequence[str] = ()) -> EntityStu
"""Gets called by Jinja when rendering {{ Entity(...) }}."""
self.entity_call_parameter_sets.append(ParameterSetFactory.create_entity(entity_name, entity_path))
return EntityStub()


class WhereFilterMetricFactory(ProtocolHint[QueryInterfaceMetricFactory]):
"""Executes in the Jinja sandbox to produce parameter sets and append them to a list."""

@override
def _implements_protocol(self) -> QueryInterfaceMetricFactory:
return self

def __init__(self) -> None: # noqa: D
self.metric_call_parameter_sets: List[MetricCallParameterSet] = []

def create(self, metric_name: str, group_by: Sequence[str] = ()) -> MetricStub: # noqa: D
self.metric_call_parameter_sets.append(
ParameterSetFactory.create_metric(metric_name=metric_name, group_by=group_by)
)
return MetricStub()
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from dbt_semantic_interfaces.parsing.where_filter.where_filter_entity import (
WhereFilterEntityFactory,
WhereFilterMetricFactory,
)
from dbt_semantic_interfaces.parsing.where_filter.where_filter_time_dimension import (
WhereFilterTimeDimensionFactory,
Expand All @@ -31,13 +32,15 @@ def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSet
time_dimension_factory = WhereFilterTimeDimensionFactory()
dimension_factory = WhereFilterDimensionFactory()
entity_factory = WhereFilterEntityFactory()
metric_factory = WhereFilterMetricFactory()

try:
# the string that the sandbox renders is unused
SandboxedEnvironment(undefined=StrictUndefined).from_string(where_sql_template).render(
Dimension=dimension_factory.create,
TimeDimension=time_dimension_factory.create,
Entity=entity_factory.create,
Metric=metric_factory.create,
)
except (UndefinedError, TemplateSyntaxError, SecurityError) as e:
raise ParseWhereFilterException(f"Error while parsing Jinja template:\n{where_sql_template}") from e
Expand All @@ -63,4 +66,5 @@ def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSet
dimension_call_parameter_sets=tuple(dimension_call_parameter_sets),
time_dimension_call_parameter_sets=tuple(time_dimension_factory.time_dimension_call_parameter_sets),
entity_call_parameter_sets=tuple(entity_factory.entity_call_parameter_sets),
metric_call_parameter_sets=tuple(metric_factory.metric_call_parameter_sets),
)
12 changes: 12 additions & 0 deletions dbt_semantic_interfaces/protocols/query_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,15 @@ class QueryInterfaceEntityFactory(Protocol):
def create(self, entity_name: str, entity_path: Sequence[str] = ()) -> QueryInterfaceEntity:
"""Create an Entity."""
pass


class QueryInterfaceMetricFactory(Protocol):
"""Creates an Metric for the query interface.
Represented as the Metric constructor in the Jinja sandbox.
"""

@abstractmethod
def create(self, metric_name: str, group_by: Sequence[str] = ()) -> QueryInterfaceMetric:
"""Create a Metric."""
pass
4 changes: 2 additions & 2 deletions dbt_semantic_interfaces/validations/saved_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _check_group_bys(valid_group_by_element_names: Set[str], saved_query: SavedQ
[x.entity_reference.element_name for x in parameter_sets.entity_call_parameter_sets]
+ [x.dimension_reference.element_name for x in parameter_sets.dimension_call_parameter_sets]
+ [x.time_dimension_reference.element_name for x in parameter_sets.time_dimension_call_parameter_sets]
+ [x.metric_reference.element_name for x in parameter_sets.metric_call_parameter_sets]
)

if len(element_names_in_group_by) != 1 or element_names_in_group_by[0] not in valid_group_by_element_names:
Expand Down Expand Up @@ -129,7 +130,7 @@ def _check_where(saved_query: SavedQuery) -> Sequence[ValidationIssue]:
def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]: # noqa: D
issues: List[ValidationIssue] = []
valid_metric_names = {metric.name for metric in semantic_manifest.metrics}
valid_group_by_element_names = {METRIC_TIME_ELEMENT_NAME}
valid_group_by_element_names = valid_metric_names.union({METRIC_TIME_ELEMENT_NAME})
for semantic_model in semantic_manifest.semantic_models:
for dimension in semantic_model.dimensions:
valid_group_by_element_names.add(dimension.name)
Expand All @@ -146,5 +147,4 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati
saved_query=saved_query,
)
issues += SavedQueryRule._check_where(saved_query=saved_query)

return issues
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,23 @@ metric:
- name: bookings
offset_window: 5 days
alias: bookings_5_days_ago
---
metric:
name: "ever_active_listings"
description: |
number of listings that have had at least 2 bookings ever
type: simple
type_params:
measure:
name: listings
filter: "{{ Metric('bookings', group_by=['listing']) }} > 2"
---
metric:
name: "active_listings"
description: |
number of listings that had at least 2 bookings on given date
type: simple
type_params:
measure:
name: listings
filter: "{{ Metric('bookings', group_by=['listing', 'metric_time']) }} > 2"
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,20 @@ saved_query:
export_as: table
schema: exports_schema
alias: bookings_export_table
---
saved_query:
name: highly_active_listings
description: Booking-related metrics that are of the highest priority.
query_params:
metrics:
- listings
group_by:
- TimeDimension('metric_time', 'DAY')
where:
- "{{ Metric('bookings', group_by=['listing', 'metric_time']) }} > 5"
exports:
- name: highly_active_listings
config:
export_as: table
schema: exports_schema
alias: highly_active_listings_export_table
38 changes: 38 additions & 0 deletions tests/implementations/where_filter/test_parse_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DimensionCallParameterSet,
EntityCallParameterSet,
FilterCallParameterSets,
MetricCallParameterSet,
ParseWhereFilterException,
TimeDimensionCallParameterSet,
)
Expand All @@ -16,6 +17,8 @@
from dbt_semantic_interfaces.references import (
DimensionReference,
EntityReference,
LinkableElementReference,
MetricReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.type_enums import TimeGranularity
Expand Down Expand Up @@ -132,6 +135,41 @@ def test_extract_entity_call_parameter_sets() -> None: # noqa: D
)


def test_extract_metric_call_parameter_sets() -> None: # noqa: D
parse_result = PydanticWhereFilter(
where_sql_template=("{{ Metric('bookings', group_by=['listing']) }} > 2")
).call_parameter_sets

assert parse_result == FilterCallParameterSets(
dimension_call_parameter_sets=(),
entity_call_parameter_sets=(),
metric_call_parameter_sets=(
MetricCallParameterSet(
metric_reference=MetricReference("bookings"),
group_by=(LinkableElementReference("listing"),),
),
),
)

parse_result = PydanticWhereFilter(
where_sql_template=("{{ Metric('bookings', group_by=['listing', 'metric_time']) }} > 2")
).call_parameter_sets

assert parse_result == FilterCallParameterSets(
dimension_call_parameter_sets=(),
entity_call_parameter_sets=(),
metric_call_parameter_sets=(
MetricCallParameterSet(
metric_reference=MetricReference("bookings"),
group_by=(LinkableElementReference("listing"), LinkableElementReference("metric_time")),
),
),
)

with pytest.raises(ParseWhereFilterException):
PydanticWhereFilter(where_sql_template=("{{ Metric('bookings') }} > 2")).call_parameter_sets


def test_metric_time_in_dimension_call_error() -> None: # noqa: D
with pytest.raises(ParseWhereFilterException, match="so it should be referenced using TimeDimension"):
assert (
Expand Down
2 changes: 2 additions & 0 deletions tests/parsing/test_metric_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def test_constraint_list_metric_parsing() -> None:
filter:
- "{{ dimension('some_dimension') }} IN ('value1', 'value2')"
- "1 > 0"
- "{{ metric('some_metric', group_by=['some_entity']) }} > 1"
"""
)
file = YamlConfigFile(filepath="inline_for_test", contents=yaml_contents)
Expand All @@ -376,6 +377,7 @@ def test_constraint_list_metric_parsing() -> None:
where_filters=[
PydanticWhereFilter(where_sql_template="{{ dimension('some_dimension') }} IN ('value1', 'value2')"),
PydanticWhereFilter(where_sql_template="1 > 0"),
PydanticWhereFilter(where_sql_template="{{ metric('some_metric', group_by=['some_entity']) }} > 1"),
]
)

Expand Down
31 changes: 28 additions & 3 deletions tests/parsing/test_where_filter_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
on inputs to parse_obj or parse_raw, as that is what the pydantic models will generally encounter.
"""


import pytest

from dbt_semantic_interfaces.call_parameter_sets import EntityCallParameterSet
from dbt_semantic_interfaces.call_parameter_sets import (
EntityCallParameterSet,
MetricCallParameterSet,
)
from dbt_semantic_interfaces.implementations.base import HashableBaseModel
from dbt_semantic_interfaces.implementations.filters.where_filter import (
PydanticWhereFilter,
Expand All @@ -22,7 +24,11 @@
from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import (
WhereFilterParser,
)
from dbt_semantic_interfaces.references import EntityReference
from dbt_semantic_interfaces.references import (
EntityReference,
LinkableElementReference,
MetricReference,
)
from dbt_semantic_interfaces.type_enums.date_part import DatePart

__BOOLEAN_EXPRESSION__ = "1 > 0"
Expand Down Expand Up @@ -187,3 +193,22 @@ def test_entity() -> None: # noqa
),
entity_reference=EntityReference(element_name="entity_2"),
)


def test_metric() -> None: # noqa
where = "{{ Metric('metric', group_by=['dimension']) }} = 10"
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
assert len(param_sets.metric_call_parameter_sets) == 1
assert param_sets.metric_call_parameter_sets[0] == MetricCallParameterSet(
group_by=(LinkableElementReference(element_name="dimension"),),
metric_reference=MetricReference(element_name="metric"),
)

# Without kwarg syntax
where = "{{ Metric('metric', ['dimension']) }} = 10"
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
assert len(param_sets.metric_call_parameter_sets) == 1
assert param_sets.metric_call_parameter_sets[0] == MetricCallParameterSet(
group_by=(LinkableElementReference(element_name="dimension"),),
metric_reference=MetricReference(element_name="metric"),
)
Loading

0 comments on commit e4f029b

Please sign in to comment.