From 67b44daa8d740a24fe882d81ac0376895278b7a5 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Wed, 7 Jun 2023 08:24:51 -0400 Subject: [PATCH 01/21] Port transformed_data functionality from VegaFusion --- altair/utils/transformed_data.py | 536 +++++++++++++++++++++++++++++++ altair/vegalite/v5/api.py | 158 ++++++++- tools/update_init_file.py | 3 +- 3 files changed, 695 insertions(+), 2 deletions(-) create mode 100644 altair/utils/transformed_data.py diff --git a/altair/utils/transformed_data.py b/altair/utils/transformed_data.py new file mode 100644 index 000000000..9004d0dea --- /dev/null +++ b/altair/utils/transformed_data.py @@ -0,0 +1,536 @@ +from typing import List, Optional, Tuple, Dict, Iterable, overload + +import pandas as pd + +from altair import ( + Chart, + FacetChart, + LayerChart, + HConcatChart, + VConcatChart, + ConcatChart, + data_transformers, +) +from typing import Union +from altair.utils.schemapi import Undefined + +Scope = Tuple[int, ...] +FacetMapping = Dict[Tuple[str, Scope], Tuple[str, Scope]] + +MAGIC_CHART_NAME = "_vf_mark{}" + + +@overload +def transformed_data( + chart: Union[Chart, FacetChart], + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, +) -> Optional[pd.DataFrame]: + ... + + +@overload +def transformed_data( + chart: Union[LayerChart, HConcatChart, VConcatChart, ConcatChart], + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, +) -> List[pd.DataFrame]: + ... + + +def transformed_data(chart, row_limit=None, exclude=None): + """Evaluate a Chart's transforms + + Evaluate the data transforms associated with a Chart and return the + transformed data as one or more DataFrames + + Parameters + ---------- + chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart + Altair chart to evaluate transforms on + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + DataFrame or list of DataFrame + If input chart is a Chart or Facet Chart, returns a DataFrame of the transformed data + Otherwise, returns a list of DataFrames of the transformed data + """ + from vegafusion import runtime, get_local_tz, get_inline_datasets_for_spec # type: ignore + + if isinstance(chart, Chart): + # Add dummy mark if None specified to satisfy Vega-Lite + if chart.mark == Undefined: + chart = chart.mark_point() + + # Deep copy chart so that we can rename marks without affecting caller + chart = chart.copy(deep=True) + + # Rename chart or subcharts with magic names that we can look up in the + # resulting Vega specification + chart_names = name_chart(chart, 0, exclude=exclude) + + # Compile to Vega and extract inline DataFrames + with data_transformers.enable("vegafusion-inline"): + vega_spec = chart.to_dict(format="vega") + inline_datasets = get_inline_datasets_for_spec(vega_spec) + + # Build mapping from mark names to vega datasets + facet_mapping = get_facet_mapping(vega_spec) + dataset_mapping = get_datasets_for_chart_names( + vega_spec, chart_names, facet_mapping + ) + + # Build a list of vega dataset names that corresponds to the order + # of the chart components + dataset_names = [] + for chart_name in chart_names: + if chart_name in dataset_mapping: + dataset_names.append(dataset_mapping[chart_name]) + else: + raise ValueError("Failed to locate all datasets") + + # Extract transformed datasets with VegaFusion + datasets, warnings = runtime.pre_transform_datasets( + vega_spec, + dataset_names, + get_local_tz(), + row_limit=row_limit, + inline_datasets=inline_datasets, + ) + + if isinstance(chart, (Chart, FacetChart)): + # Return DataFrame (or None if it was excluded) if input was a simple Chart + if not datasets: + return None + else: + return datasets[0] + else: + # Otherwise return the list of DataFrames + return datasets + + +def make_magic_chart_name(i: int) -> str: + """Make magic chart name for chart number i + + Parameters + ---------- + i : int + Mark number + + Returns + ------- + str + Mark name + """ + return MAGIC_CHART_NAME.format(i) + + +def name_chart( + chart: Union[ + Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, ConcatChart + ], + i: int = 0, + exclude: Optional[Iterable[str]] = None, +) -> List[str]: + """Name unnamed charts and subcharts + + Name unnamed charts and subcharts so that we can look them up later in + the compiled Vega spec. + + Note: This function mutates the input chart by applying names to + unnamed charts. + + Parameters + ---------- + chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart + Altair chart to apply names to + i : int (default 0) + Starting chart index + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + list of str + List of the names of the charts and subcharts + """ + exclude = set(exclude) if exclude is not None else set() + if isinstance(chart, (Chart, FacetChart)): + # Perform shallow copy of chart so that we can change + # the name + if chart.name not in exclude: + if chart.name in (None, Undefined): + name = make_magic_chart_name(i) + chart.name = name + return [chart.name] + else: + return [] + else: + if isinstance(chart, LayerChart): + subcharts = chart.layer + elif isinstance(chart, HConcatChart): + subcharts = chart.hconcat + elif isinstance(chart, VConcatChart): + subcharts = chart.vconcat + elif isinstance(chart, ConcatChart): + subcharts = chart.concat + else: + raise ValueError( + "transformed_data accepts an instance of " + "Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart\n" + f"Received value of type: {type(chart)}" + ) + + chart_names: List[str] = [] + for subchart in subcharts: + for name in name_chart(subchart, i=i + len(chart_names), exclude=exclude): + chart_names.append(name) + return chart_names + + +def get_group_mark_for_scope(vega_spec: dict, scope: Scope) -> Optional[dict]: + """Get the group mark at a particular scope + + Parameters + ---------- + vega_spec : dict + Top-level Vega specification dictionary + scope : tuple of int + Scope tuple. If empty, the original Vega specification is returned. + Otherwise, the nested group mark at the scope specified is returned. + + Returns + ------- + dict or None + Top-level Vega spec (if scope is empty) + or group mark (if scope is non-empty) + or None (if group mark at scope does not exist) + + Examples + -------- + >>> spec = { + ... "marks": [ + ... { + ... "type": "group", + ... "marks": [{"type": "symbol"}] + ... }, + ... { + ... "type": "group", + ... "marks": [{"type": "rect"}]} + ... ] + ... } + >>> get_group_mark_for_scope(spec, (1,)) + {'type': 'group', 'marks': [{'type': 'rect'}]} + """ + group = vega_spec + + # Find group at scope + for scope_value in scope: + group_index = 0 + child_group = None + for mark in group.get("marks", []): + if mark.get("type") == "group": + if group_index == scope_value: + child_group = mark + break + group_index += 1 + if child_group is None: + return None + group = child_group + + return group + + +def get_datasets_for_scope(vega_spec: dict, scope: Scope) -> List[str]: + """Get the names of the datasets that are defined at a given scope + + Parameters + ---------- + vega_spec : dict + Top-leve Vega specification + scope : tuple of int + Scope tuple. If empty, the names of top-level datasets are returned + Otherwise, the names of the datasets defined in the nested group mark + at the specified scope are returned. + + Returns + ------- + list of str + List of the names of the datasets defined at the specified scope + + Examples + -------- + >>> spec = { + ... "data": [ + ... {"name": "data1"} + ... ], + ... "marks": [ + ... { + ... "type": "group", + ... "data": [ + ... {"name": "data2"} + ... ], + ... "marks": [{"type": "symbol"}] + ... }, + ... { + ... "type": "group", + ... "data": [ + ... {"name": "data3"}, + ... {"name": "data4"}, + ... ], + ... "marks": [{"type": "rect"}] + ... } + ... ] + ... } + + >>> get_datasets_for_scope(spec, ()) + ['data1'] + + >>> get_datasets_for_scope(spec, (0,)) + ['data2'] + + >>> get_datasets_for_scope(spec, (1,)) + ['data3', 'data4'] + + Returns empty when no group mark exists at scope + >>> get_datasets_for_scope(spec, (1, 3)) + [] + """ + group = get_group_mark_for_scope(vega_spec, scope) or {} + + # get datasets from group + datasets = [] + for dataset in group.get("data", []): + datasets.append(dataset["name"]) + + # Add facet dataset + facet_dataset = group.get("from", {}).get("facet", {}).get("name", None) + if facet_dataset: + datasets.append(facet_dataset) + return datasets + + +def get_definition_scope_for_data_reference( + vega_spec: dict, data_name: str, usage_scope: Scope +) -> Optional[Scope]: + """Return the scope that a dataset is defined at, for a given usage scope + + Parameters + ---------- + vega_spec: dict + Top-level Vega specification + data_name: str + The name of a dataset reference + usage_scope: tuple of int + The scope that the dataset is referenced in + + Returns + ------- + tuple of int + The scope where the referenced dataset is defined, + or None if no such dataset is found + + Examples + -------- + >>> spec = { + ... "data": [ + ... {"name": "data1"} + ... ], + ... "marks": [ + ... { + ... "type": "group", + ... "data": [ + ... {"name": "data2"} + ... ], + ... "marks": [{ + ... "type": "symbol", + ... "encode": { + ... "update": { + ... "x": {"field": "x", "data": "data1"}, + ... "y": {"field": "y", "data": "data2"}, + ... } + ... } + ... }] + ... } + ... ] + ... } + + data1 is referenced at scope [0] and defined at scope [] + >>> get_definition_scope_for_data_reference(spec, "data1", (0,)) + () + + data2 is referenced at scope [0] and defined at scope [0] + >>> get_definition_scope_for_data_reference(spec, "data2", (0,)) + (0,) + + If data2 is not visible at scope [] (the top level), + because it's defined in scope [0] + >>> repr(get_definition_scope_for_data_reference(spec, "data2", ())) + 'None' + """ + for i in reversed(range(len(usage_scope) + 1)): + scope = usage_scope[:i] + datasets = get_datasets_for_scope(vega_spec, scope) + if data_name in datasets: + return scope + return None + + +def get_facet_mapping(group: dict, scope: Scope = ()) -> FacetMapping: + """Create mapping from facet definitions to source datasets + + Parameters + ---------- + group : dict + Top-level Vega spec or nested group mark + scope : tuple of int + Scope of the group dictionary within a top-level Vega spec + + Returns + ------- + dict + Dictionary from (facet_name, facet_scope) to (dataset_name, dataset_scope) + + Examples + -------- + >>> spec = { + ... "data": [ + ... {"name": "data1"} + ... ], + ... "marks": [ + ... { + ... "type": "group", + ... "from": { + ... "facet": { + ... "name": "facet1", + ... "data": "data1", + ... "groupby": ["colA"] + ... } + ... } + ... } + ... ] + ... } + >>> get_facet_mapping(spec) + {('facet1', (0,)): ('data1', ())} + """ + facet_mapping = {} + group_index = 0 + mark_group = get_group_mark_for_scope(group, scope) or {} + for mark in mark_group.get("marks", []): + if mark.get("type", None) == "group": + # Get facet for this group + group_scope = scope + (group_index,) + facet = mark.get("from", {}).get("facet", None) + if facet is not None: + facet_name = facet.get("name", None) + facet_data = facet.get("data", None) + if facet_name is not None and facet_data is not None: + definition_scope = get_definition_scope_for_data_reference( + group, facet_data, scope + ) + if definition_scope is not None: + facet_mapping[(facet_name, group_scope)] = ( + facet_data, + definition_scope, + ) + + # Handle children recursively + child_mapping = get_facet_mapping(group, scope=group_scope) + facet_mapping.update(child_mapping) + group_index += 1 + + return facet_mapping + + +def get_from_facet_mapping( + scoped_dataset: Tuple[str, Scope], facet_mapping: FacetMapping +) -> Tuple[str, Scope]: + """Apply facet mapping to a scoped dataset + + Parameters + ---------- + scoped_dataset : (str, tuple of int) + A dataset name and scope tuple + facet_mapping : dict from (str, tuple of int) to (str, tuple of int) + The facet mapping produced by get_facet_mapping + + Returns + ------- + (str, tuple of int) + Dataset name and scope tuple that has been mapped as many times as possible + + Examples + -------- + Facet mapping as produced by get_facet_mapping + >>> facet_mapping = {("facet1", (0,)): ("data1", ()), ("facet2", (0, 1)): ("facet1", (0,))} + >>> get_from_facet_mapping(("facet2", (0, 1)), facet_mapping) + ('data1', ()) + """ + while scoped_dataset in facet_mapping: + scoped_dataset = facet_mapping[scoped_dataset] + return scoped_dataset + + +def get_datasets_for_chart_names( + group: dict, + vl_chart_names: List[str], + facet_mapping: FacetMapping, + scope: Scope = (), +) -> Dict[str, Tuple[str, Scope]]: + """Get the Vega datasets that correspond to the provided Altair chart names + + Parameters + ---------- + group : dict + Top-level Vega spec or nested group mark + vl_chart_names : list of str + List of the Vega-Lite + facet_mapping : dict from (str, tuple of int) to (str, tuple of int) + The facet mapping produced by get_facet_mapping + scope : tuple of int + Scope of the group dictionary within a top-level Vega spec + + Returns + ------- + dict from str to (str, tuple of int) + Dict from Altair chart names to scoped datasets + """ + datasets = {} + group_index = 0 + mark_group = get_group_mark_for_scope(group, scope) or {} + for mark in mark_group.get("marks", []): + for vl_chart_name in vl_chart_names: + if mark.get("name", "") == f"{vl_chart_name}_cell": + data_name = mark.get("from", {}).get("facet", None).get("data", None) + scoped_data_name = (data_name, scope) + datasets[vl_chart_name] = get_from_facet_mapping( + scoped_data_name, facet_mapping + ) + break + + name = mark.get("name", "") + if mark.get("type", "") == "group": + group_data_names = get_datasets_for_chart_names( + group, vl_chart_names, facet_mapping, scope=scope + (group_index,) + ) + for k, v in group_data_names.items(): + datasets.setdefault(k, v) + group_index += 1 + else: + for vl_chart_name in vl_chart_names: + if name.startswith(vl_chart_name) and name.endswith("_marks"): + data_name = mark.get("from", {}).get("data", None) + scoped_data = get_definition_scope_for_data_reference( + group, data_name, scope + ) + if scoped_data is not None: + datasets[vl_chart_name] = get_from_facet_mapping( + (data_name, scoped_data), facet_mapping + ) + break + + return datasets diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 33844a6a2..f7cdfdd2d 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -8,7 +8,7 @@ from toolz.curried import pipe as _pipe import itertools import sys -from typing import cast, List, Optional, Any +from typing import cast, List, Optional, Any, Iterable # Have to rename it here as else it overlaps with schema.core.Type from typing import Type as TypingType @@ -2657,6 +2657,32 @@ def to_dict( validate=validate, format=format, ignore=ignore, context=context ) + def transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> Optional[pd.DataFrame]: + """Evaluate a Chart's transforms + + Evaluate the data transforms associated with a Chart and return the + transformed data a DataFrame + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + DataFrame + Transformed data as a DataFrame + """ + from ...utils.transformed_data import transformed_data + + return transformed_data(self, row_limit=row_limit, exclude=exclude) + def add_params(self, *params) -> Self: """Add one or more parameters to the chart.""" if not params: @@ -2917,6 +2943,32 @@ def __or__(self, other): copy |= other return copy + def transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> List[pd.DataFrame]: + """Evaluate a ConcatChart's transforms + + Evaluate the data transforms associated with a ConcatChart and return the + transformed data for each subplot as a list of DataFrames + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + list of DataFrame + Transformed data for each subplot as a list of DataFrames + """ + from ...utils.transformed_data import transformed_data + + return transformed_data(self, row_limit=row_limit, exclude=exclude) + def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: """Make chart axes scales interactive @@ -2988,6 +3040,32 @@ def __or__(self, other): copy |= other return copy + def transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> List[pd.DataFrame]: + """Evaluate a HConcatChart's transforms + + Evaluate the data transforms associated with a HConcatChart and return the + transformed data for each subplot as a list of DataFrames + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + list of DataFrame + Transformed data for each subplot as a list of DataFrames + """ + from ...utils.transformed_data import transformed_data + + return transformed_data(self, row_limit=row_limit, exclude=exclude) + def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: """Make chart axes scales interactive @@ -3059,6 +3137,32 @@ def __and__(self, other): copy &= other return copy + def transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> List[pd.DataFrame]: + """Evaluate a VConcatChart's transforms + + Evaluate the data transforms associated with a VConcatChart and return the + transformed data for each subplot as a list of DataFrames + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + list of DataFrame + Transformed data for each subplot as a list of DataFrames + """ + from ...utils.transformed_data import transformed_data + + return transformed_data(self, row_limit=row_limit, exclude=exclude) + def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: """Make chart axes scales interactive @@ -3129,6 +3233,32 @@ def __init__(self, data=Undefined, layer=(), **kwargs): for prop in combined_dict: self[prop] = combined_dict[prop] + def transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> List[pd.DataFrame]: + """Evaluate a LayerChart's transforms + + Evaluate the data transforms associated with a LayerChart and return the + transformed data for each layer as a list of DataFrames + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + list of DataFrame + Transformed data for each layer as a list of DataFrames + """ + from ...utils.transformed_data import transformed_data + + return transformed_data(self, row_limit=row_limit, exclude=exclude) + def __iadd__(self, other): _check_if_valid_subspec(other, "LayerChart") _check_if_can_be_layered(other) @@ -3218,6 +3348,32 @@ def __init__( data=data, spec=spec, facet=facet, params=params, **kwargs ) + def transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> Optional[pd.DataFrame]: + """Evaluate a FacetChart's transforms + + Evaluate the data transforms associated with a FacetChart and return the + transformed data a DataFrame + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + DataFrame + Transformed data as a DataFrame + """ + from ...utils.transformed_data import transformed_data + + return transformed_data(self, row_limit=row_limit, exclude=exclude) + def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: """Make chart axes scales interactive diff --git a/tools/update_init_file.py b/tools/update_init_file.py index 41bb8f4d8..712d3f57e 100644 --- a/tools/update_init_file.py +++ b/tools/update_init_file.py @@ -6,7 +6,7 @@ import sys from pathlib import Path from os.path import abspath, dirname, join -from typing import TypeVar, Type, cast, List, Any, Optional +from typing import TypeVar, Type, cast, List, Any, Optional, Iterable import black @@ -81,6 +81,7 @@ def _is_relevant_attribute(attr_name): or attr is Any or attr is Literal or attr is Optional + or attr is Iterable or attr_name == "TypingDict" ): return False From 2993959aa1ad4424a33b11d5d06a6e346372a7fd Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Thu, 8 Jun 2023 10:29:43 -0400 Subject: [PATCH 02/21] Add initial transformed_data tests --- tests/test_transformed_data.py | 102 +++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/test_transformed_data.py diff --git a/tests/test_transformed_data.py b/tests/test_transformed_data.py new file mode 100644 index 000000000..3e198d9a7 --- /dev/null +++ b/tests/test_transformed_data.py @@ -0,0 +1,102 @@ +from altair.utils.execeval import eval_block +from tests import examples_methods_syntax +import pkgutil +import pytest + + +@pytest.mark.parametrize("filename,rows,cols", [ + ("annual_weather_heatmap.py", 366, ["monthdate_date_end", "max_temp_max"]), + ("anscombe_plot.py", 44, ["Series", "X", "Y"]), + ("bar_chart_sorted.py", 6, ["site", "sum_yield"]), + ("bar_chart_trellis_compact.py", 27, ["p", "p_end"]), + ("beckers_barley_trellis_plot.py", 120, ["year", "site"]), + ("beckers_barley_wrapped_facet.py", 120, ["site", "median_yield"]), + ("bump_chart.py", 100, ["rank", "yearmonth_date"]), + ("comet_chart.py", 120, ["variety", "delta"]), + ("connected_scatterplot.py", 55, ["miles", "gas"]), + ("diverging_stacked_bar_chart.py", 40, ["value", "percentage_start"]), + ("donut_chart.py", 6, ["value_start", "value_end"]), + ("gapminder_bubble_plot.py", 187, ["income", "population"]), + ("grouped_bar_chart2.py", 9, ["Group", "Value_start"]), + ("hexbins.py", 84, ["xFeaturePos", "mean_temp_max"]), + ("histogram_heatmap.py", 378, ["bin_maxbins_40_Rotten_Tomatoes_Rating", "__count"]), + ("histogram_scatterplot.py", 64, ["bin_maxbins_10_Rotten_Tomatoes_Rating", "__count"]), + ("interactive_legend.py", 1708, ["sum_count_start", "series"]), + ("iowa_electricity.py", 51, ["net_generation_start", "year"]), + ("isotype.py", 37, ["animal", "x"]), + ("isotype_grid.py", 100, ["row", "col"]), + ("lasagna_plot.py", 492, ["yearmonthdate_date", "sum_price"]), + ("layered_area_chart.py", 51, ["source", "net_generation"]), + ("layered_bar_chart.py", 51, ["source", "net_generation"]), + ("layered_histogram.py", 113, ["bin_maxbins_100_Measurement"]), + ("line_chart_with_cumsum.py", 52, ["cumulative_wheat"]), + ("line_percent.py", 30, ["sex", "perc"]), + ("line_with_log_scale.py", 15, ["year", "sum_people"]), + ("multifeature_scatter_plot.py", 150, ["petalWidth", "species"]), + ("natural_disasters.py", 686, ["Deaths", "Year"]), + ("normalized_stacked_area_chart.py", 51, ["source", "net_generation_start"]), + ("normalized_stacked_bar_chart.py", 60, ["site", "sum_yield_start"]), + ("parallel_coordinates.py", 600, ["key", "value"]), + ("percentage_of_total.py", 5, ["PercentOfTotal", "TotalTime"]), + ("pie_chart.py", 6, ["category", "value_start"]), + ("pyramid.py", 3, ["category", "value_start"]), + ("stacked_bar_chart_sorted_segments.py", 60, ["variety", "site"]), + ("stem_and_leaf.py", 100, ["stem", "leaf"]), + ("streamgraph.py", 1708, ["series", "sum_count"]), + ("top_k_items.py", 10, ["rank", "IMDB_Rating_start"]), + ("top_k_letters.py", 9, ["rank", "letters"]), + ("top_k_with_others.py", 10, ["ranked_director", "mean_aggregate_gross"]), + ("trellis_area_sort_array.py", 492, ["date", "price"]), + ("trellis_histogram.py", 20, ["Origin", "__count"]), + ("us_population_over_time.py", 38, ["sex", "people_start"]), + ("us_population_over_time_facet.py", 285, ["year", "sum_people"]), + ("wilkinson-dot-plot.py", 21, ["data", "id"]), + ("window_rank.py", 12, ["team", "diff"]), +]) +def test_primitive_chart_examples(filename, rows, cols): + source = pkgutil.get_data(examples_methods_syntax.__name__, filename) + chart = eval_block(source) + df = chart.transformed_data() + assert len(df) == rows + assert set(cols).issubset(set(df.columns)) + + +@pytest.mark.parametrize("filename,all_rows,all_cols", [ + ("errorbars_with_std.py", [10, 10], [["upper_yield"], ["extent_yield"]]), + ("candlestick_chart.py", [44, 44], [["low"], ["close"]]), + ("co2_concentration.py", [713, 7, 7], [["first_date"], ["scaled_date"], ["end"]]), + ("falkensee.py", [2, 38, 38], [["event"], ["population"], ["population"]]), + ("heat_lane.py", [10, 10], [["bin_count_start"], ["y2"]]), + ("histogram_responsive.py", [20, 20], [["__count"], ["__count"]]), + ("histogram_with_a_global_mean_overlay.py", [9, 1], [["__count"], ["mean_IMDB_Rating"]]), + ("horizon_graph.py", [20, 20], [["x"], ["ny"]]), + ("interactive_cross_highlight.py", [64, 64, 13], [["__count"], ["__count"], ["Major_Genre"]]), + ("interval_selection.py", [123, 123], [["price_start"], ["date"]]), + ("layered_chart_with_dual_axis.py", [12, 12], [["month_date"], ["average_precipitation"]]), + ("layered_heatmap_text.py", [9, 9], [["Cylinders"], ["mean_horsepower"]]), + ("multiline_highlight.py", [560, 560], [["price"], ["date"]]), + ("multiline_tooltip.py", [300, 300, 300, 0, 300], [["x"], ["y"], ["y"], ["x"], ["x"]]), + ("pie_chart_with_labels.py", [6, 6], [["category"], ["value"]]), + ("radial_chart.py", [6, 6], [["values"], ["values_start"]]), + ("scatter_linked_table.py", [392, 14, 14, 14], [["Year"], ["Year"], ["Year"], ["Year"]]), + ("scatter_marginal_hist.py", [34, 150, 27], [["__count"], ["species"], ["__count"]]), + ("scatter_with_layered_histogram.py", [2, 19], [["gender"], ["__count"]]), + ("scatter_with_minimap.py", [1461, 1461], [["date"], ["date"]]), + ("scatter_with_rolling_mean.py", [1461, 1461], [["date"], ["rolling_mean"]]), + ("seattle_weather_interactive.py", [1461, 5], [["date"], ["__count"]]), + ("select_detail.py", [20, 1000], [["id"], ["x"]]), + ("simple_scatter_with_errorbars.py", [5, 5], [["x"], ["upper_ymin"]]), + ("stacked_bar_chart_with_text.py", [60, 60], [["site"], ["site"]]), + ("us_employment.py", [120, 1, 2], [["month"], ["president"], ["president"]]), + ("us_population_pyramid_over_time.py", [19, 38, 19], [["gender"], ["year"], ["gender"]]), +]) +def test_compound_chart_examples(filename, all_rows, all_cols): + source = pkgutil.get_data(examples_methods_syntax.__name__, filename) + chart = eval_block(source) + print(chart) + + dfs = chart.transformed_data() + assert len(dfs) == len(all_rows) + for df, rows, cols in zip(dfs, all_rows, all_cols): + assert len(df) == rows + assert set(cols).issubset(set(df.columns)) From 3ae6c7d080d1fab9b6e4bf983473add047cdc697 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Thu, 8 Jun 2023 10:32:02 -0400 Subject: [PATCH 03/21] skip black formatting for pytest.mark.parametrize --- tests/test_transformed_data.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_transformed_data.py b/tests/test_transformed_data.py index 3e198d9a7..31ea0465d 100644 --- a/tests/test_transformed_data.py +++ b/tests/test_transformed_data.py @@ -4,6 +4,7 @@ import pytest +# fmt: off @pytest.mark.parametrize("filename,rows,cols", [ ("annual_weather_heatmap.py", 366, ["monthdate_date_end", "max_temp_max"]), ("anscombe_plot.py", 44, ["Series", "X", "Y"]), @@ -53,6 +54,7 @@ ("wilkinson-dot-plot.py", 21, ["data", "id"]), ("window_rank.py", 12, ["team", "diff"]), ]) +# fmt: on def test_primitive_chart_examples(filename, rows, cols): source = pkgutil.get_data(examples_methods_syntax.__name__, filename) chart = eval_block(source) @@ -61,6 +63,7 @@ def test_primitive_chart_examples(filename, rows, cols): assert set(cols).issubset(set(df.columns)) +# fmt: off @pytest.mark.parametrize("filename,all_rows,all_cols", [ ("errorbars_with_std.py", [10, 10], [["upper_yield"], ["extent_yield"]]), ("candlestick_chart.py", [44, 44], [["low"], ["close"]]), @@ -90,6 +93,7 @@ def test_primitive_chart_examples(filename, rows, cols): ("us_employment.py", [120, 1, 2], [["month"], ["president"], ["president"]]), ("us_population_pyramid_over_time.py", [19, 38, 19], [["gender"], ["year"], ["gender"]]), ]) +# fmt: on def test_compound_chart_examples(filename, all_rows, all_cols): source = pkgutil.get_data(examples_methods_syntax.__name__, filename) chart = eval_block(source) From eed32e9f5eec5cf5560d36bb1d9b0aa6da39a9f3 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Fri, 9 Jun 2023 08:29:26 -0400 Subject: [PATCH 04/21] Test exclude flag to transformed_data --- tests/test_transformed_data.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_transformed_data.py b/tests/test_transformed_data.py index 31ea0465d..928b539f5 100644 --- a/tests/test_transformed_data.py +++ b/tests/test_transformed_data.py @@ -1,5 +1,7 @@ from altair.utils.execeval import eval_block +import altair as alt from tests import examples_methods_syntax +from vega_datasets import data import pkgutil import pytest @@ -104,3 +106,23 @@ def test_compound_chart_examples(filename, all_rows, all_cols): for df, rows, cols in zip(dfs, all_rows, all_cols): assert len(df) == rows assert set(cols).issubset(set(df.columns)) + + +def test_transformed_data_exclude(): + source = data.wheat() + bar = alt.Chart(source).mark_bar().encode(x="year:O", y="wheat:Q") + rule = alt.Chart(source).mark_rule(color="red").encode(y="mean(wheat):Q") + some_annotation = ( + alt.Chart(name="some_annotation") + .mark_text(fontWeight="bold") + .encode(text=alt.value("Just some text"), y=alt.datum(85), x=alt.value(200)) + ) + + chart = (bar + rule + some_annotation).properties(width=600) + datasets = chart.transformed_data(exclude=["some_annotation"]) + + assert len(datasets) == 2 + assert len(datasets[0]) == 52 + assert "wheat_start" in datasets[0] + assert len(datasets[1]) == 1 + assert "mean_wheat" in datasets[1] From 6f61badabd8ac7201c0b6929c9774f9460ac1d0a Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Fri, 9 Jun 2023 08:33:02 -0400 Subject: [PATCH 05/21] chart.transformed_data -> chart._transformed_data Make method internal while still experimental --- altair/vegalite/v5/api.py | 12 ++++++------ tests/test_transformed_data.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index f7cdfdd2d..d78ee11c4 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -2657,7 +2657,7 @@ def to_dict( validate=validate, format=format, ignore=ignore, context=context ) - def transformed_data( + def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, @@ -2943,7 +2943,7 @@ def __or__(self, other): copy |= other return copy - def transformed_data( + def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, @@ -3040,7 +3040,7 @@ def __or__(self, other): copy |= other return copy - def transformed_data( + def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, @@ -3137,7 +3137,7 @@ def __and__(self, other): copy &= other return copy - def transformed_data( + def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, @@ -3233,7 +3233,7 @@ def __init__(self, data=Undefined, layer=(), **kwargs): for prop in combined_dict: self[prop] = combined_dict[prop] - def transformed_data( + def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, @@ -3348,7 +3348,7 @@ def __init__( data=data, spec=spec, facet=facet, params=params, **kwargs ) - def transformed_data( + def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, diff --git a/tests/test_transformed_data.py b/tests/test_transformed_data.py index 928b539f5..1604f2026 100644 --- a/tests/test_transformed_data.py +++ b/tests/test_transformed_data.py @@ -60,7 +60,7 @@ def test_primitive_chart_examples(filename, rows, cols): source = pkgutil.get_data(examples_methods_syntax.__name__, filename) chart = eval_block(source) - df = chart.transformed_data() + df = chart._transformed_data() assert len(df) == rows assert set(cols).issubset(set(df.columns)) @@ -101,7 +101,7 @@ def test_compound_chart_examples(filename, all_rows, all_cols): chart = eval_block(source) print(chart) - dfs = chart.transformed_data() + dfs = chart._transformed_data() assert len(dfs) == len(all_rows) for df, rows, cols in zip(dfs, all_rows, all_cols): assert len(df) == rows @@ -119,7 +119,7 @@ def test_transformed_data_exclude(): ) chart = (bar + rule + some_annotation).properties(width=600) - datasets = chart.transformed_data(exclude=["some_annotation"]) + datasets = chart._transformed_data(exclude=["some_annotation"]) assert len(datasets) == 2 assert len(datasets[0]) == 52 From 2be1f64aa07a40a656326dcf87ba375dbcfab23c Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Fri, 9 Jun 2023 08:44:50 -0400 Subject: [PATCH 06/21] Add VegaFusion as dev dependency --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 727690b22..c7c46d156 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,8 @@ dev = [ "mypy", "pandas-stubs", "types-jsonschema", - "types-setuptools" + "types-setuptools", + "vegafusion[embed]" ] doc = [ "sphinx", From 07c5a00026a76a9efd5ba56e9490825a20c7e486 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Fri, 9 Jun 2023 08:45:53 -0400 Subject: [PATCH 07/21] Add better error message when VegaFusion is not installed --- altair/utils/transformed_data.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/altair/utils/transformed_data.py b/altair/utils/transformed_data.py index 9004d0dea..6c0c9d5a4 100644 --- a/altair/utils/transformed_data.py +++ b/altair/utils/transformed_data.py @@ -59,7 +59,16 @@ def transformed_data(chart, row_limit=None, exclude=None): If input chart is a Chart or Facet Chart, returns a DataFrame of the transformed data Otherwise, returns a list of DataFrames of the transformed data """ - from vegafusion import runtime, get_local_tz, get_inline_datasets_for_spec # type: ignore + try: + from vegafusion import runtime, get_local_tz, get_inline_datasets_for_spec # type: ignore + except ImportError as err: + raise ImportError( + "transformed_data requires the vegafusion-python-embed and vegafusion packages\n" + "These can be installed with pip using:\n" + " pip install vegafusion[embed]\n" + "Or with conda using:\n" + " conda install -c conda-forge vegafusion-python-embed vegafusion" + ) from err if isinstance(chart, Chart): # Add dummy mark if None specified to satisfy Vega-Lite From f0b26eab3a09d2de33433c27a61cca07ab7a07cb Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 12:09:54 -0400 Subject: [PATCH 08/21] Move import Co-authored-by: Stefan Binder --- altair/utils/transformed_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/altair/utils/transformed_data.py b/altair/utils/transformed_data.py index 6c0c9d5a4..58c6d245c 100644 --- a/altair/utils/transformed_data.py +++ b/altair/utils/transformed_data.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Dict, Iterable, overload +from typing import List, Optional, Tuple, Dict, Iterable, overload, Union import pandas as pd From b48f8d30462362f06cdc44bf61353563432b8389 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 12:10:04 -0400 Subject: [PATCH 09/21] move import Co-authored-by: Stefan Binder --- altair/utils/transformed_data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/altair/utils/transformed_data.py b/altair/utils/transformed_data.py index 58c6d245c..79e5e1e59 100644 --- a/altair/utils/transformed_data.py +++ b/altair/utils/transformed_data.py @@ -11,7 +11,6 @@ ConcatChart, data_transformers, ) -from typing import Union from altair.utils.schemapi import Undefined Scope = Tuple[int, ...] From 75cf9580a0abb931bc9d13653f352c94be0432f2 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 12:13:14 -0400 Subject: [PATCH 10/21] Docstring update --- altair/utils/transformed_data.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/altair/utils/transformed_data.py b/altair/utils/transformed_data.py index 79e5e1e59..e10e3ba1f 100644 --- a/altair/utils/transformed_data.py +++ b/altair/utils/transformed_data.py @@ -54,9 +54,10 @@ def transformed_data(chart, row_limit=None, exclude=None): Returns ------- - DataFrame or list of DataFrame - If input chart is a Chart or Facet Chart, returns a DataFrame of the transformed data - Otherwise, returns a list of DataFrames of the transformed data + pandas DataFrame or list of pandas DataFrames or None + If input chart is a Chart or Facet Chart, returns a pandas DataFrame of the + transformed data. Otherwise, returns a list of pandas DataFrames of the + transformed data """ try: from vegafusion import runtime, get_local_tz, get_inline_datasets_for_spec # type: ignore From a46ce1b80ac3a9585b15da6b94e5743fc3c60bcd Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 12:16:36 -0400 Subject: [PATCH 11/21] Make utils.transformed_data internal, use absolute imports --- .../{transformed_data.py => _transformed_data.py} | 0 altair/vegalite/v5/api.py | 12 ++++++------ 2 files changed, 6 insertions(+), 6 deletions(-) rename altair/utils/{transformed_data.py => _transformed_data.py} (100%) diff --git a/altair/utils/transformed_data.py b/altair/utils/_transformed_data.py similarity index 100% rename from altair/utils/transformed_data.py rename to altair/utils/_transformed_data.py diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index d78ee11c4..683e5b155 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -2679,7 +2679,7 @@ def _transformed_data( DataFrame Transformed data as a DataFrame """ - from ...utils.transformed_data import transformed_data + from altair.utils._transformed_data import transformed_data return transformed_data(self, row_limit=row_limit, exclude=exclude) @@ -2965,7 +2965,7 @@ def _transformed_data( list of DataFrame Transformed data for each subplot as a list of DataFrames """ - from ...utils.transformed_data import transformed_data + from altair.utils._transformed_data import transformed_data return transformed_data(self, row_limit=row_limit, exclude=exclude) @@ -3062,7 +3062,7 @@ def _transformed_data( list of DataFrame Transformed data for each subplot as a list of DataFrames """ - from ...utils.transformed_data import transformed_data + from altair.utils._transformed_data import transformed_data return transformed_data(self, row_limit=row_limit, exclude=exclude) @@ -3159,7 +3159,7 @@ def _transformed_data( list of DataFrame Transformed data for each subplot as a list of DataFrames """ - from ...utils.transformed_data import transformed_data + from altair.utils._transformed_data import transformed_data return transformed_data(self, row_limit=row_limit, exclude=exclude) @@ -3255,7 +3255,7 @@ def _transformed_data( list of DataFrame Transformed data for each layer as a list of DataFrames """ - from ...utils.transformed_data import transformed_data + from altair.utils._transformed_data import transformed_data return transformed_data(self, row_limit=row_limit, exclude=exclude) @@ -3370,7 +3370,7 @@ def _transformed_data( DataFrame Transformed data as a DataFrame """ - from ...utils.transformed_data import transformed_data + from altair.utils._transformed_data import transformed_data return transformed_data(self, row_limit=row_limit, exclude=exclude) From 48f802ced44c03781c3c200ec0592c4a97a9e488 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 12:17:07 -0400 Subject: [PATCH 12/21] Reword docstring Co-authored-by: Stefan Binder --- altair/utils/transformed_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/altair/utils/transformed_data.py b/altair/utils/transformed_data.py index 79e5e1e59..9bdbe7d58 100644 --- a/altair/utils/transformed_data.py +++ b/altair/utils/transformed_data.py @@ -159,7 +159,7 @@ def name_chart( i : int (default 0) Starting chart index exclude : iterable of str - Set of the names of charts to exclude + Names of charts to exclude Returns ------- From 280eb0f034ef14073ff873875b90d94289da8931 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 13:39:47 -0400 Subject: [PATCH 13/21] Remove magic, use "view" instead of chart or mark --- altair/utils/_transformed_data.py | 39 ++++++++++++++++--------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index bc470d497..6c0dad409 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -16,7 +16,7 @@ Scope = Tuple[int, ...] FacetMapping = Dict[Tuple[str, Scope], Tuple[str, Scope]] -MAGIC_CHART_NAME = "_vf_mark{}" +VIEW_NAME = "altair_view_{}" @overload @@ -78,9 +78,9 @@ def transformed_data(chart, row_limit=None, exclude=None): # Deep copy chart so that we can rename marks without affecting caller chart = chart.copy(deep=True) - # Rename chart or subcharts with magic names that we can look up in the + # Ensure that all views are named so that we can look them up in the # resulting Vega specification - chart_names = name_chart(chart, 0, exclude=exclude) + chart_names = name_views(chart, 0, exclude=exclude) # Compile to Vega and extract inline DataFrames with data_transformers.enable("vegafusion-inline"): @@ -89,7 +89,7 @@ def transformed_data(chart, row_limit=None, exclude=None): # Build mapping from mark names to vega datasets facet_mapping = get_facet_mapping(vega_spec) - dataset_mapping = get_datasets_for_chart_names( + dataset_mapping = get_datasets_for_view_names( vega_spec, chart_names, facet_mapping ) @@ -122,36 +122,36 @@ def transformed_data(chart, row_limit=None, exclude=None): return datasets -def make_magic_chart_name(i: int) -> str: - """Make magic chart name for chart number i +def make_view_name(i: int) -> str: + """Make view name for view number i Parameters ---------- i : int - Mark number + View number Returns ------- str - Mark name + View name """ - return MAGIC_CHART_NAME.format(i) + return VIEW_NAME.format(i) -def name_chart( +def name_views( chart: Union[ Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, ConcatChart ], i: int = 0, exclude: Optional[Iterable[str]] = None, ) -> List[str]: - """Name unnamed charts and subcharts + """Name unnamed chart views - Name unnamed charts and subcharts so that we can look them up later in + Name unnamed charts views so that we can look them up later in the compiled Vega spec. Note: This function mutates the input chart by applying names to - unnamed charts. + unnamed views. Parameters ---------- @@ -173,7 +173,8 @@ def name_chart( # the name if chart.name not in exclude: if chart.name in (None, Undefined): - name = make_magic_chart_name(i) + # Add name since none is specified + name = make_view_name(i) chart.name = name return [chart.name] else: @@ -196,7 +197,7 @@ def name_chart( chart_names: List[str] = [] for subchart in subcharts: - for name in name_chart(subchart, i=i + len(chart_names), exclude=exclude): + for name in name_views(subchart, i=i + len(chart_names), exclude=exclude): chart_names.append(name) return chart_names @@ -484,13 +485,13 @@ def get_from_facet_mapping( return scoped_dataset -def get_datasets_for_chart_names( +def get_datasets_for_view_names( group: dict, vl_chart_names: List[str], facet_mapping: FacetMapping, scope: Scope = (), ) -> Dict[str, Tuple[str, Scope]]: - """Get the Vega datasets that correspond to the provided Altair chart names + """Get the Vega datasets that correspond to the provided Altair view names Parameters ---------- @@ -506,7 +507,7 @@ def get_datasets_for_chart_names( Returns ------- dict from str to (str, tuple of int) - Dict from Altair chart names to scoped datasets + Dict from Altair view names to scoped datasets """ datasets = {} group_index = 0 @@ -523,7 +524,7 @@ def get_datasets_for_chart_names( name = mark.get("name", "") if mark.get("type", "") == "group": - group_data_names = get_datasets_for_chart_names( + group_data_names = get_datasets_for_view_names( group, vl_chart_names, facet_mapping, scope=scope + (group_index,) ) for k, v in group_data_names.items(): From aabf5d6e4f7fc62fc49c4bb21b4c7a0f867e85a4 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 13:40:04 -0400 Subject: [PATCH 14/21] Reword --- altair/utils/_transformed_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index 6c0dad409..84684f542 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -71,7 +71,7 @@ def transformed_data(chart, row_limit=None, exclude=None): ) from err if isinstance(chart, Chart): - # Add dummy mark if None specified to satisfy Vega-Lite + # Add mark if none is specified to satisfy Vega-Lite if chart.mark == Undefined: chart = chart.mark_point() From 16250fd7f4d12246e37e6a7a2f9f70dc391fe3f8 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 13:40:22 -0400 Subject: [PATCH 15/21] Remove incorrect comment --- altair/utils/_transformed_data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index 84684f542..8dd0e449b 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -169,8 +169,6 @@ def name_views( """ exclude = set(exclude) if exclude is not None else set() if isinstance(chart, (Chart, FacetChart)): - # Perform shallow copy of chart so that we can change - # the name if chart.name not in exclude: if chart.name in (None, Undefined): # Add name since none is specified From 8ab1dce21aeef38b1f1932238ffb7d2c77d5d024 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 14:00:05 -0400 Subject: [PATCH 16/21] black --- altair/utils/_transformed_data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index 8dd0e449b..a343269e3 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -89,9 +89,7 @@ def transformed_data(chart, row_limit=None, exclude=None): # Build mapping from mark names to vega datasets facet_mapping = get_facet_mapping(vega_spec) - dataset_mapping = get_datasets_for_view_names( - vega_spec, chart_names, facet_mapping - ) + dataset_mapping = get_datasets_for_view_names(vega_spec, chart_names, facet_mapping) # Build a list of vega dataset names that corresponds to the order # of the chart components From 6f43d6b94614134792671d2c1a90aff588682c94 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 15:37:00 -0400 Subject: [PATCH 17/21] Use DataFrameLike protocol for the transformed_data signature The returned DataFrames can be arrow or Polars as well if those are used as the input to the Chart. --- altair/__init__.py | 1 + altair/utils/_transformed_data.py | 13 ++++++------- altair/utils/core.py | 9 +++++++-- altair/vegalite/v5/api.py | 13 +++++++------ 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/altair/__init__.py b/altair/__init__.py index 1a66239a5..3c0a19f13 100644 --- a/altair/__init__.py +++ b/altair/__init__.py @@ -124,6 +124,7 @@ "Cyclical", "Data", "DataFormat", + "DataFrameLike", "DataSource", "Datasets", "DateTime", diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index a343269e3..fa89f8727 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -1,7 +1,5 @@ from typing import List, Optional, Tuple, Dict, Iterable, overload, Union -import pandas as pd - from altair import ( Chart, FacetChart, @@ -11,6 +9,7 @@ ConcatChart, data_transformers, ) +from altair.utils.core import DataFrameLike from altair.utils.schemapi import Undefined Scope = Tuple[int, ...] @@ -24,7 +23,7 @@ def transformed_data( chart: Union[Chart, FacetChart], row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, -) -> Optional[pd.DataFrame]: +) -> Optional[DataFrameLike]: ... @@ -33,7 +32,7 @@ def transformed_data( chart: Union[LayerChart, HConcatChart, VConcatChart, ConcatChart], row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, -) -> List[pd.DataFrame]: +) -> List[DataFrameLike]: ... @@ -54,9 +53,9 @@ def transformed_data(chart, row_limit=None, exclude=None): Returns ------- - pandas DataFrame or list of pandas DataFrames or None - If input chart is a Chart or Facet Chart, returns a pandas DataFrame of the - transformed data. Otherwise, returns a list of pandas DataFrames of the + DataFrame or list of DataFrames or None + If input chart is a Chart or Facet Chart, returns a DataFrame of the + transformed data. Otherwise, returns a list of DataFrames of the transformed data """ try: diff --git a/altair/utils/core.py b/altair/utils/core.py index 41e886001..4821de3c6 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -18,9 +18,9 @@ from altair.utils.schemapi import SchemaBase if sys.version_info >= (3, 10): - from typing import ParamSpec + from typing import ParamSpec, Protocol else: - from typing_extensions import ParamSpec + from typing_extensions import ParamSpec, Protocol try: from pandas.api.types import infer_dtype as _infer_dtype @@ -32,6 +32,11 @@ _P = ParamSpec("_P") +class DataFrameLike(Protocol): + def __dataframe__(self, *args, **kwargs): + ... + + def infer_dtype(value): """Infer the dtype of the value. diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 683e5b155..ba7801143 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -21,6 +21,7 @@ from .display import renderers, VEGALITE_VERSION, VEGAEMBED_VERSION, VEGA_VERSION from .theme import themes from .compiler import vegalite_compilers +from ...utils.core import DataFrameLike if sys.version_info >= (3, 11): from typing import Self @@ -2661,7 +2662,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> Optional[pd.DataFrame]: + ) -> Optional[DataFrameLike]: """Evaluate a Chart's transforms Evaluate the data transforms associated with a Chart and return the @@ -2947,7 +2948,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[pd.DataFrame]: + ) -> List[DataFrameLike]: """Evaluate a ConcatChart's transforms Evaluate the data transforms associated with a ConcatChart and return the @@ -3044,7 +3045,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[pd.DataFrame]: + ) -> List[DataFrameLike]: """Evaluate a HConcatChart's transforms Evaluate the data transforms associated with a HConcatChart and return the @@ -3141,7 +3142,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[pd.DataFrame]: + ) -> List[DataFrameLike]: """Evaluate a VConcatChart's transforms Evaluate the data transforms associated with a VConcatChart and return the @@ -3237,7 +3238,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[pd.DataFrame]: + ) -> List[DataFrameLike]: """Evaluate a LayerChart's transforms Evaluate the data transforms associated with a LayerChart and return the @@ -3352,7 +3353,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> Optional[pd.DataFrame]: + ) -> Optional[DataFrameLike]: """Evaluate a FacetChart's transforms Evaluate the data transforms associated with a FacetChart and return the From a738408a8055bb9c7f3346d58411e4336d641346 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 15:39:31 -0400 Subject: [PATCH 18/21] Add NotImplementedError for RepeatChart --- altair/vegalite/v5/api.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index ba7801143..86de9b2ac 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -2859,6 +2859,32 @@ def __init__( **kwds, ) + def _transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> Optional[DataFrameLike]: + """Evaluate a RepeatChart's transforms + + Evaluate the data transforms associated with a RepeatChart and return the + transformed data a DataFrame + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Raises + ------ + NotImplementedError + RepeatChart does not yet support transformed_data + """ + raise NotImplementedError( + "transformed_data is not yet implemented for RepeatChart" + ) + def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: """Make chart axes scales interactive From 88fceb56d5f3fd8d6f5d9c8bfd258727758471de Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 15:43:50 -0400 Subject: [PATCH 19/21] Use Chart._get_name to name subcharts --- altair/utils/_transformed_data.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index fa89f8727..287afba85 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -15,8 +15,6 @@ Scope = Tuple[int, ...] FacetMapping = Dict[Tuple[str, Scope], Tuple[str, Scope]] -VIEW_NAME = "altair_view_{}" - @overload def transformed_data( @@ -119,22 +117,6 @@ def transformed_data(chart, row_limit=None, exclude=None): return datasets -def make_view_name(i: int) -> str: - """Make view name for view number i - - Parameters - ---------- - i : int - View number - - Returns - ------- - str - View name - """ - return VIEW_NAME.format(i) - - def name_views( chart: Union[ Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, ConcatChart @@ -169,8 +151,7 @@ def name_views( if chart.name not in exclude: if chart.name in (None, Undefined): # Add name since none is specified - name = make_view_name(i) - chart.name = name + chart.name = Chart._get_name() return [chart.name] else: return [] From 1416f4d40b7db9557c62a0fca41662d4b3b093bc Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Mon, 12 Jun 2023 09:09:26 -0400 Subject: [PATCH 20/21] Protocol is available in Python 3.8 --- altair/utils/core.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index 4821de3c6..89249719a 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -18,9 +18,14 @@ from altair.utils.schemapi import SchemaBase if sys.version_info >= (3, 10): - from typing import ParamSpec, Protocol + from typing import ParamSpec else: - from typing_extensions import ParamSpec, Protocol + from typing_extensions import ParamSpec + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol try: from pandas.api.types import infer_dtype as _infer_dtype From c665e8f3e214269e6b1d6c08de6b9544d44bfcc9 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Mon, 12 Jun 2023 15:55:33 -0400 Subject: [PATCH 21/21] Make DataFrameLike private for now --- altair/__init__.py | 1 - altair/utils/_transformed_data.py | 6 +++--- altair/utils/core.py | 2 +- altair/vegalite/v5/api.py | 16 ++++++++-------- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/altair/__init__.py b/altair/__init__.py index 3c0a19f13..1a66239a5 100644 --- a/altair/__init__.py +++ b/altair/__init__.py @@ -124,7 +124,6 @@ "Cyclical", "Data", "DataFormat", - "DataFrameLike", "DataSource", "Datasets", "DateTime", diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index 287afba85..25e2b86b2 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -9,7 +9,7 @@ ConcatChart, data_transformers, ) -from altair.utils.core import DataFrameLike +from altair.utils.core import _DataFrameLike from altair.utils.schemapi import Undefined Scope = Tuple[int, ...] @@ -21,7 +21,7 @@ def transformed_data( chart: Union[Chart, FacetChart], row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, -) -> Optional[DataFrameLike]: +) -> Optional[_DataFrameLike]: ... @@ -30,7 +30,7 @@ def transformed_data( chart: Union[LayerChart, HConcatChart, VConcatChart, ConcatChart], row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, -) -> List[DataFrameLike]: +) -> List[_DataFrameLike]: ... diff --git a/altair/utils/core.py b/altair/utils/core.py index 89249719a..61e370b1d 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -37,7 +37,7 @@ _P = ParamSpec("_P") -class DataFrameLike(Protocol): +class _DataFrameLike(Protocol): def __dataframe__(self, *args, **kwargs): ... diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 86de9b2ac..9e97e3dff 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -21,7 +21,7 @@ from .display import renderers, VEGALITE_VERSION, VEGAEMBED_VERSION, VEGA_VERSION from .theme import themes from .compiler import vegalite_compilers -from ...utils.core import DataFrameLike +from ...utils.core import _DataFrameLike if sys.version_info >= (3, 11): from typing import Self @@ -2662,7 +2662,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> Optional[DataFrameLike]: + ) -> Optional[_DataFrameLike]: """Evaluate a Chart's transforms Evaluate the data transforms associated with a Chart and return the @@ -2863,7 +2863,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> Optional[DataFrameLike]: + ) -> Optional[_DataFrameLike]: """Evaluate a RepeatChart's transforms Evaluate the data transforms associated with a RepeatChart and return the @@ -2974,7 +2974,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[DataFrameLike]: + ) -> List[_DataFrameLike]: """Evaluate a ConcatChart's transforms Evaluate the data transforms associated with a ConcatChart and return the @@ -3071,7 +3071,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[DataFrameLike]: + ) -> List[_DataFrameLike]: """Evaluate a HConcatChart's transforms Evaluate the data transforms associated with a HConcatChart and return the @@ -3168,7 +3168,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[DataFrameLike]: + ) -> List[_DataFrameLike]: """Evaluate a VConcatChart's transforms Evaluate the data transforms associated with a VConcatChart and return the @@ -3264,7 +3264,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[DataFrameLike]: + ) -> List[_DataFrameLike]: """Evaluate a LayerChart's transforms Evaluate the data transforms associated with a LayerChart and return the @@ -3379,7 +3379,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> Optional[DataFrameLike]: + ) -> Optional[_DataFrameLike]: """Evaluate a FacetChart's transforms Evaluate the data transforms associated with a FacetChart and return the