From 6ccebeb6bed9c976d7eed23cb6bfa7ea70081c6e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 11 Aug 2021 17:44:25 +0200 Subject: [PATCH 001/159] no need to wrap pandas index in lazy index adapter --- xarray/core/indexes.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 429c37af588..6a04952195e 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -16,11 +16,7 @@ import pandas as pd from . import formatting, utils -from .indexing import ( - LazilyIndexedArray, - PandasIndexingAdapter, - PandasMultiIndexingAdapter, -) +from .indexing import PandasIndexingAdapter, PandasMultiIndexingAdapter from .utils import is_dict_like, is_scalar if TYPE_CHECKING: @@ -269,9 +265,7 @@ def _create_variables_from_multiindex(index, dim, level_meta=None): variables = {} dim_coord_adapter = PandasMultiIndexingAdapter(index) - variables[dim] = IndexVariable( - dim, LazilyIndexedArray(dim_coord_adapter), fastpath=True - ) + variables[dim] = IndexVariable(dim, dim_coord_adapter, fastpath=True) for level in index.names: meta = level_meta.get(level, {}) From 6d5c79df9de9d6a71c4526d6d44e4077ecab760e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 11 Aug 2021 17:46:26 +0200 Subject: [PATCH 002/159] multi-index default level names --- xarray/core/indexes.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 6a04952195e..13ea53a245c 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -314,6 +314,11 @@ def from_variables(cls, variables: Mapping[Hashable, "Variable"]): @classmethod def from_pandas_index(cls, index: pd.MultiIndex, dim: Hashable): + levels = [ + name if name is not None else f"{dim}_level_{i}" + for i, name in enumerate(index.names) + ] + index = index.rename(levels) index_vars = _create_variables_from_multiindex(index, dim) return cls(index, dim), index_vars From 4742100752cdad61b6cab5107aca2a37632b9c4e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 11 Aug 2021 17:47:18 +0200 Subject: [PATCH 003/159] refactor setting Dataset/DataArray default indexes --- xarray/core/dataarray.py | 6 ++---- xarray/core/merge.py | 32 +++++++++++++++++++++++++++----- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 900af885319..8f0f265fa85 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -53,7 +53,7 @@ from .formatting import format_item from .indexes import Index, Indexes, default_indexes, propagate_indexes from .indexing import is_fancy_indexer -from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords +from .merge import PANDAS_TYPES, MergeError, _create_indexes_from_coords from .options import OPTIONS, _get_keep_attrs from .utils import ( Default, @@ -403,9 +403,7 @@ def __init__( data = as_compatible_data(data) coords, dims = _infer_coords_and_dims(data.shape, coords, dims) variable = Variable(dims, data, attrs, fastpath=True) - indexes = dict( - _extract_indexes_from_coords(coords) - ) # needed for to_dataset + indexes, coords = _create_indexes_from_coords(coords) # These fully describe a DataArray self._variable = variable diff --git a/xarray/core/merge.py b/xarray/core/merge.py index b8b32bdaa01..25d95a88a49 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -20,7 +20,7 @@ from . import dtypes, pdcompat from .alignment import deep_align from .duck_array_ops import lazy_array_equiv -from .indexes import Index, PandasIndex +from .indexes import Index, PandasIndex, PandasMultiIndex from .utils import Frozen, compat_dict_union, dict_equiv, equivalent from .variable import Variable, as_variable, assert_unique_multiindex_level_names @@ -478,20 +478,42 @@ def merge_coords( def merge_data_and_coords(data, coords, compat="broadcast_equals", join="outer"): """Used in Dataset.__init__.""" + indexes, coords = _create_indexes_from_coords(coords) objects = [data, coords] explicit_coords = coords.keys() - indexes = dict(_extract_indexes_from_coords(coords)) return merge_core( objects, compat, join, explicit_coords=explicit_coords, indexes=indexes ) -def _extract_indexes_from_coords(coords): - """Yields the name & index of valid indexes from a mapping of coords""" +def _create_indexes_from_coords(coords): + """Maybe create default indexes from a mapping of coordinates. + + Return those indexes and updated coordinates. + """ + indexes = {} + updated_coords = {} + for name, variable in coords.items(): variable = as_variable(variable, name=name) + print(variable._data) + if variable.dims == (name,): - yield name, variable._to_xindex() + array = getattr(variable._data, "array", None) + if isinstance(array, pd.MultiIndex): + # TODO: benbovy - explicit indexes: depreciate passing multi-indexes as coords? + # (instead pass them explicitly as indexes once it is supported via public API) + index, index_vars = PandasMultiIndex.from_pandas_index(array, name) + else: + index, index_vars = PandasIndex.from_variables({name: variable}) + + indexes[name] = index + updated_coords.update(index_vars) + + else: + updated_coords[name] = variable + + return indexes, updated_coords def assert_valid_explicit_coords(variables, dims, explicit_coords): From 55dd9d97daf4951e6ee0fd29fa89e8668cce483e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 11 Aug 2021 17:50:03 +0200 Subject: [PATCH 004/159] update multi-index (text) repr Notes: - move the multi-index formatting logic into PandasMultiIndexingAdapter._repr_inline_ - inline repr: check for _repr_inline_ implementation first --- xarray/core/formatting.py | 67 +++++++++++++++------------------------ xarray/core/indexing.py | 18 +++++++++++ 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 7f292605e63..63d9382aaaa 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -4,7 +4,7 @@ import functools from datetime import datetime, timedelta from itertools import chain, zip_longest -from typing import Hashable +from typing import Hashable, Mapping import numpy as np import pandas as pd @@ -256,10 +256,10 @@ def inline_sparse_repr(array): def inline_variable_array_repr(var, max_width): """Build a one-line summary of a variable's data.""" - if var._in_memory: - return format_array_flat(var, max_width) - elif hasattr(var._data, "_repr_inline_"): + if hasattr(var._data, "_repr_inline_"): return var._data._repr_inline_(max_width) + elif var._in_memory: + return format_array_flat(var, max_width) elif isinstance(var._data, dask_array_type): return inline_dask_repr(var.data) elif isinstance(var._data, sparse_array_type): @@ -294,43 +294,12 @@ def summarize_variable( return front_str + values_str -def _summarize_coord_multiindex(coord, col_width, marker): - first_col = pretty_print(f" {marker} {coord.name} ", col_width) - return "{}({}) MultiIndex".format(first_col, str(coord.dims[0])) - - -def _summarize_coord_levels(coord, col_width, marker="-"): - if len(coord) > 100 and col_width < len(coord): - n_values = col_width - indices = list(range(0, n_values)) + list(range(-n_values, 0)) - subset = coord[indices] - else: - subset = coord - - return "\n".join( - summarize_variable( - lname, subset.get_level_variable(lname), col_width, marker=marker - ) - for lname in subset.level_names - ) - - def summarize_datavar(name, var, col_width): return summarize_variable(name, var.variable, col_width) -def summarize_coord(name: Hashable, var, col_width: int): - is_index = name in var.dims - marker = "*" if is_index else " " - if is_index: - coord = var.variable.to_index_variable() - if coord.level_names is not None: - return "\n".join( - [ - _summarize_coord_multiindex(coord, col_width, marker), - _summarize_coord_levels(coord, col_width), - ] - ) +def summarize_coord(name: Hashable, var, col_width: int, indexes: Mapping): + marker = "*" if name in indexes else " " return summarize_variable(name, var.variable, col_width, marker) @@ -373,12 +342,20 @@ def _calculate_col_width(col_items): def _mapping_repr( - mapping, title, summarizer, expand_option_name, col_width=None, max_rows=None + mapping, + title, + summarizer, + expand_option_name, + col_width=None, + max_rows=None, + summarizer_kwargs=None, ): if col_width is None: col_width = _calculate_col_width(mapping) if max_rows is None: max_rows = OPTIONS["display_max_rows"] + if summarizer_kwargs is None: + summarizer_kwargs = {} summary = [f"{title}:"] if mapping: len_mapping = len(mapping) @@ -388,15 +365,22 @@ def _mapping_repr( summary = [f"{summary[0]} ({max_rows}/{len_mapping})"] first_rows = max_rows // 2 + max_rows % 2 keys = list(mapping.keys()) - summary += [summarizer(k, mapping[k], col_width) for k in keys[:first_rows]] + summary += [ + summarizer(k, mapping[k], col_width, **summarizer_kwargs) + for k in keys[:first_rows] + ] if max_rows > 1: last_rows = max_rows // 2 summary += [pretty_print(" ...", col_width) + " ..."] summary += [ - summarizer(k, mapping[k], col_width) for k in keys[-last_rows:] + summarizer(k, mapping[k], col_width, **summarizer_kwargs) + for k in keys[-last_rows:] ] else: - summary += [summarizer(k, v, col_width) for k, v in mapping.items()] + summary += [ + summarizer(k, v, col_width, **summarizer_kwargs) + for k, v in mapping.items() + ] else: summary += [EMPTY_REPR] return "\n".join(summary) @@ -427,6 +411,7 @@ def coords_repr(coords, col_width=None): summarizer=summarize_coord, expand_option_name="display_expand_coords", col_width=col_width, + summarizer_kwargs={"indexes": coords.indexes}, ) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 70994a36ac8..54b89b0f51a 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1406,3 +1406,21 @@ def __repr__(self) -> str: else: props = "(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})" return f"{type(self).__name__}{props}" + + def _repr_inline_(self, max_width) -> str: + # special implementation to speed-up the repr for big multi-indexes + if self.level is None: + return "MultiIndex" + else: + from .formatting import format_array_flat + + if self.size > 100 and max_width < self.size: + n_values = max_width + indices = np.concatenate( + [np.arange(0, n_values), np.arange(-n_values, 0)] + ) + subset = self[indices] + else: + subset = self + + return format_array_flat(np.asarray(subset), max_width) From 752a405a23096e9f403c334f52287db1ecb14aad Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 11 Aug 2021 18:01:05 +0200 Subject: [PATCH 005/159] remove print --- xarray/core/merge.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 25d95a88a49..f4f8d2b4c55 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -496,7 +496,6 @@ def _create_indexes_from_coords(coords): for name, variable in coords.items(): variable = as_variable(variable, name=name) - print(variable._data) if variable.dims == (name,): array = getattr(variable._data, "array", None) From 1250ccfdc722ae0d44e3d70fc6df74f7ffcd9948 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 12 Aug 2021 13:47:02 +0200 Subject: [PATCH 006/159] minor fixes and improvements --- xarray/core/indexing.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 54b89b0f51a..3a28919ffac 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -71,13 +71,6 @@ def group_indexers_by_index(data_obj, indexers, method=None, tolerance=None): try: index = xindexes[key] coord = data_obj.coords[key] - dim = coord.dims[0] - if dim not in indexes: - indexes[dim] = index - - label = maybe_cast_to_coords_dtype(label, coord.dtype) - grouped_indexers[dim][key] = label - except KeyError: if key in data_obj.coords: raise KeyError(f"no index found for coordinate {key}") @@ -91,6 +84,13 @@ def group_indexers_by_index(data_obj, indexers, method=None, tolerance=None): "an associated coordinate." ) grouped_indexers[None][key] = label + else: + dim = coord.dims[0] + if dim not in indexes: + indexes[dim] = index + + label = maybe_cast_to_coords_dtype(label, coord.dtype) + grouped_indexers[dim][key] = label return indexes, grouped_indexers @@ -1404,7 +1404,9 @@ def __repr__(self) -> str: if self.level is None: return super().__repr__() else: - props = "(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})" + props = ( + f"(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})" + ) return f"{type(self).__name__}{props}" def _repr_inline_(self, max_width) -> str: From 1bb61d95a225b058c84bb05b27c542f29d8b9afc Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 12 Aug 2021 14:35:44 +0200 Subject: [PATCH 007/159] fix dtype of index variables created from Index --- xarray/core/indexes.py | 17 ++++++++++------- xarray/tests/test_indexes.py | 17 ++++++++++++----- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 13ea53a245c..9223631915f 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -169,7 +169,7 @@ def from_variables(cls, variables: Mapping[Hashable, "Variable"]): obj = cls(var.data, dim) - data = PandasIndexingAdapter(obj.index) + data = PandasIndexingAdapter(obj.index, dtype=var.dtype) index_var = IndexVariable( dim, data, attrs=var.attrs, encoding=var.encoding, fastpath=True ) @@ -314,12 +314,15 @@ def from_variables(cls, variables: Mapping[Hashable, "Variable"]): @classmethod def from_pandas_index(cls, index: pd.MultiIndex, dim: Hashable): - levels = [ - name if name is not None else f"{dim}_level_{i}" - for i, name in enumerate(index.names) - ] - index = index.rename(levels) - index_vars = _create_variables_from_multiindex(index, dim) + level_meta = {} + for i, idx in enumerate(index.levels): + name = idx.name or f"{dim}_level_{i}" + level_meta[name] = {"dtype": idx.dtype} + + index = index.rename(level_meta.keys()) + index_vars = _create_variables_from_multiindex( + index, dim, level_meta=level_meta + ) return cls(index, dim), index_vars def query(self, labels, method=None, tolerance=None): diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index c8ba72a253f..e753da5de0e 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -28,12 +28,15 @@ def test_constructor(self): assert index.dim == "x" def test_from_variables(self): + # pandas has only Float64Index but variable dtype should be preserved + data = np.array([1.1, 2.2, 3.3], dtype=np.float32) var = xr.Variable( - "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} + "x", data, attrs={"unit": "m"}, encoding={"dtype": np.float64} ) index, index_vars = PandasIndex.from_variables({"x": var}) xr.testing.assert_identical(var.to_index_variable(), index_vars["x"]) + assert index_vars["x"].dtype == var.dtype assert index.dim == "x" assert index.index.equals(index_vars["x"].to_index()) @@ -166,16 +169,20 @@ def test_from_variables(self): PandasMultiIndex.from_variables({"level1": v_level1, "level3": v_level3}) def test_from_pandas_index(self): - pd_idx = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]], names=("foo", "bar")) + foo_data = np.array([0, 0, 1], dtype="int") + bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) index, index_vars = PandasMultiIndex.from_pandas_index(pd_idx, "x") assert index.dim == "x" - assert index.index is pd_idx + assert index.index.equals(pd_idx) assert index.index.names == ("foo", "bar") xr.testing.assert_identical(index_vars["x"], IndexVariable("x", pd_idx)) - xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3])) - xr.testing.assert_identical(index_vars["bar"], IndexVariable("x", [4, 5, 6])) + xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", foo_data)) + xr.testing.assert_identical(index_vars["bar"], IndexVariable("x", bar_data)) + assert index_vars["foo"].dtype == foo_data.dtype + assert index_vars["bar"].dtype == bar_data.dtype def test_query(self): index = PandasMultiIndex( From a551c7f05abf90a492fb59068b59ebb2bac8cb4c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 12 Aug 2021 14:37:42 +0200 Subject: [PATCH 008/159] fix multi-index selection regression See https://github.com/pydata/xarray/issues/5691 --- xarray/core/indexes.py | 22 +++++++++++++++------- xarray/tests/test_dataarray.py | 14 ++++++++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 9223631915f..f1cc7dfaed1 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -129,6 +129,15 @@ def _is_nested_tuple(possible_tuple): ) +def normalize_label(value, extract_scalar=False): + if getattr(value, "ndim", 1) <= 1: + value = _asarray_tuplesafe(value) + if extract_scalar: + # see https://github.com/pydata/xarray/pull/4292 for details + value = value[()] if value.dtype.kind in "mM" else value.item() + return value + + def get_indexer_nd(index, labels, method=None, tolerance=None): """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional labels @@ -207,14 +216,9 @@ def query(self, labels, method=None, tolerance=None): "a dimension that does not have a MultiIndex" ) else: - label = ( - label - if getattr(label, "ndim", 1) > 1 # vectorized-indexing - else _asarray_tuplesafe(label) - ) + label = normalize_label(label) if label.ndim == 0: - # see https://github.com/pydata/xarray/pull/4292 for details - label_value = label[()] if label.dtype.kind in "mM" else label.item() + label_value = normalize_label(label, extract_scalar=True) if isinstance(self.index, pd.CategoricalIndex): if method is not None: raise ValueError( @@ -336,6 +340,10 @@ def query(self, labels, method=None, tolerance=None): # label(s) given for multi-index level(s) if all([lbl in self.index.names for lbl in labels]): is_nested_vals = _is_nested_tuple(tuple(labels.values())) + labels = { + k: normalize_label(v, extract_scalar=True) for k, v in labels.items() + } + if len(labels) == self.index.nlevels and not is_nested_vals: indexer = self.index.get_loc(tuple(labels[k] for k in self.index.names)) else: diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8ab8bc872da..5205c1b59ab 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1006,6 +1006,20 @@ def test_sel_float(self): assert_equal(expected_scalar, actual_scalar) assert_equal(expected_16, actual_16) + def test_sel_float_multiindex(self): + # regression test https://github.com/pydata/xarray/issues/5691 + midx = pd.MultiIndex.from_arrays( + [["a", "a", "b", "b"], [0.1, 0.2, 0.3, 0.4]], names=["lvl1", "lvl2"] + ) + da = xr.DataArray([1, 2, 3, 4], coords={"x": midx}, dims="x") + + actual = da.sel(lvl1="a", lvl2=0.1) + expected = da.isel(x=0) + + assert_equal(actual, expected) + + # TODO: test multi-index created from coordinates, one with dtype=float32 + def test_sel_no_index(self): array = DataArray(np.arange(10), dims="x") assert_identical(array[0], array.sel(x=0)) From abc338474dd25f48472d8e2cf4a47d443cdd3568 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 13 Aug 2021 18:11:18 +0200 Subject: [PATCH 009/159] check conflicting multi-index level names --- xarray/core/indexes.py | 4 ++++ xarray/core/merge.py | 21 +++++++++++++++++---- xarray/tests/test_indexes.py | 3 +++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index f1cc7dfaed1..eb4e6c368a1 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -321,6 +321,10 @@ def from_pandas_index(cls, index: pd.MultiIndex, dim: Hashable): level_meta = {} for i, idx in enumerate(index.levels): name = idx.name or f"{dim}_level_{i}" + if name == dim: + raise ValueError( + f"conflicting multi-index level name {name!r} with dimension {dim!r}" + ) level_meta[name] = {"dtype": idx.dtype} index = index.rename(level_meta.keys()) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index f4f8d2b4c55..214bdc8b1c9 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -476,21 +476,25 @@ def merge_coords( return variables, out_indexes -def merge_data_and_coords(data, coords, compat="broadcast_equals", join="outer"): +def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="outer"): """Used in Dataset.__init__.""" - indexes, coords = _create_indexes_from_coords(coords) - objects = [data, coords] + indexes, coords = _create_indexes_from_coords(coords, data_vars) + objects = [data_vars, coords] explicit_coords = coords.keys() return merge_core( objects, compat, join, explicit_coords=explicit_coords, indexes=indexes ) -def _create_indexes_from_coords(coords): +def _create_indexes_from_coords(coords, data_vars=None): """Maybe create default indexes from a mapping of coordinates. Return those indexes and updated coordinates. """ + all_var_names = list(coords.keys()) + if data_vars is not None: + all_var_names += list(data_vars.keys()) + indexes = {} updated_coords = {} @@ -503,6 +507,15 @@ def _create_indexes_from_coords(coords): # TODO: benbovy - explicit indexes: depreciate passing multi-indexes as coords? # (instead pass them explicitly as indexes once it is supported via public API) index, index_vars = PandasMultiIndex.from_pandas_index(array, name) + # check for conflict between level names and variable names + duplicate_names = [ + k for k in index_vars if k in all_var_names and k != name + ] + if duplicate_names: + conflict_str = "\n".join(duplicate_names) + raise ValueError( + f"conflicting MultiIndex level name(s):\n{conflict_str}" + ) else: index, index_vars = PandasIndex.from_variables({name: variable}) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index e753da5de0e..410fbb61b26 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -184,6 +184,9 @@ def test_from_pandas_index(self): assert index_vars["foo"].dtype == foo_data.dtype assert index_vars["bar"].dtype == bar_data.dtype + with pytest.raises(ValueError, match=".*conflicting multi-index level name.*"): + PandasMultiIndex.from_pandas_index(pd_idx, "foo") + def test_query(self): index = PandasMultiIndex( pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x" From 92c8ca4ed1ba0c99933d6e8feff800d46fe6201e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 13 Aug 2021 18:12:47 +0200 Subject: [PATCH 010/159] update formatting (text and html) --- xarray/core/formatting.py | 97 ++++++++++++++++++---------- xarray/core/formatting_html.py | 46 ++++--------- xarray/tests/test_dataarray.py | 12 ++-- xarray/tests/test_dataset.py | 12 ++-- xarray/tests/test_formatting_html.py | 8 --- 5 files changed, 88 insertions(+), 87 deletions(-) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 63d9382aaaa..66f685aa9fd 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -2,9 +2,10 @@ """ import contextlib import functools +from collections import defaultdict from datetime import datetime, timedelta from itertools import chain, zip_longest -from typing import Hashable, Mapping +from typing import Hashable import numpy as np import pandas as pd @@ -272,37 +273,33 @@ def inline_variable_array_repr(var, max_width): def summarize_variable( - name: Hashable, var, col_width: int, marker: str = " ", max_width: int = None + name: Hashable, var, col_width: int, max_width: int = None, is_index: bool = False ): """Summarize a variable in one line, e.g., for the Dataset.__repr__.""" + variable = var.variable if hasattr(var, "variable") else var + if max_width is None: max_width_options = OPTIONS["display_width"] if not isinstance(max_width_options, int): raise TypeError(f"`max_width` value of `{max_width}` is not a valid int") else: max_width = max_width_options + + marker = "*" if is_index else " " first_col = pretty_print(f" {marker} {name} ", col_width) - if var.dims: - dims_str = "({}) ".format(", ".join(map(str, var.dims))) + + if variable.dims: + dims_str = "({}) ".format(", ".join(map(str, variable.dims))) else: dims_str = "" - front_str = f"{first_col}{dims_str}{var.dtype} " + front_str = f"{first_col}{dims_str}{variable.dtype} " values_width = max_width - len(front_str) - values_str = inline_variable_array_repr(var, values_width) + values_str = inline_variable_array_repr(variable, values_width) return front_str + values_str -def summarize_datavar(name, var, col_width): - return summarize_variable(name, var.variable, col_width) - - -def summarize_coord(name: Hashable, var, col_width: int, indexes: Mapping): - marker = "*" if name in indexes else " " - return summarize_variable(name, var.variable, col_width, marker) - - def summarize_attr(key, value, col_width=None): """Summary for __repr__ - use ``X.attrs[key]`` for full value.""" # Indent key and add ':', then right-pad if col_width is not None @@ -348,14 +345,17 @@ def _mapping_repr( expand_option_name, col_width=None, max_rows=None, - summarizer_kwargs=None, + indexes=None, ): if col_width is None: col_width = _calculate_col_width(mapping) if max_rows is None: max_rows = OPTIONS["display_max_rows"] - if summarizer_kwargs is None: - summarizer_kwargs = {} + + summarizer_kwargs = defaultdict(dict) + if indexes is not None: + summarizer_kwargs = {k: {"is_index": k in indexes} for k in mapping} + summary = [f"{title}:"] if mapping: len_mapping = len(mapping) @@ -366,19 +366,19 @@ def _mapping_repr( first_rows = max_rows // 2 + max_rows % 2 keys = list(mapping.keys()) summary += [ - summarizer(k, mapping[k], col_width, **summarizer_kwargs) + summarizer(k, mapping[k], col_width, **summarizer_kwargs[k]) for k in keys[:first_rows] ] if max_rows > 1: last_rows = max_rows // 2 summary += [pretty_print(" ...", col_width) + " ..."] summary += [ - summarizer(k, mapping[k], col_width, **summarizer_kwargs) + summarizer(k, mapping[k], col_width, **summarizer_kwargs[k]) for k in keys[-last_rows:] ] else: summary += [ - summarizer(k, v, col_width, **summarizer_kwargs) + summarizer(k, v, col_width, **summarizer_kwargs[k]) for k, v in mapping.items() ] else: @@ -389,7 +389,7 @@ def _mapping_repr( data_vars_repr = functools.partial( _mapping_repr, title="Data variables", - summarizer=summarize_datavar, + summarizer=summarize_variable, expand_option_name="display_expand_data_vars", ) @@ -408,10 +408,10 @@ def coords_repr(coords, col_width=None): return _mapping_repr( coords, title="Coordinates", - summarizer=summarize_coord, + summarizer=summarize_variable, expand_option_name="display_expand_coords", col_width=col_width, - summarizer_kwargs={"indexes": coords.indexes}, + indexes=coords.indexes, ) @@ -557,9 +557,20 @@ def diff_dim_summary(a, b): return "" -def _diff_mapping_repr(a_mapping, b_mapping, compat, title, summarizer, col_width=None): - def extra_items_repr(extra_keys, mapping, ab_side): - extra_repr = [summarizer(k, mapping[k], col_width) for k in extra_keys] +def _diff_mapping_repr( + a_mapping, + b_mapping, + compat, + title, + summarizer, + col_width=None, + a_indexes=None, + b_indexes=None, +): + def extra_items_repr(extra_keys, mapping, ab_side, kwargs): + extra_repr = [ + summarizer(k, mapping[k], col_width, **kwargs[k]) for k in extra_keys + ] if extra_repr: header = f"{title} only on the {ab_side} object:" return [header] + extra_repr @@ -573,6 +584,13 @@ def extra_items_repr(extra_keys, mapping, ab_side): diff_items = [] + a_summarizer_kwargs = defaultdict(dict) + if a_indexes is not None: + a_summarizer_kwargs = {k: {"is_index": k in a_indexes} for k in a_mapping} + b_summarizer_kwargs = defaultdict(dict) + if b_indexes is not None: + b_summarizer_kwargs = {k: {"is_index": k in b_indexes} for k in b_mapping} + for k in a_keys & b_keys: try: # compare xarray variable @@ -592,7 +610,8 @@ def extra_items_repr(extra_keys, mapping, ab_side): if not compatible: temp = [ - summarizer(k, vars[k], col_width) for vars in (a_mapping, b_mapping) + summarizer(k, a_mapping[k], col_width, **a_summarizer_kwargs[k]), + summarizer(k, b_mapping[k], col_width, **b_summarizer_kwargs[k]), ] if compat == "identical" and is_variable: @@ -614,19 +633,29 @@ def extra_items_repr(extra_keys, mapping, ab_side): if diff_items: summary += [f"Differing {title.lower()}:"] + diff_items - summary += extra_items_repr(a_keys - b_keys, a_mapping, "left") - summary += extra_items_repr(b_keys - a_keys, b_mapping, "right") + summary += extra_items_repr(a_keys - b_keys, a_mapping, "left", a_summarizer_kwargs) + summary += extra_items_repr( + b_keys - a_keys, b_mapping, "right", b_summarizer_kwargs + ) return "\n".join(summary) -diff_coords_repr = functools.partial( - _diff_mapping_repr, title="Coordinates", summarizer=summarize_coord -) +def diff_coords_repr(a, b, compat, col_width=None): + return _diff_mapping_repr( + a, + b, + compat, + "Coordinates", + summarize_variable, + col_width=col_width, + a_indexes=a.indexes, + b_indexes=b.indexes, + ) diff_data_vars_repr = functools.partial( - _diff_mapping_repr, title="Data variables", summarizer=summarize_datavar + _diff_mapping_repr, title="Data variables", summarizer=summarize_variable ) diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 2a480427d4e..ac5b79b287d 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -64,38 +64,7 @@ def _icon(icon_name): ) -def _summarize_coord_multiindex(name, coord): - preview = f"({', '.join(escape(l) for l in coord.level_names)})" - return summarize_variable( - name, coord, is_index=True, dtype="MultiIndex", preview=preview - ) - - -def summarize_coord(name, var): - is_index = name in var.dims - if is_index: - coord = var.variable.to_index_variable() - if coord.level_names is not None: - coords = {name: _summarize_coord_multiindex(name, coord)} - for lname in coord.level_names: - var = coord.get_level_variable(lname) - coords[lname] = summarize_variable(lname, var) - return coords - - return {name: summarize_variable(name, var, is_index)} - - -def summarize_coords(variables): - coords = {} - for k, v in variables.items(): - coords.update(**summarize_coord(k, v)) - - vars_li = "".join(f"
  • {v}
  • " for v in coords.values()) - - return f"
      {vars_li}
    " - - -def summarize_variable(name, var, is_index=False, dtype=None, preview=None): +def summarize_variable(name, var, is_index=False, dtype=None): variable = var.variable if hasattr(var, "variable") else var cssclass_idx = " class='xr-has-index'" if is_index else "" @@ -108,7 +77,7 @@ def summarize_variable(name, var, is_index=False, dtype=None, preview=None): data_id = "data-" + str(uuid.uuid4()) disabled = "" if len(var.attrs) else "disabled" - preview = preview or escape(inline_variable_array_repr(variable, 35)) + preview = escape(inline_variable_array_repr(variable, 35)) attrs_ul = summarize_attrs(var.attrs) data_repr = short_data_repr_html(variable) @@ -132,6 +101,17 @@ def summarize_variable(name, var, is_index=False, dtype=None, preview=None): ) +def summarize_coords(variables): + li_items = [] + for k, v in variables.items(): + li_content = summarize_variable(k, v, is_index=k in variables.indexes) + li_items.append(f"
  • {li_content}
  • ") + + vars_li = "".join(li_items) + + return f"
      {vars_li}
    " + + def summarize_vars(variables): vars_li = "".join( f"
  • {summarize_variable(k, v)}
  • " diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 5205c1b59ab..3bbad4ff6cf 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -92,9 +92,9 @@ def test_repr_multiindex(self): array([0, 1, 2, 3]) Coordinates: - * x (x) MultiIndex - - level_1 (x) object 'a' 'a' 'b' 'b' - - level_2 (x) int64 1 2 1 2""" + * x (x) object MultiIndex + level_1 (x) object 'a' 'a' 'b' 'b' + level_2 (x) int64 1 2 1 2""" ) assert expected == repr(self.mda) @@ -114,9 +114,9 @@ def test_repr_multiindex_long(self): array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) Coordinates: - * x (x) MultiIndex - - level_1 (x) object 'a' 'a' 'a' 'a' 'a' 'a' 'a' ... 'd' 'd' 'd' 'd' 'd' 'd' - - level_2 (x) int64 1 2 3 4 5 6 7 8 1 2 3 4 5 6 ... 4 5 6 7 8 1 2 3 4 5 6 7 8""" + * x (x) object MultiIndex + level_1 (x) object 'a' 'a' 'a' 'a' 'a' 'a' 'a' ... 'd' 'd' 'd' 'd' 'd' 'd' + level_2 (x) int64 1 2 3 4 5 6 7 8 1 2 3 4 5 6 ... 4 5 6 7 8 1 2 3 4 5 6 7 8""" ) assert expected == repr(mda_long) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 8e39bbdd83e..b2e03f5c950 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -239,9 +239,9 @@ def test_repr_multiindex(self): Dimensions: (x: 4) Coordinates: - * x (x) MultiIndex - - level_1 (x) object 'a' 'a' 'b' 'b' - - level_2 (x) int64 1 2 1 2 + * x (x) object MultiIndex + level_1 (x) object 'a' 'a' 'b' 'b' + level_2 (x) int64 1 2 1 2 Data variables: *empty*""" ) @@ -259,9 +259,9 @@ def test_repr_multiindex(self): Dimensions: (x: 4) Coordinates: - * x (x) MultiIndex - - a_quite_long_level_name (x) object 'a' 'a' 'b' 'b' - - level_2 (x) int64 1 2 1 2 + * x (x) object MultiIndex + a_quite_long_level_name (x) object 'a' 'a' 'b' 'b' + level_2 (x) int64 1 2 1 2 Data variables: *empty*""" ) diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index 09c6fa0cf3c..85666f3913b 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -115,14 +115,6 @@ def test_repr_of_dataarray(dataarray): ) -def test_summary_of_multiindex_coord(multiindex): - idx = multiindex.x.variable.to_index_variable() - formatted = fh._summarize_coord_multiindex("foo", idx) - assert "(level_1, level_2)" in formatted - assert "MultiIndex" in formatted - assert "foo" in formatted - - def test_repr_of_multiindex(multiindex): formatted = fh.dataset_repr(multiindex) assert "(x)" in formatted From b9e1cbafce48488636a08313f3336ee3c8f6718c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 23 Aug 2021 14:29:53 +0200 Subject: [PATCH 011/159] check level name conflicts for midx given as coord --- xarray/core/merge.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 214bdc8b1c9..f2d582a2392 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -514,13 +514,14 @@ def _create_indexes_from_coords(coords, data_vars=None): if duplicate_names: conflict_str = "\n".join(duplicate_names) raise ValueError( - f"conflicting MultiIndex level name(s):\n{conflict_str}" + f"conflicting MultiIndex level / variable name(s):\n{conflict_str}" ) else: index, index_vars = PandasIndex.from_variables({name: variable}) indexes[name] = index updated_coords.update(index_vars) + all_var_names += list(index_vars.keys()) else: updated_coords[name] = variable From 60935ac2667c09221e08339267441ba1ddeb8689 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 23 Aug 2021 17:21:51 +0200 Subject: [PATCH 012/159] intended behavior or unwanted side effect? see #5732 --- xarray/core/merge.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index f2d582a2392..dd0111477a2 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -496,7 +496,7 @@ def _create_indexes_from_coords(coords, data_vars=None): all_var_names += list(data_vars.keys()) indexes = {} - updated_coords = {} + updated_coords = {k: v for k, v in coords.items()} for name, variable in coords.items(): variable = as_variable(variable, name=name) @@ -523,9 +523,6 @@ def _create_indexes_from_coords(coords, data_vars=None): updated_coords.update(index_vars) all_var_names += list(index_vars.keys()) - else: - updated_coords[name] = variable - return indexes, updated_coords From 3086d4e2547a1c9044e9469e2bee003045b5ec6a Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 23 Aug 2021 17:45:37 +0200 Subject: [PATCH 013/159] get rid of multi-index virtual coordinates Not totally yet: need to refactor set_index / reset_index --- xarray/core/dataarray.py | 7 ++---- xarray/core/dataset.py | 46 +++++++++++----------------------- xarray/core/formatting.py | 25 +++--------------- xarray/tests/test_dataarray.py | 5 ---- xarray/tests/test_dataset.py | 28 --------------------- 5 files changed, 19 insertions(+), 92 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8f0f265fa85..0ddce30bf4a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -712,7 +712,7 @@ def _level_coords(self) -> Dict[Hashable, Hashable]: """ level_coords: Dict[Hashable, Hashable] = {} - for cname, var in self._coords.items(): + for _, var in self._coords.items(): if var.ndim == 1 and isinstance(var, IndexVariable): level_names = var.level_names if level_names is not None: @@ -727,9 +727,7 @@ def _getitem_coord(self, key): var = self._coords[key] except KeyError: dim_sizes = dict(zip(self.dims, self.shape)) - _, key, var = _get_virtual_variable( - self._coords, key, self._level_coords, dim_sizes - ) + _, key, var = _get_virtual_variable(self._coords, key, dim_sizes) return self._replace_maybe_drop_dims(var, name=key) @@ -774,7 +772,6 @@ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: # virtual coordinates # uses empty dict -- everything here can already be found in self.coords. yield HybridMappingProxy(keys=self.dims, mapping={}) - yield HybridMappingProxy(keys=self._level_coords, mapping={}) def __contains__(self, key: Any) -> bool: return key in self.data diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 533ecadbae5..009e4497a2d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -138,13 +138,12 @@ def _get_virtual_variable( - variables, key: Hashable, level_vars: Mapping = None, dim_sizes: Mapping = None + variables, key: Hashable, dim_sizes: Mapping = None ) -> Tuple[Hashable, Hashable, Variable]: - """Get a virtual variable (e.g., 'time.year' or a MultiIndex level) - from a dict of xarray.Variable objects (if possible) + """Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable + objects (if possible) + """ - if level_vars is None: - level_vars = {} if dim_sizes is None: dim_sizes = {} @@ -157,30 +156,18 @@ def _get_virtual_variable( raise KeyError(key) split_key = key.split(".", 1) - var_name: Optional[str] - if len(split_key) == 2: - ref_name, var_name = split_key - elif len(split_key) == 1: - ref_name, var_name = key, None - else: + if len(split_key) != 2: raise KeyError(key) - if ref_name in level_vars: - dim_var = variables[level_vars[ref_name]] - ref_var = dim_var.to_index_variable().get_level_variable(ref_name) - else: - ref_var = variables[ref_name] + ref_name, var_name = split_key + ref_var = variables[ref_name] - if var_name is None: - virtual_var = ref_var - var_name = key + if _contains_datetime_like_objects(ref_var): + ref_var = xr.DataArray(ref_var) + data = getattr(ref_var.dt, var_name).data else: - if _contains_datetime_like_objects(ref_var): - ref_var = xr.DataArray(ref_var) - data = getattr(ref_var.dt, var_name).data - else: - data = getattr(ref_var, var_name).data - virtual_var = Variable(ref_var.dims, data) + data = getattr(ref_var, var_name).data + virtual_var = Variable(ref_var.dims, data) return ref_name, var_name, virtual_var @@ -1362,7 +1349,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset": variables[name] = self._variables[name] except KeyError: ref_name, var_name, var = _get_virtual_variable( - self._variables, name, self._level_coords, self.dims + self._variables, name, self.dims ) variables[var_name] = var if ref_name in self._coord_names or ref_name in self.dims: @@ -1396,9 +1383,7 @@ def _construct_dataarray(self, name: Hashable) -> "DataArray": try: variable = self._variables[name] except KeyError: - _, name, variable = _get_virtual_variable( - self._variables, name, self._level_coords, self.dims - ) + _, name, variable = _get_virtual_variable(self._variables, name, self.dims) needed_dims = set(variable.dims) @@ -1438,9 +1423,6 @@ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: # virtual coordinates yield HybridMappingProxy(keys=self.dims, mapping=self) - # uses empty dict -- everything here can already be found in self.coords. - yield HybridMappingProxy(keys=self._level_coords, mapping={}) - def __contains__(self, key: object) -> bool: """The 'in' operator will return true or false depending on whether 'key' is an array in the dataset or not. diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 66f685aa9fd..a62910d1638 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -315,23 +315,6 @@ def summarize_attr(key, value, col_width=None): EMPTY_REPR = " *empty*" -def _get_col_items(mapping): - """Get all column items to format, including both keys of `mapping` - and MultiIndex levels if any. - """ - from .variable import IndexVariable - - col_items = [] - for k, v in mapping.items(): - col_items.append(k) - var = getattr(v, "variable", v) - if isinstance(var, IndexVariable): - level_names = var.to_index_variable().level_names - if level_names is not None: - col_items += list(level_names) - return col_items - - def _calculate_col_width(col_items): max_name_length = max(len(str(s)) for s in col_items) if col_items else 0 col_width = max(max_name_length, 7) + 6 @@ -404,7 +387,7 @@ def _mapping_repr( def coords_repr(coords, col_width=None): if col_width is None: - col_width = _calculate_col_width(_get_col_items(coords)) + col_width = _calculate_col_width(coords) return _mapping_repr( coords, title="Coordinates", @@ -528,7 +511,7 @@ def array_repr(arr): def dataset_repr(ds): summary = ["".format(type(ds).__name__)] - col_width = _calculate_col_width(_get_col_items(ds.variables)) + col_width = _calculate_col_width(ds.variables) dims_start = pretty_print("Dimensions:", col_width) summary.append("{}({})".format(dims_start, dim_summary(ds))) @@ -717,9 +700,7 @@ def diff_dataset_repr(a, b, compat): ) ] - col_width = _calculate_col_width( - set(_get_col_items(a.variables) + _get_col_items(b.variables)) - ) + col_width = _calculate_col_width(set(list(a.variables) + list(b.variables))) summary.append(diff_dim_summary(a, b)) summary.append(diff_coords_repr(a.coords, b.coords, compat, col_width=col_width)) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 3bbad4ff6cf..bf6576a96c9 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -768,11 +768,6 @@ def test_contains(self): assert 1 in data_array assert 3 not in data_array - def test_attr_sources_multiindex(self): - # make sure attr-style access for multi-index levels - # returns DataArray objects - assert isinstance(self.mda.level_1, DataArray) - def test_pickle(self): data = DataArray(np.random.random((3, 3)), dims=("id", "time")) roundtripped = pickle.loads(pickle.dumps(data)) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b2e03f5c950..d07f4ad6639 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3323,34 +3323,6 @@ def test_virtual_variable_same_name(self): expected = DataArray(times.time, [("time", times)], name="time") assert_identical(actual, expected) - def test_virtual_variable_multiindex(self): - # access multi-index levels as virtual variables - data = create_test_multiindex() - expected = DataArray( - ["a", "a", "b", "b"], - name="level_1", - coords=[data["x"].to_index()], - dims="x", - ) - assert_identical(expected, data["level_1"]) - - # combine multi-index level and datetime - dr_index = pd.date_range("1/1/2011", periods=4, freq="H") - mindex = pd.MultiIndex.from_arrays( - [["a", "a", "b", "b"], dr_index], names=("level_str", "level_date") - ) - data = Dataset({}, {"x": mindex}) - expected = DataArray( - mindex.get_level_values("level_date").hour, - name="hour", - coords=[mindex], - dims="x", - ) - assert_identical(expected, data["level_date.hour"]) - - # attribute style access - assert_identical(data.level_str, data["level_str"]) - def test_time_season(self): ds = Dataset({"t": pd.date_range("2000-01-01", periods=12, freq="M")}) seas = ["DJF"] * 2 + ["MAM"] * 3 + ["JJA"] * 3 + ["SON"] * 3 + ["DJF"] From 54a04f7de6e5200022f870a0039a3ef90021d280 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 25 Aug 2021 15:20:20 +0200 Subject: [PATCH 014/159] add level coords in indexes & keep coord order --- xarray/core/merge.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index dd0111477a2..2f4dbcd2bd0 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -491,21 +491,20 @@ def _create_indexes_from_coords(coords, data_vars=None): Return those indexes and updated coordinates. """ - all_var_names = list(coords.keys()) + all_var_names = set(coords.keys()) if data_vars is not None: - all_var_names += list(data_vars.keys()) + all_var_names |= set(data_vars.keys()) indexes = {} - updated_coords = {k: v for k, v in coords.items()} + updated_coords = {} - for name, variable in coords.items(): - variable = as_variable(variable, name=name) + for name, obj in coords.items(): + variable = as_variable(obj, name=name) if variable.dims == (name,): array = getattr(variable._data, "array", None) if isinstance(array, pd.MultiIndex): # TODO: benbovy - explicit indexes: depreciate passing multi-indexes as coords? - # (instead pass them explicitly as indexes once it is supported via public API) index, index_vars = PandasMultiIndex.from_pandas_index(array, name) # check for conflict between level names and variable names duplicate_names = [ @@ -516,12 +515,15 @@ def _create_indexes_from_coords(coords, data_vars=None): raise ValueError( f"conflicting MultiIndex level / variable name(s):\n{conflict_str}" ) + all_var_names |= set(index_vars.keys()) else: index, index_vars = PandasIndex.from_variables({name: variable}) - indexes[name] = index + indexes.update({k: index for k in index_vars}) updated_coords.update(index_vars) - all_var_names += list(index_vars.keys()) + + else: + updated_coords[name] = obj return indexes, updated_coords From e3181d43372eaf606cb7ff9d42934af33e07ad41 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 25 Aug 2021 15:21:48 +0200 Subject: [PATCH 015/159] fix copying multi-index level variable data --- xarray/core/indexing.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 3a28919ffac..858303c2ad9 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1426,3 +1426,10 @@ def _repr_inline_(self, max_width) -> str: subset = self return format_array_flat(np.asarray(subset), max_width) + + def copy(self, deep: bool = True) -> "PandasMultiIndexingAdapter": + # see PandasIndexingAdapter.copy + array = self.array.copy(deep=True) if deep else self.array + # do not use indexing cache if deep=True + adapter = None if deep else self.adapter + return type(self)(array, self._dtype, self.level, adapter) From 81a60f619761f91db533c56ceac379740f992033 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 25 Aug 2021 17:35:17 +0200 Subject: [PATCH 016/159] collect index for multi-index level variables Avoid re-creating the indexes for dimension variables. Collect then directly instead. Note: the change here is working for building new Datasets but I haven't checked other cases like merging different objects, etc. So I'm not sure this is the right approach. --- xarray/core/merge.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 2f4dbcd2bd0..2d58f9038e5 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -254,6 +254,7 @@ def merge_collected( def collect_variables_and_indexes( list_of_mappings: "List[DatasetLike]", + indexes: Optional[Mapping[Hashable, Any]] = None, ) -> Dict[Hashable, List[MergeElement]]: """Collect variables and indexes from list of mappings of xarray objects. @@ -263,10 +264,17 @@ def collect_variables_and_indexes( - a tuple `(dims, data[, attrs[, encoding]])` that can be converted in an xarray.Variable - or an xarray.DataArray + + If a mapping of indexes is given, those indexes are assigned to all variables + with a matching key/name. + """ from .dataarray import DataArray from .dataset import Dataset + if indexes is None: + indexes = {} + grouped: Dict[Hashable, List[Tuple[Variable, Optional[Index]]]] = {} def append(name, variable, index): @@ -292,7 +300,11 @@ def append_all(variables, indexes): append_all(coords, indexes) variable = as_variable(variable, name=name) - if variable.dims == (name,): + if name in indexes: + index = indexes[name] + elif variable.dims == (name,): + # TODO: benbovy - explicit indexes: do we still need this? + # default "dimension" indexes are already created elsewhere variable = variable.to_index_variable() index = variable._to_xindex() else: @@ -664,8 +676,7 @@ def merge_core( aligned = deep_align( coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value ) - collected = collect_variables_and_indexes(aligned) - + collected = collect_variables_and_indexes(aligned, indexes=indexes) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) variables, out_indexes = merge_collected( collected, prioritized, compat=compat, combine_attrs=combine_attrs From 2e4a041adc127549b9c534293b9709e1a473cc87 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 26 Aug 2021 18:07:07 +0200 Subject: [PATCH 017/159] wip refactor label based selection - Index.query must now return a mapping of {dim_name: positional_indexer} as indexes may be based on several coordinates with different dimensions - Added `group_coords_by_index` utility function (not used yet, not sure we'll need it) TODO: - Update DataArray selection - Update .loc and other places using remap_label_indexers - Fix selection of multi-index that returns only scalar coordinates --- xarray/core/coordinates.py | 16 +++- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 49 ++++++------ xarray/core/indexes.py | 18 +++-- xarray/core/indexing.py | 140 +++++++++++++++++++++------------- xarray/tests/test_indexing.py | 65 +++++++++++++--- 6 files changed, 194 insertions(+), 96 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 767b76d0d12..ca2db45d1fe 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -398,7 +398,9 @@ def remap_label_indexers( method: str = None, tolerance=None, **indexers_kwargs: Any, -) -> Tuple[dict, dict]: # TODO more precise return type after annotations in indexing +) -> Tuple[ + dict, dict, dict, list +]: # TODO more precise return type after annotations in indexing """Remap indexers from obj.coords. If indexer is an instance of DataArray and it has coordinate, then this coordinate will be attached to pos_indexers. @@ -408,6 +410,7 @@ def remap_label_indexers( pos_indexers: Same type of indexers. np.ndarray or Variable or DataArray new_indexes: mapping of new dimensional-coordinate. + """ from .dataarray import DataArray @@ -418,9 +421,15 @@ def remap_label_indexers( for k, v in indexers.items() } - pos_indexers, new_indexes = indexing.remap_label_indexers( + ( + pos_indexers, + new_indexes, + new_variables, + drop_variables, + ) = indexing.remap_label_indexers( obj, v_indexers, method=method, tolerance=tolerance ) + # attach indexer's coordinate to pos_indexers for k, v in indexers.items(): if isinstance(v, Variable): @@ -430,4 +439,5 @@ def remap_label_indexers( # ensures alignments coords = {k: var for k, var in v._coords.items() if k not in indexers} pos_indexers[k] = DataArray(pos_indexers[k], coords=coords, dims=v.dims) - return pos_indexers, new_indexes + + return pos_indexers, new_indexes, new_variables, drop_variables diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0ddce30bf4a..68c8330842b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -204,7 +204,7 @@ def __setitem__(self, key, value) -> None: labels = indexing.expanded_indexer(key, self.data_array.ndim) key = dict(zip(self.data_array.dims, labels)) - pos_indexers, _ = remap_label_indexers(self.data_array, key) + pos_indexers, _, _, _ = remap_label_indexers(self.data_array, key) self.data_array[pos_indexers] = value diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 009e4497a2d..73a88fded8d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -559,7 +559,7 @@ def __setitem__(self, key, value) -> None: ) # set new values - pos_indexers, _ = remap_label_indexers(self.dataset, key) + pos_indexers, _, _, _ = remap_label_indexers(self.dataset, key) self.dataset[pos_indexers] = value @@ -1163,26 +1163,34 @@ def _replace_vars_and_dims( variables, coord_names, dims, attrs, indexes=None, inplace=inplace ) - def _overwrite_indexes(self, indexes: Mapping[Any, Index]) -> "Dataset": + def _overwrite_indexes( + self, + indexes: Mapping[Any, Index], + variables: Mapping[Any, Variable], + drop_variables: List, + ) -> "Dataset": + """Maybe replace indexes and their corresponding index variables.""" if not indexes: return self - variables = self._variables.copy() + assert indexes.keys() == variables.keys() + + new_variables = self._variables.copy() + new_coord_names = self._coord_names.copy() new_indexes = dict(self.xindexes) - for name, idx in indexes.items(): - variables[name] = IndexVariable(name, idx.to_pandas_index()) - new_indexes[name] = idx - obj = self._replace(variables, indexes=new_indexes) - - # switch from dimension to level names, if necessary - dim_names: Dict[Hashable, str] = {} - for dim, idx in indexes.items(): - pd_idx = idx.to_pandas_index() - if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim: - dim_names[dim] = pd_idx.name - if dim_names: - obj = obj.rename(dim_names) - return obj + + for name in indexes: + new_variables[name] = variables[name] + new_indexes[name] = indexes[name] + + for name in drop_variables: + new_variables.pop(name) + new_indexes.pop(name) + new_coord_names.remove(name) + + return self._replace_with_new_dims( + variables=new_variables, coord_names=new_coord_names, indexes=new_indexes + ) def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset": """Returns a copy of this dataset. @@ -2452,15 +2460,12 @@ def sel( DataArray.sel """ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") - pos_indexers, new_indexes = remap_label_indexers( + pos_indexers, new_indexes, new_variables, drop_variables = remap_label_indexers( self, indexers=indexers, method=method, tolerance=tolerance ) - # TODO: benbovy - flexible indexes: also use variables returned by Index.query - # (temporary dirty fix). - new_indexes = {k: v[0] for k, v in new_indexes.items()} result = self.isel(indexers=pos_indexers, drop=drop) - return result._overwrite_indexes(new_indexes) + return result._overwrite_indexes(new_indexes, new_variables, drop_variables) def head( self, diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index eb4e6c368a1..636af44ad78 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -23,6 +23,7 @@ from .variable import IndexVariable, Variable IndexVars = Dict[Hashable, "IndexVariable"] +IndexWithVars = Tuple["Index", Optional[IndexVars]] class Index: @@ -31,7 +32,7 @@ class Index: @classmethod def from_variables( cls, variables: Mapping[Hashable, "Variable"] - ) -> Tuple["Index", Optional[IndexVars]]: # pragma: no cover + ) -> IndexWithVars: # pragma: no cover raise NotImplementedError() def to_pandas_index(self) -> pd.Index: @@ -46,7 +47,7 @@ def to_pandas_index(self) -> pd.Index: def query( self, labels: Dict[Hashable, Any] - ) -> Tuple[Any, Optional[Tuple["Index", IndexVars]]]: # pragma: no cover + ) -> Tuple[Mapping[str, Any], Optional[IndexWithVars]]: # pragma: no cover raise NotImplementedError() def equals(self, other): # pragma: no cover @@ -205,6 +206,9 @@ def to_pandas_index(self) -> pd.Index: return self.index def query(self, labels, method=None, tolerance=None): + if method is not None and not isinstance(method, str): + raise TypeError("``method`` must be a string") + assert len(labels) == 1 coord_name, label = next(iter(labels.items())) @@ -240,7 +244,7 @@ def query(self, labels, method=None, tolerance=None): if np.any(indexer < 0): raise KeyError(f"not all values found in index {coord_name!r}") - return indexer, None + return {self.dim: indexer}, None def equals(self, other): return self.index.equals(other.index) @@ -425,10 +429,12 @@ def query(self, labels, method=None, tolerance=None): new_index, self.dim ) else: - new_index, new_vars = PandasIndex.from_pandas_index(new_index, self.dim) - return indexer, (new_index, new_vars) + new_index, new_vars = PandasIndex.from_pandas_index( + new_index, new_index.name + ) + return {self.dim: indexer}, (new_index, new_vars) else: - return indexer, None + return {self.dim: indexer}, None def remove_unused_levels_categories(index: pd.Index) -> pd.Index: diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 858303c2ad9..f8a15ecf2da 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -4,7 +4,19 @@ from collections import defaultdict from contextlib import suppress from datetime import timedelta -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, +) import numpy as np import pandas as pd @@ -20,6 +32,9 @@ ) from .utils import maybe_cast_to_coords_dtype +if TYPE_CHECKING: + from .indexes import Index + def expanded_indexer(key, ndim): """Given a key for indexing an ndarray, return an equivalent key which is a @@ -55,78 +70,99 @@ def _expand_slice(slice_, size): return np.arange(*slice_.indices(size)) -def group_indexers_by_index(data_obj, indexers, method=None, tolerance=None): - # TODO: benbovy - flexible indexes: indexers are still grouped by dimension - # - Make xarray.Index hashable so that it can be used as key in a mapping? - indexes = {} - grouped_indexers = defaultdict(dict) +def group_coords_by_index( + indexes: Mapping[Hashable, "Index"] +) -> Dict[Tuple[Hashable, ...], "Index"]: + """From a flat mapping of coordinate names to their corresponding index, return + a dictionnary of unique index items with the name(s) of all their corresponding + coordinate(s) (tuple) as keys. + + """ + index_unique: Dict[int, "Index"] = {} + grouped_coord_names = defaultdict(list) + + for coord_name, index_obj in indexes.items(): + index_id = id(index_obj) + index_unique[index_id] = index_obj + grouped_coord_names[index_id].append(coord_name) - # TODO: data_obj.xindexes should eventually return the PandasIndex instance - # for each multi-index levels - xindexes = dict(data_obj.xindexes) - for level, dim in data_obj._level_coords.items(): - xindexes[level] = xindexes[dim] + return {tuple(grouped_coord_names[k]): index_unique[k] for k in index_unique} + + +def group_indexers_by_index(data_obj, indexers, **kwargs): + """Returns a dictionary of unique index items and another dictionary of label indexers + grouped by index (both using the same index ids as keys). + + """ + unique_indexes = {} + grouped_indexers = defaultdict(dict) for key, label in indexers.items(): - try: - index = xindexes[key] - coord = data_obj.coords[key] - except KeyError: - if key in data_obj.coords: - raise KeyError(f"no index found for coordinate {key}") - elif key not in data_obj.dims: - raise KeyError(f"{key} is not a valid dimension or coordinate") - # key is a dimension without coordinate: we'll reuse the provided labels - elif method is not None or tolerance is not None: - raise ValueError( - "cannot supply ``method`` or ``tolerance`` " - "when the indexed dimension does not have " - "an associated coordinate." - ) - grouped_indexers[None][key] = label - else: - dim = coord.dims[0] - if dim not in indexes: - indexes[dim] = index + index = data_obj.xindexes.get(key, None) + coord = data_obj.coords.get(key, None) + if index is not None: + index_id = id(index) + unique_indexes[index_id] = index label = maybe_cast_to_coords_dtype(label, coord.dtype) - grouped_indexers[dim][key] = label + grouped_indexers[index_id][key] = label + elif coord is not None: + raise KeyError(f"no index found for coordinate {key}") + elif key not in data_obj.dims: + raise KeyError(f"{key} is not a valid dimension or coordinate") + elif len(kwargs): + raise ValueError( + "cannot supply selection options " + "when the indexed dimension does not have " + "an associated coordinate." + ) + else: + # key is a dimension without coordinate + # failback to location-based selection + grouped_indexers[None][key] = label - return indexes, grouped_indexers + return unique_indexes, grouped_indexers -def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): - """Given an xarray data object and label based indexers, return a mapping - of equivalent location based indexers. Also return a mapping of updated - pandas index objects (in case of multi-index level drop). - """ - if method is not None and not isinstance(method, str): - raise TypeError("``method`` must be a string") +def remap_label_indexers(data_obj, indexers, **kwargs): + """Given an xarray data object and label based indexers, returns: + + - a mapping of equivalent location based indexers + - a mapping of updated indexes (if any) + - a mapping of updated index variables (if any) + - a list of variables to drop (if any) + """ pos_indexers = {} new_indexes = {} + new_variables = {} + drop_variables = [] - indexes, grouped_indexers = group_indexers_by_index( - data_obj, indexers, method, tolerance - ) + indexes, grouped_indexers = group_indexers_by_index(data_obj, indexers, **kwargs) forward_pos_indexers = grouped_indexers.pop(None, None) if forward_pos_indexers is not None: for dim, label in forward_pos_indexers.items(): pos_indexers[dim] = label - for dim, index in indexes.items(): - labels = grouped_indexers[dim] - idxr, new_idx = index.query(labels, method=method, tolerance=tolerance) - pos_indexers[dim] = idxr - if new_idx is not None: - new_indexes[dim] = new_idx + for index_id, index in indexes.items(): + labels = grouped_indexers[index_id] + pos_idxr, new_idx_and_vars = index.query(labels, **kwargs) + pos_indexers.update(pos_idxr) + + if new_idx_and_vars is not None: + new_idx, new_vars = new_idx_and_vars + new_variables.update(new_vars) + for k in new_vars: + new_indexes[k] = new_idx + for k, idx in data_obj.xindexes.items(): + if id(idx) == index_id and k not in new_vars: + drop_variables.append(k) # TODO: benbovy - flexible indexes: support the following cases: - # - an index query returns positional indexers over multiple dimensions - # - check/combine positional indexers returned by multiple indexes over the same dimension + # - check/combine positional indexers returned by multiple indexes over the same dimension(s) - return pos_indexers, new_indexes + return pos_indexers, new_indexes, new_variables, drop_variables def _normalize_slice(sl, size): diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 6e4fd320029..40d93e33ed7 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -6,6 +6,7 @@ from xarray import DataArray, Dataset, Variable from xarray.core import indexing, nputils +from xarray.core.indexes import PandasIndex, PandasMultiIndex from . import IndexerMaker, ReturnItem, assert_array_equal @@ -80,13 +81,38 @@ def test_group_indexers_by_index(self): indexing.group_indexers_by_index(data, {"z": 1}, method="nearest") def test_remap_label_indexers(self): - def test_indexer(data, x, expected_pos, expected_idx=None): - pos, new_idx_vars = indexing.remap_label_indexers(data, {"x": x}) - idx, _ = new_idx_vars.get("x", (None, None)) - if idx is not None: - idx = idx.to_pandas_index() + def test_indexer( + data, + x, + expected_pos, + expected_idx=None, + expected_vars=None, + expected_drop=None, + ): + if expected_vars is None: + expected_vars = {} + if expected_idx is None: + expected_idx = {} + else: + expected_idx = {k: expected_idx for k in expected_vars} + if expected_drop is None: + expected_drop = [] + + pos, new_idx, new_vars, drop_vars = indexing.remap_label_indexers( + data, {"x": x} + ) + assert_array_equal(pos.get("x"), expected_pos) - assert_array_equal(idx, expected_idx) + + assert new_idx.keys() == expected_idx.keys() + for k in new_idx: + assert new_idx[k].equals(expected_idx[k]) + + assert new_vars.keys() == expected_vars.keys() + for k in new_vars: + assert_array_equal(new_vars[k], expected_vars[k]) + + assert drop_vars == expected_drop data = Dataset({"x": ("x", [1, 2, 3])}) mindex = pd.MultiIndex.from_product( @@ -102,19 +128,28 @@ def test_indexer(data, x, expected_pos, expected_idx=None): mdata, ("a", 1), [True, True, False, False, False, False, False, False], - [-1, -2], + *PandasIndex.from_pandas_index(pd.Index([-1, -2]), "three"), + ["x", "one", "two"], ) test_indexer( mdata, "a", slice(0, 4, None), - pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + *PandasMultiIndex.from_pandas_index( + pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), + "x", + ), + ["one"], ) test_indexer( mdata, ("a",), [True, True, True, True, False, False, False, False], - pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + *PandasMultiIndex.from_pandas_index( + pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), + "x", + ), + ["one"], ) test_indexer(mdata, [("a", 1, -1), ("b", 2, -2)], [0, 7]) test_indexer(mdata, slice("a", "b"), slice(0, 8, None)) @@ -124,19 +159,25 @@ def test_indexer(data, x, expected_pos, expected_idx=None): mdata, {"one": "a", "two": 1}, [True, True, False, False, False, False, False, False], - [-1, -2], + *PandasIndex.from_pandas_index(pd.Index([-1, -2]), "three"), + ["x", "one", "two"], ) test_indexer( mdata, {"one": "a", "three": -1}, [True, False, True, False, False, False, False, False], - [1, 2], + *PandasIndex.from_pandas_index(pd.Index([1, 2]), "two"), + ["x", "one", "three"], ) test_indexer( mdata, {"one": "a"}, [True, True, True, True, False, False, False, False], - pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + *PandasMultiIndex.from_pandas_index( + pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), + "x", + ), + ["one"], ) def test_read_only_view(self): From f387c70948ca73940f0baa5a499de04496698360 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 27 Aug 2021 09:47:08 +0200 Subject: [PATCH 018/159] fix index query tests --- xarray/tests/test_indexes.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 410fbb61b26..11a81cd0f3b 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -86,7 +86,7 @@ def test_query_datetime(self): pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x" ) actual = index.query({"x": "2001-01-01"}) - expected = (1, None) + expected = ({"x": 1}, None) assert actual == expected actual = index.query({"x": index.to_pandas_index().to_numpy()[1]}) @@ -192,7 +192,10 @@ def test_query(self): pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x" ) # test tuples inside slice are considered as scalar indexer values - assert index.query({"x": slice(("a", 1), ("b", 2))}) == (slice(0, 4), None) + assert index.query({"x": slice(("a", 1), ("b", 2))}) == ( + {"x": slice(0, 4)}, + None, + ) with pytest.raises(KeyError, match=r"not all values found"): index.query({"x": [0]}) From 6a9dbd6943107ff4541ab596ee1e52fd234d4926 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 27 Aug 2021 09:48:38 +0200 Subject: [PATCH 019/159] fix multi-index adapter getitem scalar Multi-index level variables now return the scalar value that corresponds to the level instead of the multi-index tuple element (all levels). Also get rid of PandasMultiIndexingAdapter.__getitem__ cache optimization, which doesn't work with level scalar values and was premature optimization anyway. --- xarray/core/indexes.py | 4 +-- xarray/core/indexing.py | 61 +++++++++++++++++++---------------------- 2 files changed, 29 insertions(+), 36 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 636af44ad78..d58c97a20d5 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -277,9 +277,7 @@ def _create_variables_from_multiindex(index, dim, level_meta=None): for level in index.names: meta = level_meta.get(level, {}) - data = PandasMultiIndexingAdapter( - index, dtype=meta.get("dtype"), level=level, adapter=dim_coord_adapter - ) + data = PandasMultiIndexingAdapter(index, dtype=meta.get("dtype"), level=level) variables[level] = IndexVariable( dim, data, diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index f8a15ecf2da..ee7953de287 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1338,6 +1338,26 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray: def shape(self) -> Tuple[int]: return (len(self.array),) + def _convert_scalar(self, item): + if item is pd.NaT: + # work around the impossibility of casting NaT with asarray + # note: it probably would be better in general to return + # pd.Timestamp rather np.than datetime64 but this is easier + # (for now) + item = np.datetime64("NaT", "ns") + elif isinstance(item, timedelta): + item = np.timedelta64(getattr(item, "value", item), "ns") + elif isinstance(item, pd.Timestamp): + # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 + # numpy fails to convert pd.Timestamp to np.datetime64[ns] + item = np.asarray(item.to_datetime64()) + elif self.dtype != object: + item = np.asarray(item, dtype=self.dtype) + + # as for numpy.ndarray indexing, we always want the result to be + # a NumPy array. + return utils.to_0d_array(item) + def __getitem__( self, indexer ) -> Union[ @@ -1359,29 +1379,9 @@ def __getitem__( result = self.array[key] if isinstance(result, pd.Index): - result = type(self)(result, dtype=self.dtype) + return type(self)(result, dtype=self.dtype) else: - # result is a scalar - if result is pd.NaT: - # work around the impossibility of casting NaT with asarray - # note: it probably would be better in general to return - # pd.Timestamp rather np.than datetime64 but this is easier - # (for now) - result = np.datetime64("NaT", "ns") - elif isinstance(result, timedelta): - result = np.timedelta64(getattr(result, "value", result), "ns") - elif isinstance(result, pd.Timestamp): - # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 - # numpy fails to convert pd.Timestamp to np.datetime64[ns] - result = np.asarray(result.to_datetime64()) - elif self.dtype != object: - result = np.asarray(result, dtype=self.dtype) - - # as for numpy.ndarray indexing, we always want the result to be - # a NumPy array. - result = utils.to_0d_array(result) - - return result + return self._convert_scalar(result) def transpose(self, order) -> pd.Index: return self.array # self.array should be always one-dimensional @@ -1417,11 +1417,9 @@ def __init__( array: pd.MultiIndex, dtype: DTypeLike = None, level: Optional[str] = None, - adapter: Optional[PandasIndexingAdapter] = None, ): super().__init__(array, dtype) self.level = level - self.adapter = adapter def __array__(self, dtype: DTypeLike = None) -> np.ndarray: if self.level is not None: @@ -1429,12 +1427,11 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray: else: return super().__array__(dtype) - @functools.lru_cache(1) - def __getitem__(self, indexer): - if self.adapter is None: - return super().__getitem__(indexer) - else: - return self.adapter.__getitem__(indexer) + def _convert_scalar(self, item): + if isinstance(item, tuple) and self.level is not None: + idx = tuple(self.array.names).index(self.level) + item = item[idx] + return super()._convert_scalar(item) def __repr__(self) -> str: if self.level is None: @@ -1466,6 +1463,4 @@ def _repr_inline_(self, max_width) -> str: def copy(self, deep: bool = True) -> "PandasMultiIndexingAdapter": # see PandasIndexingAdapter.copy array = self.array.copy(deep=True) if deep else self.array - # do not use indexing cache if deep=True - adapter = None if deep else self.adapter - return type(self)(array, self._dtype, self.level, adapter) + return type(self)(array, self._dtype, self.level) From 64b71c960815f710bd1b91851355ba56e6fe1cec Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 27 Aug 2021 11:17:32 +0200 Subject: [PATCH 020/159] wip refactor label based selection Fixed renamed dimension in the case of multi-index -> single index Updated DataArray._overwrite_indexes Dirty fix for alignment (not tested yet) --- xarray/core/alignment.py | 3 ++- xarray/core/dataarray.py | 53 ++++++++++++++++++++++++++++------------ xarray/core/dataset.py | 28 ++++++++++++++++++--- xarray/core/indexes.py | 3 +-- 4 files changed, 64 insertions(+), 23 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index a53ac094253..3db765dfb62 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -69,7 +69,8 @@ def _override_indexes(objects, all_indexes, exclude): dim: all_indexes[dim][0] for dim in obj.xindexes if dim not in exclude } - objects[idx + 1] = obj._overwrite_indexes(new_indexes) + # TODO: benbovy - explicit indexes: not refactored yet (dirty fix) + objects[idx + 1] = obj._overwrite_indexes(new_indexes, {}, []) return objects diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 68c8330842b..c43015ef3e7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -460,23 +460,44 @@ def _replace_maybe_drop_dims( ) return self._replace(variable, coords, name, indexes=indexes) - def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray": - if not len(indexes): + def _overwrite_indexes( + self, + indexes: Mapping[Hashable, Index], + coords: Mapping[Hashable, Variable], + drop_coords: List[Hashable], + ) -> "DataArray": + """Maybe replace indexes and their corresponding coordinates.""" + if not indexes: return self - coords = self._coords.copy() - for name, idx in indexes.items(): - coords[name] = IndexVariable(name, idx.to_pandas_index()) - obj = self._replace(coords=coords) - - # switch from dimension to level names, if necessary - dim_names: Dict[Any, str] = {} - for dim, idx in indexes.items(): - pd_idx = idx.to_pandas_index() - if not isinstance(idx, pd.MultiIndex) and pd_idx.name != dim: - dim_names[dim] = idx.name - if dim_names: - obj = obj.rename(dim_names) - return obj + + assert indexes.keys() == coords.keys() + + new_variable = self.variable.copy() + new_coords = self._coords.copy() + new_indexes = dict(self.xindexes) + dims_dict = {} + + for name in indexes: + # new coordinate variables may have renamed dimensions (e.g., level + # name of a multi-index converted to a single index) + old_vs_new_dims = zip(self._coords[name].dims, coords[name].dims) + for old_dim, new_dim in old_vs_new_dims: + if old_dim != new_dim: + dims_dict[old_dim] = new_dim + + new_coords[name] = coords[name] + new_indexes[name] = indexes[name] + + for name in drop_coords: + new_coords.pop(name) + new_indexes.pop(name) + + if dims_dict: + new_variable.dims = [dims_dict.get(d, d) for d in new_variable.dims] + + return self._replace( + variable=new_variable, coords=new_coords, indexes=new_indexes + ) def _to_temp_dataset(self) -> Dataset: return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 73a88fded8d..00df4acb76b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1165,9 +1165,9 @@ def _replace_vars_and_dims( def _overwrite_indexes( self, - indexes: Mapping[Any, Index], - variables: Mapping[Any, Variable], - drop_variables: List, + indexes: Mapping[Hashable, Index], + variables: Mapping[Hashable, Variable], + drop_variables: List[Hashable], ) -> "Dataset": """Maybe replace indexes and their corresponding index variables.""" if not indexes: @@ -1178,8 +1178,18 @@ def _overwrite_indexes( new_variables = self._variables.copy() new_coord_names = self._coord_names.copy() new_indexes = dict(self.xindexes) + dims_dict = {} for name in indexes: + # new coordinate variables may have renamed dimensions (e.g., level + # name of a multi-index converted to a single index) + # TODO: instead of infer renamed dimensions from the coordinates, + # should we require explicitly providing it from Index.query? + old_vs_new_dims = zip(self._variables[name].dims, variables[name].dims) + for old_dim, new_dim in old_vs_new_dims: + if old_dim != new_dim: + dims_dict[old_dim] = new_dim + new_variables[name] = variables[name] new_indexes[name] = indexes[name] @@ -1188,10 +1198,20 @@ def _overwrite_indexes( new_indexes.pop(name) new_coord_names.remove(name) - return self._replace_with_new_dims( + replaced = self._replace( variables=new_variables, coord_names=new_coord_names, indexes=new_indexes ) + if dims_dict: + # skip rename indexes: they should already have the right name(s) + dims = replaced._rename_dims(dims_dict) + new_variables, new_coord_names = replaced._rename_vars({}, dims_dict) + return replaced._replace( + variables=new_variables, coord_names=new_coord_names, dims=dims + ) + else: + return replaced + def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset": """Returns a copy of this dataset. diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index d58c97a20d5..865df57fa53 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -46,7 +46,7 @@ def to_pandas_index(self) -> pd.Index: raise TypeError(f"{type(self)} cannot be cast to a pandas.Index object.") def query( - self, labels: Dict[Hashable, Any] + self, labels: Dict[Hashable, Any], **kwargs ) -> Tuple[Mapping[str, Any], Optional[IndexWithVars]]: # pragma: no cover raise NotImplementedError() @@ -63,7 +63,6 @@ def copy(self, deep: bool = True): # pragma: no cover raise NotImplementedError() def __getitem__(self, indexer: Any): - # if not implemented, index will be dropped from the Dataset or DataArray raise NotImplementedError() From f3116ac23418f4aab1fde52d65674ca1aac15ea1 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 30 Aug 2021 18:49:00 +0200 Subject: [PATCH 021/159] wip: deeper refactoring label-based sel Created QueryResult and MergedQueryResults classes for convenience. --- xarray/core/alignment.py | 4 +- xarray/core/coordinates.py | 20 ++- xarray/core/dataarray.py | 22 ++-- xarray/core/dataset.py | 32 ++--- xarray/core/indexes.py | 24 +++- xarray/core/indexing.py | 242 ++++++++++++++++++++++++++----------- 6 files changed, 224 insertions(+), 120 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 3db765dfb62..c50d6b54afb 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -69,8 +69,8 @@ def _override_indexes(objects, all_indexes, exclude): dim: all_indexes[dim][0] for dim in obj.xindexes if dim not in exclude } - # TODO: benbovy - explicit indexes: not refactored yet (dirty fix) - objects[idx + 1] = obj._overwrite_indexes(new_indexes, {}, []) + # TODO: benbovy - explicit indexes: not refactored yet! + objects[idx + 1] = obj._overwrite_indexes(new_indexes) return objects diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index ca2db45d1fe..5e09ac4fe19 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -398,9 +398,7 @@ def remap_label_indexers( method: str = None, tolerance=None, **indexers_kwargs: Any, -) -> Tuple[ - dict, dict, dict, list -]: # TODO more precise return type after annotations in indexing +) -> Any: """Remap indexers from obj.coords. If indexer is an instance of DataArray and it has coordinate, then this coordinate will be attached to pos_indexers. @@ -421,23 +419,21 @@ def remap_label_indexers( for k, v in indexers.items() } - ( - pos_indexers, - new_indexes, - new_variables, - drop_variables, - ) = indexing.remap_label_indexers( + query_results = indexing.remap_label_indexers( obj, v_indexers, method=method, tolerance=tolerance ) # attach indexer's coordinate to pos_indexers for k, v in indexers.items(): + dim_indexer = query_results.dim_indexers.get(k, None) if isinstance(v, Variable): - pos_indexers[k] = Variable(v.dims, pos_indexers[k]) + query_results.dim_indexers[k] = Variable(v.dims, dim_indexer) elif isinstance(v, DataArray): # drop coordinates found in indexers since .sel() already # ensures alignments coords = {k: var for k, var in v._coords.items() if k not in indexers} - pos_indexers[k] = DataArray(pos_indexers[k], coords=coords, dims=v.dims) + query_results.dim_indexers[k] = DataArray( + dim_indexer, coords=coords, dims=v.dims + ) - return pos_indexers, new_indexes, new_variables, drop_variables + return query_results diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c43015ef3e7..31bfc4ce9e1 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -463,28 +463,24 @@ def _replace_maybe_drop_dims( def _overwrite_indexes( self, indexes: Mapping[Hashable, Index], - coords: Mapping[Hashable, Variable], - drop_coords: List[Hashable], + coords: Optional[Mapping[Hashable, Variable]] = None, + drop_coords: Optional[List[Hashable]] = None, + rename_dims: Optional[Mapping[Hashable, Hashable]] = None, ) -> "DataArray": """Maybe replace indexes and their corresponding coordinates.""" if not indexes: return self - assert indexes.keys() == coords.keys() + if coords is None: + coords = {} + if drop_coords is None: + drop_coords = [] new_variable = self.variable.copy() new_coords = self._coords.copy() new_indexes = dict(self.xindexes) - dims_dict = {} for name in indexes: - # new coordinate variables may have renamed dimensions (e.g., level - # name of a multi-index converted to a single index) - old_vs_new_dims = zip(self._coords[name].dims, coords[name].dims) - for old_dim, new_dim in old_vs_new_dims: - if old_dim != new_dim: - dims_dict[old_dim] = new_dim - new_coords[name] = coords[name] new_indexes[name] = indexes[name] @@ -492,8 +488,8 @@ def _overwrite_indexes( new_coords.pop(name) new_indexes.pop(name) - if dims_dict: - new_variable.dims = [dims_dict.get(d, d) for d in new_variable.dims] + if rename_dims: + new_variable.dims = [rename_dims.get(d, d) for d in new_variable.dims] return self._replace( variable=new_variable, coords=new_coords, indexes=new_indexes diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 00df4acb76b..2c7ca5a959e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1166,30 +1166,24 @@ def _replace_vars_and_dims( def _overwrite_indexes( self, indexes: Mapping[Hashable, Index], - variables: Mapping[Hashable, Variable], - drop_variables: List[Hashable], + variables: Optional[Mapping[Hashable, Variable]] = None, + drop_variables: Optional[List[Hashable]] = None, + rename_dims: Optional[Mapping[Hashable, Hashable]] = None, ) -> "Dataset": """Maybe replace indexes and their corresponding index variables.""" if not indexes: return self - assert indexes.keys() == variables.keys() + if variables is None: + variables = {} + if drop_variables is None: + drop_variables = [] new_variables = self._variables.copy() new_coord_names = self._coord_names.copy() new_indexes = dict(self.xindexes) - dims_dict = {} for name in indexes: - # new coordinate variables may have renamed dimensions (e.g., level - # name of a multi-index converted to a single index) - # TODO: instead of infer renamed dimensions from the coordinates, - # should we require explicitly providing it from Index.query? - old_vs_new_dims = zip(self._variables[name].dims, variables[name].dims) - for old_dim, new_dim in old_vs_new_dims: - if old_dim != new_dim: - dims_dict[old_dim] = new_dim - new_variables[name] = variables[name] new_indexes[name] = indexes[name] @@ -1202,10 +1196,10 @@ def _overwrite_indexes( variables=new_variables, coord_names=new_coord_names, indexes=new_indexes ) - if dims_dict: + if rename_dims: # skip rename indexes: they should already have the right name(s) - dims = replaced._rename_dims(dims_dict) - new_variables, new_coord_names = replaced._rename_vars({}, dims_dict) + dims = replaced._rename_dims(rename_dims) + new_variables, new_coord_names = replaced._rename_vars({}, rename_dims) return replaced._replace( variables=new_variables, coord_names=new_coord_names, dims=dims ) @@ -2480,12 +2474,12 @@ def sel( DataArray.sel """ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") - pos_indexers, new_indexes, new_variables, drop_variables = remap_label_indexers( + query_results = remap_label_indexers( self, indexers=indexers, method=method, tolerance=tolerance ) - result = self.isel(indexers=pos_indexers, drop=drop) - return result._overwrite_indexes(new_indexes, new_variables, drop_variables) + result = self.isel(indexers=query_results.dim_indexers, drop=drop) + return result._overwrite_indexes(*query_results.to_tuple()[1:]) def head( self, diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 865df57fa53..0dcba5d7c1c 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -16,7 +16,7 @@ import pandas as pd from . import formatting, utils -from .indexing import PandasIndexingAdapter, PandasMultiIndexingAdapter +from .indexing import PandasIndexingAdapter, PandasMultiIndexingAdapter, QueryResult from .utils import is_dict_like, is_scalar if TYPE_CHECKING: @@ -47,7 +47,7 @@ def to_pandas_index(self) -> pd.Index: def query( self, labels: Dict[Hashable, Any], **kwargs - ) -> Tuple[Mapping[str, Any], Optional[IndexWithVars]]: # pragma: no cover + ) -> QueryResult: # pragma: no cover raise NotImplementedError() def equals(self, other): # pragma: no cover @@ -243,7 +243,7 @@ def query(self, labels, method=None, tolerance=None): if np.any(indexer < 0): raise KeyError(f"not all values found in index {coord_name!r}") - return {self.dim: indexer}, None + return QueryResult({self.dim: indexer}) def equals(self, other): return self.index.equals(other.index) @@ -425,13 +425,27 @@ def query(self, labels, method=None, tolerance=None): new_index, new_vars = PandasMultiIndex.from_pandas_index( new_index, self.dim ) + dims_dict = {} + drop_coords = set(self.index.names) - set(new_index.index.names) else: new_index, new_vars = PandasIndex.from_pandas_index( new_index, new_index.name ) - return {self.dim: indexer}, (new_index, new_vars) + dims_dict = {self.dim: new_index.index.name} + drop_coords = set(self.index.names) - {new_index.index.name} | { + self.dim + } + + return QueryResult( + {self.dim: indexer}, + index=new_index, + index_vars=new_vars, + drop_coords=list(drop_coords), + rename_dims=dims_dict, + ) + else: - return {self.dim: indexer}, None + return QueryResult({self.dim: indexer}) def remove_unused_levels_categories(index: pd.Index) -> pd.Index: diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index ee7953de287..b574673c4b4 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1,19 +1,21 @@ import enum import functools import operator -from collections import defaultdict +from collections import Counter, defaultdict from contextlib import suppress from datetime import timedelta from typing import ( TYPE_CHECKING, Any, Callable, + DefaultDict, Dict, Hashable, Iterable, List, Mapping, Optional, + Sequence, Tuple, Union, ) @@ -33,41 +35,118 @@ from .utils import maybe_cast_to_coords_dtype if TYPE_CHECKING: - from .indexes import Index + from .dataarray import DataArray + from .dataset import Dataset + from .indexes import Index, IndexVars -def expanded_indexer(key, ndim): - """Given a key for indexing an ndarray, return an equivalent key which is a - tuple with length equal to the number of dimensions. +class QueryResult: + """Index query results. + + Parameters + ---------- + dim_indexers: dict + A dictionary where keys are array dimensions and values are + location-based indexers. + index: :class:`Index`, optional + A new index object to replace in the resulting DataArray or Dataset. + index_vars : dict, optional + New indexed variables to replace in the resulting DataArray or Dataset. + drop_coords : list, optional + Coordinate(s) to drop in the resulting DataArray or Dataset. + rename_dims : dict, optional + A dictionnary in the form ``{old_dim: new_dim}`` for dimension(s) to + rename in the resulting DataArray or Dataset. - The expansion is done by replacing all `Ellipsis` items with the right - number of full slices and then padding the key with full slices so that it - reaches the appropriate dimensionality. """ - if not isinstance(key, tuple): - # numpy treats non-tuple keys equivalent to tuples of length 1 - key = (key,) - new_key = [] - # handling Ellipsis right is a little tricky, see: - # http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing - found_ellipsis = False - for k in key: - if k is Ellipsis: - if not found_ellipsis: - new_key.extend((ndim + 1 - len(key)) * [slice(None)]) - found_ellipsis = True - else: - new_key.append(slice(None)) + + dim_indexers: Mapping[Hashable, Any] + indexes: Dict[Hashable, "Index"] + index_vars: "IndexVars" + drop_coords: List[Hashable] + rename_dims: Mapping[Hashable, Hashable] + + __slots__ = ("dim_indexers", "indexes", "index_vars", "drop_coords", "rename_dims") + + def __init__( + self, + dim_indexers: Mapping[Hashable, Any], + index: Optional["Index"] = None, + index_vars: Optional["IndexVars"] = None, + drop_coords: Optional[Sequence[Hashable]] = None, + rename_dims: Optional[Mapping[Hashable, Hashable]] = None, + ): + self.dim_indexers = dim_indexers + + if index_vars is None: + index_vars = {} + self.index_vars = index_vars + + # map the new index to all indexed variables + if index is not None: + self.indexes = {k: index for k in self.index_vars} else: - new_key.append(k) - if len(new_key) > ndim: - raise IndexError("too many indices") - new_key.extend((ndim - len(new_key)) * [slice(None)]) - return tuple(new_key) + self.indexes = {} + if drop_coords is None: + drop_coords = [] + self.drop_coords = list(drop_coords) -def _expand_slice(slice_, size): - return np.arange(*slice_.indices(size)) + if rename_dims is None: + rename_dims = {} + self.rename_dims = rename_dims + + +class MergedQueryResults: + """Results merged from all index queries executed during a single selection + operation.""" + + dim_indexers: Dict[Hashable, Any] + indexes: Dict[Hashable, "Index"] + index_vars: "IndexVars" + drop_coords: List[Hashable] + rename_dims: Dict[Hashable, Hashable] + + __slots__ = ("dim_indexers", "indexes", "index_vars", "drop_coords", "rename_dims") + + def __init__(self, query_results: List[QueryResult]): + all_dims_count = Counter( + [dim for res in query_results for dim in res.dim_indexers] + ) + duplicate_dims = {k: v for k, v in all_dims_count.items() if v > 1} + + if duplicate_dims: + fmt_dims = [ + f"{dim!r}: {count} indexes involved" + for dim, count in duplicate_dims.items() + ] + raise ValueError( + "Xarray does not support label-based selection with more than one index" + "over the following dimension(s):\n" + + "\n".join(fmt_dims) + + "Suggestion: use a multi-index for each of those dimension(s)." + ) + + self.dim_indexers = { + k: v for res in query_results for k, v in res.dim_indexers.items() + } + self.indexes = {k: v for res in query_results for k, v in res.indexes.items()} + self.index_vars = { + k: v for res in query_results for k, v in res.index_vars.items() + } + self.drop_coords = [c for res in query_results for c in res.drop_coords] + self.rename_dims = { + k: v for res in query_results for k, v in res.rename_dims.items() + } + + def to_tuple(self): + return ( + self.dim_indexers, + self.indexes, + self.index_vars, + self.drop_coords, + self.rename_dims, + ) def group_coords_by_index( @@ -89,28 +168,34 @@ def group_coords_by_index( return {tuple(grouped_coord_names[k]): index_unique[k] for k in index_unique} -def group_indexers_by_index(data_obj, indexers, **kwargs): +def group_indexers_by_index( + obj: Union["DataArray", "Dataset"], + indexers: Mapping[Hashable, Any], + query_kwargs: Mapping[str, Any], +) -> Tuple[Dict[int, "Index"], Dict[Union[int, None], Dict[Hashable, Any]]]: """Returns a dictionary of unique index items and another dictionary of label indexers grouped by index (both using the same index ids as keys). """ unique_indexes = {} - grouped_indexers = defaultdict(dict) + grouped_indexers: DefaultDict[Union[int, None], Dict[Hashable, Any]] = defaultdict( + dict + ) for key, label in indexers.items(): - index = data_obj.xindexes.get(key, None) - coord = data_obj.coords.get(key, None) + index = obj.xindexes.get(key, None) + coord = obj.coords.get(key, None) if index is not None: index_id = id(index) unique_indexes[index_id] = index - label = maybe_cast_to_coords_dtype(label, coord.dtype) + label = maybe_cast_to_coords_dtype(label, coord.dtype) # type: ignore grouped_indexers[index_id][key] = label elif coord is not None: raise KeyError(f"no index found for coordinate {key}") - elif key not in data_obj.dims: + elif key not in obj.dims: raise KeyError(f"{key} is not a valid dimension or coordinate") - elif len(kwargs): + elif len(query_kwargs): raise ValueError( "cannot supply selection options " "when the indexed dimension does not have " @@ -121,48 +206,67 @@ def group_indexers_by_index(data_obj, indexers, **kwargs): # failback to location-based selection grouped_indexers[None][key] = label - return unique_indexes, grouped_indexers - + return unique_indexes, dict(grouped_indexers) -def remap_label_indexers(data_obj, indexers, **kwargs): - """Given an xarray data object and label based indexers, returns: - - a mapping of equivalent location based indexers - - a mapping of updated indexes (if any) - - a mapping of updated index variables (if any) - - a list of variables to drop (if any) +def remap_label_indexers( + obj: Union["DataArray", "Dataset"], + indexers: Mapping[Hashable, Any], + **query_kwargs, + # query_kwargs: Mapping[str, Any], + # **indexers_kwargs, +) -> MergedQueryResults: + """Execute index queries from a DataArray / Dataset and label-based indexers + and return the (merged) query results. """ - pos_indexers = {} - new_indexes = {} - new_variables = {} - drop_variables = [] + # indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "map_index_queries") + indexes, grouped_indexers = group_indexers_by_index(obj, indexers, query_kwargs) - indexes, grouped_indexers = group_indexers_by_index(data_obj, indexers, **kwargs) + results = [] - forward_pos_indexers = grouped_indexers.pop(None, None) - if forward_pos_indexers is not None: - for dim, label in forward_pos_indexers.items(): - pos_indexers[dim] = label + # forward dimension indexers with no index/coordinate + results.append(QueryResult(grouped_indexers.pop(None, {}))) for index_id, index in indexes.items(): labels = grouped_indexers[index_id] - pos_idxr, new_idx_and_vars = index.query(labels, **kwargs) - pos_indexers.update(pos_idxr) - - if new_idx_and_vars is not None: - new_idx, new_vars = new_idx_and_vars - new_variables.update(new_vars) - for k in new_vars: - new_indexes[k] = new_idx - for k, idx in data_obj.xindexes.items(): - if id(idx) == index_id and k not in new_vars: - drop_variables.append(k) - - # TODO: benbovy - flexible indexes: support the following cases: - # - check/combine positional indexers returned by multiple indexes over the same dimension(s) - - return pos_indexers, new_indexes, new_variables, drop_variables + results.append(index.query(labels, **query_kwargs)) + + return MergedQueryResults(results) + + +def expanded_indexer(key, ndim): + """Given a key for indexing an ndarray, return an equivalent key which is a + tuple with length equal to the number of dimensions. + + The expansion is done by replacing all `Ellipsis` items with the right + number of full slices and then padding the key with full slices so that it + reaches the appropriate dimensionality. + """ + if not isinstance(key, tuple): + # numpy treats non-tuple keys equivalent to tuples of length 1 + key = (key,) + new_key = [] + # handling Ellipsis right is a little tricky, see: + # http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing + found_ellipsis = False + for k in key: + if k is Ellipsis: + if not found_ellipsis: + new_key.extend((ndim + 1 - len(key)) * [slice(None)]) + found_ellipsis = True + else: + new_key.append(slice(None)) + else: + new_key.append(k) + if len(new_key) > ndim: + raise IndexError("too many indices") + new_key.extend((ndim - len(new_key)) * [slice(None)]) + return tuple(new_key) + + +def _expand_slice(slice_, size): + return np.arange(*slice_.indices(size)) def _normalize_slice(sl, size): From 3ca6a022d79991ac9ab427d14a22280ad4d7a3f6 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 31 Aug 2021 13:50:56 +0200 Subject: [PATCH 022/159] fix some tests + minor tweaks --- xarray/core/indexing.py | 41 ++++++++++++++++------------- xarray/tests/test_dataset.py | 2 +- xarray/tests/test_indexes.py | 14 +++++----- xarray/tests/test_indexing.py | 49 ++++++++++++++++++++--------------- 4 files changed, 59 insertions(+), 47 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index b574673c4b4..10c5bf2801f 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -127,17 +127,18 @@ def __init__(self, query_results: List[QueryResult]): + "Suggestion: use a multi-index for each of those dimension(s)." ) - self.dim_indexers = { - k: v for res in query_results for k, v in res.dim_indexers.items() - } - self.indexes = {k: v for res in query_results for k, v in res.indexes.items()} - self.index_vars = { - k: v for res in query_results for k, v in res.index_vars.items() - } - self.drop_coords = [c for res in query_results for c in res.drop_coords] - self.rename_dims = { - k: v for res in query_results for k, v in res.rename_dims.items() - } + self.dim_indexers = {} + self.indexes = {} + self.index_vars = {} + self.drop_coords = [] + self.rename_dims = {} + + for res in query_results: + self.dim_indexers.update(res.dim_indexers) + self.indexes.update(res.indexes) + self.index_vars.update(res.index_vars) + self.drop_coords += res.drop_coords + self.rename_dims.update(res.rename_dims) def to_tuple(self): return ( @@ -191,15 +192,14 @@ def group_indexers_by_index( unique_indexes[index_id] = index label = maybe_cast_to_coords_dtype(label, coord.dtype) # type: ignore grouped_indexers[index_id][key] = label - elif coord is not None: + elif key in obj.coords: raise KeyError(f"no index found for coordinate {key}") elif key not in obj.dims: raise KeyError(f"{key} is not a valid dimension or coordinate") elif len(query_kwargs): raise ValueError( - "cannot supply selection options " - "when the indexed dimension does not have " - "an associated coordinate." + f"cannot supply selection options {query_kwargs!r} for dimension {key!r}" + "that has no asssociated coordinate or index" ) else: # key is a dimension without coordinate @@ -212,14 +212,19 @@ def group_indexers_by_index( def remap_label_indexers( obj: Union["DataArray", "Dataset"], indexers: Mapping[Hashable, Any], - **query_kwargs, - # query_kwargs: Mapping[str, Any], - # **indexers_kwargs, + method=None, + tolerance=None, ) -> MergedQueryResults: """Execute index queries from a DataArray / Dataset and label-based indexers and return the (merged) query results. """ + # TODO benbovy - flexible indexes: remove when custom index options are available + if method is None and tolerance is None: + query_kwargs = {} + else: + query_kwargs = {"method": method, "tolerance": tolerance} + # indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "map_index_queries") indexes, grouped_indexers = group_indexers_by_index(obj, indexers, query_kwargs) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index d07f4ad6639..94aabdf0e64 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1680,7 +1680,7 @@ def test_sel_method(self): with pytest.raises(TypeError, match=r"``method``"): # this should not pass silently - data.sel(method=data) + data.sel(dim2=1, method=data) # cannot pass method if there is no associated coordinate with pytest.raises(ValueError, match=r"cannot supply"): diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 11a81cd0f3b..8ed08ee4f4a 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -86,11 +86,11 @@ def test_query_datetime(self): pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x" ) actual = index.query({"x": "2001-01-01"}) - expected = ({"x": 1}, None) - assert actual == expected + expected_dim_indexers = {"x": 1} + assert actual.dim_indexers == expected_dim_indexers actual = index.query({"x": index.to_pandas_index().to_numpy()[1]}) - assert actual == expected + assert actual.dim_indexers == expected_dim_indexers def test_query_unsorted_datetime_index_raises(self): index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"]), "x") @@ -191,11 +191,11 @@ def test_query(self): index = PandasMultiIndex( pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x" ) + # test tuples inside slice are considered as scalar indexer values - assert index.query({"x": slice(("a", 1), ("b", 2))}) == ( - {"x": slice(0, 4)}, - None, - ) + actual = index.query({"x": slice(("a", 1), ("b", 2))}) + expected_dim_indexers = {"x": slice(0, 4)} + assert actual.dim_indexers == expected_dim_indexers with pytest.raises(KeyError, match=r"not all values found"): index.query({"x": [0]}) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 40d93e33ed7..b9c0499bfe3 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -64,21 +64,23 @@ def test_group_indexers_by_index(self): data.coords["y2"] = ("y", [2.0, 3.0]) indexes, grouped_indexers = indexing.group_indexers_by_index( - data, {"z": 0, "one": "a", "two": 1, "y": 0} + data, {"z": 0, "one": "a", "two": 1, "y": 0}, {} ) - assert indexes == {"x": data.xindexes["x"], "y": data.xindexes["y"]} - assert grouped_indexers == { - "x": {"one": "a", "two": 1}, - "y": {"y": 0}, - None: {"z": 0}, - } + for k in indexes: + if indexes[k].equals(data.xindexes["x"]): + assert grouped_indexers[k] == {"one": "a", "two": 1} + elif indexes[k].equals(data.xindexes["y"]): + assert grouped_indexers[k] == {"y": 0} + assert grouped_indexers[None] == {"z": 0} + grouped_indexers.pop(None) + assert indexes.keys() == grouped_indexers.keys() with pytest.raises(KeyError, match=r"no index found for coordinate y2"): - indexing.group_indexers_by_index(data, {"y2": 2.0}) + indexing.group_indexers_by_index(data, {"y2": 2.0}, {}) with pytest.raises(KeyError, match=r"w is not a valid dimension or coordinate"): - indexing.group_indexers_by_index(data, {"w": "a"}) + indexing.group_indexers_by_index(data, {"w": "a"}, {}) with pytest.raises(ValueError, match=r"cannot supply.*"): - indexing.group_indexers_by_index(data, {"z": 1}, method="nearest") + indexing.group_indexers_by_index(data, {"z": 1}, {"method": "nearest"}) def test_remap_label_indexers(self): def test_indexer( @@ -88,6 +90,7 @@ def test_indexer( expected_idx=None, expected_vars=None, expected_drop=None, + expected_rename_dims=None, ): if expected_vars is None: expected_vars = {} @@ -97,22 +100,23 @@ def test_indexer( expected_idx = {k: expected_idx for k in expected_vars} if expected_drop is None: expected_drop = [] + if expected_rename_dims is None: + expected_rename_dims = {} - pos, new_idx, new_vars, drop_vars = indexing.remap_label_indexers( - data, {"x": x} - ) + results = indexing.remap_label_indexers(data, {"x": x}) - assert_array_equal(pos.get("x"), expected_pos) + assert_array_equal(results.dim_indexers.get("x"), expected_pos) - assert new_idx.keys() == expected_idx.keys() - for k in new_idx: - assert new_idx[k].equals(expected_idx[k]) + assert results.indexes.keys() == expected_idx.keys() + for k in results.indexes: + assert results.indexes[k].equals(expected_idx[k]) - assert new_vars.keys() == expected_vars.keys() - for k in new_vars: - assert_array_equal(new_vars[k], expected_vars[k]) + assert results.index_vars.keys() == expected_vars.keys() + for k in results.index_vars: + assert_array_equal(results.index_vars[k], expected_vars[k]) - assert drop_vars == expected_drop + assert set(results.drop_coords) == set(expected_drop) + assert results.rename_dims == expected_rename_dims data = Dataset({"x": ("x", [1, 2, 3])}) mindex = pd.MultiIndex.from_product( @@ -130,6 +134,7 @@ def test_indexer( [True, True, False, False, False, False, False, False], *PandasIndex.from_pandas_index(pd.Index([-1, -2]), "three"), ["x", "one", "two"], + {"x": "three"}, ) test_indexer( mdata, @@ -161,6 +166,7 @@ def test_indexer( [True, True, False, False, False, False, False, False], *PandasIndex.from_pandas_index(pd.Index([-1, -2]), "three"), ["x", "one", "two"], + {"x": "three"}, ) test_indexer( mdata, @@ -168,6 +174,7 @@ def test_indexer( [True, False, True, False, False, False, False, False], *PandasIndex.from_pandas_index(pd.Index([1, 2]), "two"), ["x", "one", "three"], + {"x": "two"}, ) test_indexer( mdata, From 05c488d6744d2f0976355abd3bf10e757d751242 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 31 Aug 2021 13:55:53 +0200 Subject: [PATCH 023/159] fix indexing PandasMultiIndexingAdapater When level is not None: - if result is another adapter: propagate it properly - if result is a numpy adapter: use level values --- xarray/core/indexing.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 10c5bf2801f..415b382230b 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1483,7 +1483,7 @@ def __getitem__( (key,) = key if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional - return NumpyIndexingAdapter(self.array.values)[indexer] + return NumpyIndexingAdapter(np.asarray(self))[indexer] result = self.array[key] @@ -1542,6 +1542,13 @@ def _convert_scalar(self, item): item = item[idx] return super()._convert_scalar(item) + def __getitem__(self, indexer): + result = super().__getitem__(indexer) + if isinstance(result, type(self)): + result.level = self.level + + return result + def __repr__(self) -> str: if self.level is None: return super().__repr__() From cc2d9c9c90c9fa16e64901ca052dc7dae850f0be Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 31 Aug 2021 13:59:09 +0200 Subject: [PATCH 024/159] refactor cast label indexer to coord dtype Make the fix in #3153 specific to pandas indexes (i.e., do not apply it to other, custom indexes). See #5697 for details. This should also fix #5700 although no test has been added yet (we need to refactor set_index first). --- xarray/core/indexes.py | 35 ++++++++++++++++++++++++++++------- xarray/core/indexing.py | 3 --- xarray/core/utils.py | 6 ------ 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 0dcba5d7c1c..ee39873d2ec 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -129,9 +129,12 @@ def _is_nested_tuple(possible_tuple): ) -def normalize_label(value, extract_scalar=False): +def normalize_label(value, extract_scalar=False, dtype=None): if getattr(value, "ndim", 1) <= 1: value = _asarray_tuplesafe(value) + if dtype is not None and dtype.kind == "f": + # see https://github.com/pydata/xarray/pull/3153 for details + value = np.asarray(value, dtype=dtype) if extract_scalar: # see https://github.com/pydata/xarray/pull/4292 for details value = value[()] if value.dtype.kind in "mM" else value.item() @@ -151,12 +154,16 @@ def get_indexer_nd(index, labels, method=None, tolerance=None): class PandasIndex(Index): """Wrap a pandas.Index as an xarray compatible index.""" - __slots__ = ("index", "dim") + __slots__ = ("index", "dim", "coord_dtype") - def __init__(self, array: Any, dim: Hashable): + def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None): self.index = utils.safe_cast_to_index(array) self.dim = dim + if coord_dtype is None: + coord_dtype = self.index.dtype + self.coord_dtype = coord_dtype + @classmethod def from_variables(cls, variables: Mapping[Hashable, "Variable"]): from .variable import IndexVariable @@ -176,7 +183,7 @@ def from_variables(cls, variables: Mapping[Hashable, "Variable"]): dim = var.dims[0] - obj = cls(var.data, dim) + obj = cls(var.data, dim, coord_dtype=var.dtype) data = PandasIndexingAdapter(obj.index, dtype=var.dtype) index_var = IndexVariable( @@ -219,7 +226,7 @@ def query(self, labels, method=None, tolerance=None): "a dimension that does not have a MultiIndex" ) else: - label = normalize_label(label) + label = normalize_label(label, dtype=self.coord_dtype) if label.ndim == 0: label_value = normalize_label(label, extract_scalar=True) if isinstance(self.index, pd.CategoricalIndex): @@ -289,6 +296,16 @@ def _create_variables_from_multiindex(index, dim, level_meta=None): class PandasMultiIndex(PandasIndex): + + __slots__ = ("index", "dim", "coord_dtype", "level_coords_dtype") + + def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None): + super().__init__(array, dim) + + if level_coords_dtype is None: + level_coords_dtype = {idx.name: idx.dtype for idx in self.index.levels} + self.level_coords_dtype = level_coords_dtype + @classmethod def from_variables(cls, variables: Mapping[Hashable, "Variable"]): if any([var.ndim != 1 for var in variables.values()]): @@ -305,7 +322,8 @@ def from_variables(cls, variables: Mapping[Hashable, "Variable"]): index = pd.MultiIndex.from_arrays( [var.values for var in variables.values()], names=variables.keys() ) - obj = cls(index, dim) + level_coords_dtype = {name: var.dtype for name, var in variables.items()} + obj = cls(index, dim, level_coords_dtype=level_coords_dtype) level_meta = { name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} @@ -346,7 +364,10 @@ def query(self, labels, method=None, tolerance=None): if all([lbl in self.index.names for lbl in labels]): is_nested_vals = _is_nested_tuple(tuple(labels.values())) labels = { - k: normalize_label(v, extract_scalar=True) for k, v in labels.items() + k: normalize_label( + v, extract_scalar=True, dtype=self.level_coords_dtype[k] + ) + for k, v in labels.items() } if len(labels) == self.index.nlevels and not is_nested_vals: diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 415b382230b..60de1a4d9f0 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -32,7 +32,6 @@ is_duck_dask_array, sparse_array_type, ) -from .utils import maybe_cast_to_coords_dtype if TYPE_CHECKING: from .dataarray import DataArray @@ -185,12 +184,10 @@ def group_indexers_by_index( for key, label in indexers.items(): index = obj.xindexes.get(key, None) - coord = obj.coords.get(key, None) if index is not None: index_id = id(index) unique_indexes[index_id] = index - label = maybe_cast_to_coords_dtype(label, coord.dtype) # type: ignore grouped_indexers[index_id][key] = label elif key in obj.coords: raise KeyError(f"no index found for coordinate {key}") diff --git a/xarray/core/utils.py b/xarray/core/utils.py index a139d2ef10a..668fe7b9bde 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -72,12 +72,6 @@ def _maybe_cast_to_cftimeindex(index: pd.Index) -> pd.Index: return index -def maybe_cast_to_coords_dtype(label, coords_dtype): - if coords_dtype.kind == "f" and not isinstance(label, slice): - label = np.asarray(label, dtype=coords_dtype) - return label - - def maybe_coerce_to_str(index, original_coords): """maybe coerce a pandas Index back to a nunpy array of type str From a8d84c7035b69f59cb4568d39d2c370f4a49e3e1 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 31 Aug 2021 17:15:21 +0200 Subject: [PATCH 025/159] better handling of multi-index level labels --- xarray/core/indexes.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 5975af8c158..9836657e538 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -7,7 +7,6 @@ Iterable, Mapping, Optional, - Sequence, Tuple, Union, ) @@ -362,26 +361,28 @@ def query(self, labels, method=None, tolerance=None): # label(s) given for multi-index level(s) if all([lbl in self.index.names for lbl in labels]): - is_nested_vals = _is_nested_tuple(tuple(labels.values())) - labels = { - k: normalize_label( - v, extract_scalar=True, dtype=self.level_coords_dtype[k] - ) - for k, v in labels.items() - } + label_values = {} + for k, v in labels.items(): + try: + label_values[k] = normalize_label( + v, extract_scalar=True, dtype=self.level_coords_dtype[k] + ) + except ValueError: + # label should be an item not an array-like + raise ValueError( + "Vectorized selection is not " + f"available along coordinate {k!r} (multi-index level)" + ) - if len(labels) == self.index.nlevels and not is_nested_vals: - indexer = self.index.get_loc(tuple(labels[k] for k in self.index.names)) + has_slice = any([isinstance(v, slice) for v in label_values.values()]) + + if len(label_values) == self.index.nlevels and not has_slice: + indexer = self.index.get_loc( + tuple(label_values[k] for k in self.index.names) + ) else: - for k, v in labels.items(): - # index should be an item (i.e. Hashable) not an array-like - if isinstance(v, Sequence) and not isinstance(v, str): - raise ValueError( - "Vectorized selection is not " - f"available along coordinate {k!r} (multi-index level)" - ) indexer, new_index = self.index.get_loc_level( - tuple(labels.values()), level=tuple(labels.keys()) + tuple(label_values.values()), level=tuple(label_values.keys()) ) # GH2619. Raise a KeyError if nothing is chosen if indexer.dtype.kind == "b" and indexer.sum() == 0: From a6bce37e0c1fa7c40e1a8cc8d5f9e82dae40cba7 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 9 Sep 2021 12:08:42 +0200 Subject: [PATCH 026/159] label-based selection tweaks and fixes - Use a dataclass for QueryResult - Pass Variables and DataArrays indexers un-normalized to Index.query(). Indexes have the responsibility of returning the expected types for positional (dimension) indexers by adding back dimensions and coordinates if needed. - Typing fixes and tweaks --- xarray/core/coordinates.py | 51 +--------- xarray/core/dataarray.py | 12 +-- xarray/core/dataset.py | 17 ++-- xarray/core/indexes.py | 77 +++++++++------ xarray/core/indexing.py | 174 +++++++++++++--------------------- xarray/tests/test_indexing.py | 6 +- 6 files changed, 130 insertions(+), 207 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 1e7ea584816..5bfbd347620 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -16,10 +16,10 @@ import numpy as np import pandas as pd -from . import formatting, indexing +from . import formatting from .indexes import Index, Indexes from .merge import merge_coordinates_without_align, merge_coords -from .utils import Frozen, ReprObject, either_dict_or_kwargs +from .utils import Frozen, ReprObject from .variable import Variable if TYPE_CHECKING: @@ -390,50 +390,3 @@ def assert_coordinate_consistent( f"dimension coordinate {k!r} conflicts between " f"indexed and indexing objects:\n{obj[k]}\nvs.\n{coords[k]}" ) - - -def remap_label_indexers( - obj: Union["DataArray", "Dataset"], - indexers: Mapping[Any, Any] = None, - method: str = None, - tolerance=None, - **indexers_kwargs: Any, -) -> Any: - """Remap indexers from obj.coords. - If indexer is an instance of DataArray and it has coordinate, then this coordinate - will be attached to pos_indexers. - - Returns - ------- - pos_indexers: Same type of indexers. - np.ndarray or Variable or DataArray - new_indexes: mapping of new dimensional-coordinate. - - """ - from .dataarray import DataArray - - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "remap_label_indexers") - - v_indexers = { - k: v.variable.data if isinstance(v, DataArray) else v - for k, v in indexers.items() - } - - query_results = indexing.remap_label_indexers( - obj, v_indexers, method=method, tolerance=tolerance - ) - - # attach indexer's coordinate to pos_indexers - for k, v in indexers.items(): - dim_indexer = query_results.dim_indexers.get(k, None) - if isinstance(v, Variable): - query_results.dim_indexers[k] = Variable(v.dims, dim_indexer) - elif isinstance(v, DataArray): - # drop coordinates found in indexers since .sel() already - # ensures alignments - coords = {k: var for k, var in v._coords.items() if k not in indexers} - query_results.dim_indexers[k] = DataArray( - dim_indexer, coords=coords, dims=v.dims - ) - - return query_results diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 670491f155a..bcef818238c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -45,15 +45,11 @@ from .arithmetic import DataArrayArithmetic from .common import AbstractArray, DataWithCoords from .computation import unify_chunks -from .coordinates import ( - DataArrayCoordinates, - assert_coordinate_consistent, - remap_label_indexers, -) +from .coordinates import DataArrayCoordinates, assert_coordinate_consistent from .dataset import Dataset, split_indexes from .formatting import format_item from .indexes import Index, Indexes, default_indexes, propagate_indexes -from .indexing import is_fancy_indexer +from .indexing import is_fancy_indexer, map_index_queries from .merge import PANDAS_TYPES, MergeError, _create_indexes_from_coords from .options import OPTIONS, _get_keep_attrs from .utils import ( @@ -205,8 +201,8 @@ def __setitem__(self, key, value) -> None: labels = indexing.expanded_indexer(key, self.data_array.ndim) key = dict(zip(self.data_array.dims, labels)) - pos_indexers, _, _, _ = remap_label_indexers(self.data_array, key) - self.data_array[pos_indexers] = value + dim_indexers = map_index_queries(self.data_array, key).dim_indexers + self.data_array[dim_indexers] = value # Used as the key corresponding to a DataArray's variable when converting diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7c5453af152..6ceef16f122 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4,6 +4,7 @@ import sys import warnings from collections import defaultdict +from dataclasses import astuple from html import escape from numbers import Number from operator import methodcaller @@ -54,11 +55,7 @@ from .arithmetic import DatasetArithmetic from .common import DataWithCoords, _contains_datetime_like_objects from .computation import unify_chunks -from .coordinates import ( - DatasetCoordinates, - assert_coordinate_consistent, - remap_label_indexers, -) +from .coordinates import DatasetCoordinates, assert_coordinate_consistent from .duck_array_ops import datetime_to_numeric from .indexes import ( Index, @@ -71,7 +68,7 @@ remove_unused_levels_categories, roll_index, ) -from .indexing import is_fancy_indexer +from .indexing import is_fancy_indexer, map_index_queries from .merge import ( dataset_merge_method, dataset_update_method, @@ -557,8 +554,8 @@ def __setitem__(self, key, value) -> None: ) # set new values - pos_indexers, _, _, _ = remap_label_indexers(self.dataset, key) - self.dataset[pos_indexers] = value + dim_indexers = map_index_queries(self.dataset, key).dim_indexers + self.dataset[dim_indexers] = value class Dataset(DataWithCoords, DatasetArithmetic, Mapping): @@ -2477,12 +2474,12 @@ def sel( DataArray.sel """ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") - query_results = remap_label_indexers( + query_results = map_index_queries( self, indexers=indexers, method=method, tolerance=tolerance ) result = self.isel(indexers=query_results.dim_indexers, drop=drop) - return result._overwrite_indexes(*query_results.to_tuple()[1:]) + return result._overwrite_indexes(*astuple(query_results)[1:]) def head( self, diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 9836657e538..a6fd2887cab 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -21,17 +21,15 @@ if TYPE_CHECKING: from .variable import IndexVariable, Variable -IndexVars = Dict[Hashable, "IndexVariable"] -IndexWithVars = Tuple["Index", Optional[IndexVars]] +IndexVars = Dict[Any, "IndexVariable"] class Index: """Base class inherited by all xarray-compatible indexes.""" - @classmethod def from_variables( cls, variables: Mapping[Any, "Variable"] - ) -> IndexWithVars: # pragma: no cover + ) -> Tuple["Index", Optional[IndexVars]]: raise NotImplementedError() def to_pandas_index(self) -> pd.Index: @@ -44,9 +42,7 @@ def to_pandas_index(self) -> pd.Index: """ raise TypeError(f"{type(self)} cannot be cast to a pandas.Index object.") - def query( - self, labels: Dict[Hashable, Any], **kwargs - ) -> QueryResult: # pragma: no cover + def query(self, labels: Dict[Any, Any]) -> QueryResult: raise NotImplementedError() def equals(self, other): # pragma: no cover @@ -128,18 +124,21 @@ def _is_nested_tuple(possible_tuple): ) -def normalize_label(value, extract_scalar=False, dtype=None): +def normalize_label(value, dtype=None) -> np.ndarray: if getattr(value, "ndim", 1) <= 1: value = _asarray_tuplesafe(value) if dtype is not None and dtype.kind == "f": + # pd.Index built from coordinate with float precision != 64 # see https://github.com/pydata/xarray/pull/3153 for details value = np.asarray(value, dtype=dtype) - if extract_scalar: - # see https://github.com/pydata/xarray/pull/4292 for details - value = value[()] if value.dtype.kind in "mM" else value.item() return value +def as_scalar(value: np.ndarray): + # see https://github.com/pydata/xarray/pull/4292 for details + return value[()] if value.dtype.kind in "mM" else value.item() + + def get_indexer_nd(index, labels, method=None, tolerance=None): """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional labels @@ -153,6 +152,10 @@ def get_indexer_nd(index, labels, method=None, tolerance=None): class PandasIndex(Index): """Wrap a pandas.Index as an xarray compatible index.""" + index: pd.Index + dim: Hashable + coord_dtype: Any + __slots__ = ("index", "dim", "coord_dtype") def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None): @@ -164,7 +167,9 @@ def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None): self.coord_dtype = coord_dtype @classmethod - def from_variables(cls, variables: Mapping[Any, "Variable"]): + def from_variables( + cls, variables: Mapping[Any, "Variable"] + ) -> Tuple["PandasIndex", IndexVars]: from .variable import IndexVariable if len(variables) != 1: @@ -181,9 +186,7 @@ def from_variables(cls, variables: Mapping[Any, "Variable"]): ) dim = var.dims[0] - obj = cls(var.data, dim, coord_dtype=var.dtype) - data = PandasIndexingAdapter(obj.index, dtype=var.dtype) index_var = IndexVariable( dim, data, attrs=var.attrs, encoding=var.encoding, fastpath=True @@ -192,7 +195,9 @@ def from_variables(cls, variables: Mapping[Any, "Variable"]): return obj, {name: index_var} @classmethod - def from_pandas_index(cls, index: pd.Index, dim: Hashable): + def from_pandas_index( + cls, index: pd.Index, dim: Hashable + ) -> Tuple["PandasIndex", IndexVars]: from .variable import IndexVariable if index.name is None: @@ -210,7 +215,10 @@ def from_pandas_index(cls, index: pd.Index, dim: Hashable): def to_pandas_index(self) -> pd.Index: return self.index - def query(self, labels, method=None, tolerance=None): + def query(self, labels: Dict[Any, Any], method=None, tolerance=None) -> QueryResult: + from .dataarray import DataArray + from .variable import Variable + if method is not None and not isinstance(method, str): raise TypeError("``method`` must be a string") @@ -225,30 +233,36 @@ def query(self, labels, method=None, tolerance=None): "a dimension that does not have a MultiIndex" ) else: - label = normalize_label(label, dtype=self.coord_dtype) - if label.ndim == 0: - label_value = normalize_label(label, extract_scalar=True) + label_array = normalize_label(label, dtype=self.coord_dtype) + if label_array.ndim == 0: + label_value = as_scalar(label_array) if isinstance(self.index, pd.CategoricalIndex): if method is not None: raise ValueError( - "'method' is not a valid kwarg when indexing using a CategoricalIndex." + "'method' is not supported when indexing using a CategoricalIndex." ) if tolerance is not None: raise ValueError( - "'tolerance' is not a valid kwarg when indexing using a CategoricalIndex." + "'tolerance' is not supported when indexing using a CategoricalIndex." ) indexer = self.index.get_loc(label_value) else: indexer = self.index.get_loc( label_value, method=method, tolerance=tolerance ) - elif label.dtype.kind == "b": - indexer = label + elif label_array.dtype.kind == "b": + indexer = label_array else: - indexer = get_indexer_nd(self.index, label, method, tolerance) + indexer = get_indexer_nd(self.index, label_array, method, tolerance) if np.any(indexer < 0): raise KeyError(f"not all values found in index {coord_name!r}") + # attach dimension names and/or coordinates to positional indexer + if isinstance(label, Variable): + indexer = Variable(label.dims, indexer) + elif isinstance(label, DataArray): + indexer = DataArray(indexer, coords=label._coords, dims=label.dims) + return QueryResult({self.dim: indexer}) def equals(self, other): @@ -296,6 +310,8 @@ def _create_variables_from_multiindex(index, dim, level_meta=None): class PandasMultiIndex(PandasIndex): + level_coords_dtype: Dict[str, Any] + __slots__ = ("index", "dim", "coord_dtype", "level_coords_dtype") def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None): @@ -335,7 +351,9 @@ def from_variables(cls, variables: Mapping[Any, "Variable"]): return obj, index_vars @classmethod - def from_pandas_index(cls, index: pd.MultiIndex, dim: Hashable): + def from_pandas_index( + cls, index: pd.MultiIndex, dim: Hashable + ) -> Tuple["PandasMultiIndex", IndexVars]: level_meta = {} for i, idx in enumerate(index.levels): name = idx.name or f"{dim}_level_{i}" @@ -351,7 +369,7 @@ def from_pandas_index(cls, index: pd.MultiIndex, dim: Hashable): ) return cls(index, dim), index_vars - def query(self, labels, method=None, tolerance=None): + def query(self, labels, method=None, tolerance=None) -> QueryResult: if method is not None or tolerance is not None: raise ValueError( "multi-index does not support ``method`` and ``tolerance``" @@ -363,10 +381,9 @@ def query(self, labels, method=None, tolerance=None): if all([lbl in self.index.names for lbl in labels]): label_values = {} for k, v in labels.items(): + label_array = normalize_label(v, dtype=self.level_coords_dtype[k]) try: - label_values[k] = normalize_label( - v, extract_scalar=True, dtype=self.level_coords_dtype[k] - ) + label_values[k] = as_scalar(label_array) except ValueError: # label should be an item not an array-like raise ValueError( @@ -460,7 +477,7 @@ def query(self, labels, method=None, tolerance=None): return QueryResult( {self.dim: indexer}, - index=new_index, + indexes={k: new_index for k in new_vars}, index_vars=new_vars, drop_coords=list(drop_coords), rename_dims=dims_dict, diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 60de1a4d9f0..c52fe4557d4 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -3,19 +3,18 @@ import operator from collections import Counter, defaultdict from contextlib import suppress +from dataclasses import dataclass, field from datetime import timedelta from typing import ( TYPE_CHECKING, Any, Callable, - DefaultDict, Dict, Hashable, Iterable, List, Mapping, Optional, - Sequence, Tuple, Union, ) @@ -32,6 +31,7 @@ is_duck_dask_array, sparse_array_type, ) +from .utils import either_dict_or_kwargs if TYPE_CHECKING: from .dataarray import DataArray @@ -39,16 +39,17 @@ from .indexes import Index, IndexVars +@dataclass class QueryResult: """Index query results. - Parameters + Attributes ---------- dim_indexers: dict A dictionary where keys are array dimensions and values are location-based indexers. - index: :class:`Index`, optional - A new index object to replace in the resulting DataArray or Dataset. + indexes: dict, optional + New indexes to replace in the resulting DataArray or Dataset. index_vars : dict, optional New indexed variables to replace in the resulting DataArray or Dataset. drop_coords : list, optional @@ -59,98 +60,47 @@ class QueryResult: """ - dim_indexers: Mapping[Hashable, Any] - indexes: Dict[Hashable, "Index"] - index_vars: "IndexVars" - drop_coords: List[Hashable] - rename_dims: Mapping[Hashable, Hashable] - - __slots__ = ("dim_indexers", "indexes", "index_vars", "drop_coords", "rename_dims") - - def __init__( - self, - dim_indexers: Mapping[Hashable, Any], - index: Optional["Index"] = None, - index_vars: Optional["IndexVars"] = None, - drop_coords: Optional[Sequence[Hashable]] = None, - rename_dims: Optional[Mapping[Hashable, Hashable]] = None, - ): - self.dim_indexers = dim_indexers - - if index_vars is None: - index_vars = {} - self.index_vars = index_vars - - # map the new index to all indexed variables - if index is not None: - self.indexes = {k: index for k in self.index_vars} - else: - self.indexes = {} - - if drop_coords is None: - drop_coords = [] - self.drop_coords = list(drop_coords) - - if rename_dims is None: - rename_dims = {} - self.rename_dims = rename_dims + dim_indexers: Dict[Any, Any] + indexes: Dict[Hashable, "Index"] = field(default_factory=dict) + index_vars: "IndexVars" = field(default_factory=dict) + drop_coords: List[Hashable] = field(default_factory=list) + rename_dims: Dict[Any, Hashable] = field(default_factory=dict) -class MergedQueryResults: - """Results merged from all index queries executed during a single selection - operation.""" +def merge_query_results(results: List[QueryResult]) -> QueryResult: + all_dims_count = Counter([dim for res in results for dim in res.dim_indexers]) + duplicate_dims = {k: v for k, v in all_dims_count.items() if v > 1} - dim_indexers: Dict[Hashable, Any] - indexes: Dict[Hashable, "Index"] - index_vars: "IndexVars" - drop_coords: List[Hashable] - rename_dims: Dict[Hashable, Hashable] - - __slots__ = ("dim_indexers", "indexes", "index_vars", "drop_coords", "rename_dims") - - def __init__(self, query_results: List[QueryResult]): - all_dims_count = Counter( - [dim for res in query_results for dim in res.dim_indexers] + if duplicate_dims: + fmt_dims = [ + f"{dim!r}: {count} indexes involved" + for dim, count in duplicate_dims.items() + ] + raise ValueError( + "Xarray does not support label-based selection with more than one index" + "over the following dimension(s):\n" + + "\n".join(fmt_dims) + + "Suggestion: use a multi-index for each of those dimension(s)." ) - duplicate_dims = {k: v for k, v in all_dims_count.items() if v > 1} - if duplicate_dims: - fmt_dims = [ - f"{dim!r}: {count} indexes involved" - for dim, count in duplicate_dims.items() - ] - raise ValueError( - "Xarray does not support label-based selection with more than one index" - "over the following dimension(s):\n" - + "\n".join(fmt_dims) - + "Suggestion: use a multi-index for each of those dimension(s)." - ) + dim_indexers = {} + indexes = {} + index_vars = {} + drop_coords = [] + rename_dims = {} - self.dim_indexers = {} - self.indexes = {} - self.index_vars = {} - self.drop_coords = [] - self.rename_dims = {} - - for res in query_results: - self.dim_indexers.update(res.dim_indexers) - self.indexes.update(res.indexes) - self.index_vars.update(res.index_vars) - self.drop_coords += res.drop_coords - self.rename_dims.update(res.rename_dims) - - def to_tuple(self): - return ( - self.dim_indexers, - self.indexes, - self.index_vars, - self.drop_coords, - self.rename_dims, - ) + for res in results: + dim_indexers.update(res.dim_indexers) + indexes.update(res.indexes) + index_vars.update(res.index_vars) + drop_coords += res.drop_coords + rename_dims.update(res.rename_dims) + + return QueryResult(dim_indexers, indexes, index_vars, drop_coords, rename_dims) def group_coords_by_index( - indexes: Mapping[Hashable, "Index"] + indexes: Mapping[Any, "Index"] ) -> Dict[Tuple[Hashable, ...], "Index"]: """From a flat mapping of coordinate names to their corresponding index, return a dictionnary of unique index items with the name(s) of all their corresponding @@ -170,20 +120,18 @@ def group_coords_by_index( def group_indexers_by_index( obj: Union["DataArray", "Dataset"], - indexers: Mapping[Hashable, Any], - query_kwargs: Mapping[str, Any], -) -> Tuple[Dict[int, "Index"], Dict[Union[int, None], Dict[Hashable, Any]]]: + indexers: Mapping[Any, Any], + options: Mapping[str, Any], +) -> Tuple[Dict[int, "Index"], Dict[Union[int, None], Dict]]: """Returns a dictionary of unique index items and another dictionary of label indexers grouped by index (both using the same index ids as keys). """ unique_indexes = {} - grouped_indexers: DefaultDict[Union[int, None], Dict[Hashable, Any]] = defaultdict( - dict - ) + grouped_indexers: Mapping[Union[int, None], Dict] = defaultdict(dict) for key, label in indexers.items(): - index = obj.xindexes.get(key, None) + index: "Index" = obj.xindexes.get(key, None) if index is not None: index_id = id(index) @@ -193,10 +141,10 @@ def group_indexers_by_index( raise KeyError(f"no index found for coordinate {key}") elif key not in obj.dims: raise KeyError(f"{key} is not a valid dimension or coordinate") - elif len(query_kwargs): + elif len(options): raise ValueError( - f"cannot supply selection options {query_kwargs!r} for dimension {key!r}" - "that has no asssociated coordinate or index" + f"cannot supply selection options {options!r} for dimension {key!r}" + "that has no associated coordinate or index" ) else: # key is a dimension without coordinate @@ -206,24 +154,27 @@ def group_indexers_by_index( return unique_indexes, dict(grouped_indexers) -def remap_label_indexers( +def map_index_queries( obj: Union["DataArray", "Dataset"], - indexers: Mapping[Hashable, Any], + indexers: Mapping[Any, Any], method=None, tolerance=None, -) -> MergedQueryResults: + **indexers_kwargs: Any, +) -> QueryResult: """Execute index queries from a DataArray / Dataset and label-based indexers and return the (merged) query results. """ + from .dataarray import DataArray + # TODO benbovy - flexible indexes: remove when custom index options are available if method is None and tolerance is None: - query_kwargs = {} + options = {} else: - query_kwargs = {"method": method, "tolerance": tolerance} + options = {"method": method, "tolerance": tolerance} - # indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "map_index_queries") - indexes, grouped_indexers = group_indexers_by_index(obj, indexers, query_kwargs) + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "map_index_queries") + indexes, grouped_indexers = group_indexers_by_index(obj, indexers, options) results = [] @@ -232,9 +183,18 @@ def remap_label_indexers( for index_id, index in indexes.items(): labels = grouped_indexers[index_id] - results.append(index.query(labels, **query_kwargs)) + results.append(index.query(labels, **options)) # type: ignore[call-arg] + + merged = merge_query_results(results) + + # drop dimension coordinates found in dimension indexers + # (.sel() already ensures alignment) + for k, v in merged.dim_indexers.items(): + if isinstance(v, DataArray): + drop_coords = [name for name in v._coords if name in merged.dim_indexers] + merged.dim_indexers[k] = v.drop_vars(drop_coords) - return MergedQueryResults(results) + return merged def expanded_indexer(key, ndim): diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 67885bbb777..920b301487a 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -82,7 +82,7 @@ def test_group_indexers_by_index(self) -> None: with pytest.raises(ValueError, match=r"cannot supply.*"): indexing.group_indexers_by_index(data, {"z": 1}, {"method": "nearest"}) - def test_remap_label_indexers(self): + def test_map_index_queries(self) -> None: def test_indexer( data, x, @@ -91,7 +91,7 @@ def test_indexer( expected_vars=None, expected_drop=None, expected_rename_dims=None, - ): + ) -> None: if expected_vars is None: expected_vars = {} if expected_idx is None: @@ -103,7 +103,7 @@ def test_indexer( if expected_rename_dims is None: expected_rename_dims = {} - results = indexing.remap_label_indexers(data, {"x": x}) + results = indexing.map_index_queries(data, {"x": x}) assert_array_equal(results.dim_indexers.get("x"), expected_pos) From ad17931618e4043b424ebe8df2562f232e8fc45d Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 9 Sep 2021 15:00:22 +0200 Subject: [PATCH 027/159] sel: propagate multi-index vars attrs/encoding Related to https://github.com/pydata/xarray/issues/1366 --- xarray/core/dataset.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6ceef16f122..043bb1650e0 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1178,8 +1178,16 @@ def _overwrite_indexes( new_coord_names = self._coord_names.copy() new_indexes = dict(self.xindexes) + for name, var in variables.items(): + old_var = self._variables.get(name) + if old_var is not None: + # propagate attrs and encoding + # TODO: needs a test + var.attrs = {**old_var.attrs, **var.attrs} + var.encoding = {**old_var.encoding, **var.encoding} + new_variables[name] = var + for name in indexes: - new_variables[name] = variables[name] new_indexes[name] = indexes[name] for name in drop_variables: @@ -2346,6 +2354,8 @@ def isel( indexes.pop(var_name, None) continue if indexes and var_name in indexes: + # TODO benbovy - flexible indexes: this won't be always desirable + # (e.g., 1-d out-of-core coordinate, "meta"-index, etc.) if var_value.ndim == 1: indexes[var_name] = var_value._to_xindex() else: From b894f8d85853015bda4e51d052b9c67eb1281cff Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 9 Sep 2021 17:13:44 +0200 Subject: [PATCH 028/159] wip refactor rename Still needs tests Also need to address the problem of multi-index level coordinates (data adapters currently not updated). We'll probably need for `Index.rename()` to also return new index variables? --- xarray/core/dataset.py | 33 +++++++-------------- xarray/core/indexes.py | 56 +++++++++++++++++++++++++++++++++-- xarray/core/indexing.py | 44 +++++++-------------------- xarray/tests/test_indexing.py | 19 ++++++------ 4 files changed, 86 insertions(+), 66 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 043bb1650e0..01474075a37 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -63,6 +63,7 @@ PandasIndex, PandasMultiIndex, default_indexes, + group_coords_by_index, isel_variable_and_index, propagate_indexes, remove_unused_levels_categories, @@ -93,13 +94,7 @@ is_scalar, maybe_wrap_array, ) -from .variable import ( - IndexVariable, - Variable, - as_variable, - assert_unique_multiindex_level_names, - broadcast_variables, -) +from .variable import IndexVariable, Variable, as_variable, broadcast_variables if TYPE_CHECKING: from ..backends import AbstractDataStore, ZarrStore @@ -3310,28 +3305,23 @@ def _rename_vars(self, name_dict, dims_dict): def _rename_dims(self, name_dict): return {name_dict.get(k, k): v for k, v in self.dims.items()} - def _rename_indexes(self, name_dict, dims_set): - # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5645 + def _rename_indexes(self, name_dict, dims_dict): if self._indexes is None: return None + indexes = {} - for k, v in self.indexes.items(): - new_name = name_dict.get(k, k) - if new_name not in dims_set: - continue - if isinstance(v, pd.MultiIndex): - new_names = [name_dict.get(k, k) for k in v.names] - indexes[new_name] = PandasMultiIndex( - v.rename(names=new_names), new_name - ) - else: - indexes[new_name] = PandasIndex(v.rename(new_name), new_name) + + for index, coord_names in group_coords_by_index(self.xindexes): + new_index = index.rename(name_dict, dims_dict) + new_coord_names = [name_dict.get(k, k) for k in coord_names] + indexes.update({k: new_index for k in new_coord_names}) + return indexes def _rename_all(self, name_dict, dims_dict): variables, coord_names = self._rename_vars(name_dict, dims_dict) dims = self._rename_dims(dims_dict) - indexes = self._rename_indexes(name_dict, dims.keys()) + indexes = self._rename_indexes(name_dict, dims_dict) return variables, coord_names, dims, indexes def rename( @@ -3373,7 +3363,6 @@ def rename( variables, coord_names, dims, indexes = self._rename_all( name_dict=name_dict, dims_dict=name_dict ) - assert_unique_multiindex_level_names(variables) return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename_dims( diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index a6fd2887cab..9e3746ae027 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1,10 +1,12 @@ import collections.abc +from collections import defaultdict from typing import ( TYPE_CHECKING, Any, Dict, Hashable, Iterable, + List, Mapping, Optional, Tuple, @@ -27,6 +29,7 @@ class Index: """Base class inherited by all xarray-compatible indexes.""" + @classmethod def from_variables( cls, variables: Mapping[Any, "Variable"] ) -> Tuple["Index", Optional[IndexVars]]: @@ -54,6 +57,11 @@ def union(self, other): # pragma: no cover def intersection(self, other): # pragma: no cover raise NotImplementedError() + def rename( + self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable] + ) -> "Index": + return self + def copy(self, deep: bool = True): # pragma: no cover raise NotImplementedError() @@ -166,6 +174,13 @@ def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None): coord_dtype = self.index.dtype self.coord_dtype = coord_dtype + def _replace(self, index, dim=None, coord_dtype=None) -> "PandasIndex": + if dim is None: + dim = self.dim + if coord_dtype is None: + coord_dtype = self.coord_dtype + return type(self)(index, dim, coord_dtype) + @classmethod def from_variables( cls, variables: Mapping[Any, "Variable"] @@ -187,6 +202,7 @@ def from_variables( dim = var.dims[0] obj = cls(var.data, dim, coord_dtype=var.dtype) + obj.index.name = name data = PandasIndexingAdapter(obj.index, dtype=var.dtype) index_var = IndexVariable( dim, data, attrs=var.attrs, encoding=var.encoding, fastpath=True @@ -276,11 +292,17 @@ def intersection(self, other): new_index = self.index.intersection(other.index) return type(self)(new_index, self.dim) + def rename(self, name_dict, dims_dict): + new_name = name_dict.get(self.index.name, self.index.name) + idx = self.index.rename(new_name) + new_dim = dims_dict.get(self.dim, self.dim) + return self._replace(idx, dim=new_dim) + def copy(self, deep=True): - return type(self)(self.index.copy(deep=deep), self.dim) + return self._replace(self.index.copy(deep=deep)) def __getitem__(self, indexer: Any): - return type(self)(self.index[indexer], self.dim) + return self._replace(self.index[indexer]) def _create_variables_from_multiindex(index, dim, level_meta=None): @@ -321,6 +343,13 @@ def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None): level_coords_dtype = {idx.name: idx.dtype for idx in self.index.levels} self.level_coords_dtype = level_coords_dtype + def _replace(self, index, dim=None, level_coords_dtype=None) -> "PandasMultiIndex": + if dim is None: + dim = self.dim + if level_coords_dtype is None: + level_coords_dtype = self.level_coords_dtype + return type(self)(index, dim, level_coords_dtype) + @classmethod def from_variables(cls, variables: Mapping[Any, "Variable"]): if any([var.ndim != 1 for var in variables.values()]): @@ -486,6 +515,14 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: else: return QueryResult({self.dim: indexer}) + def rename(self, name_dict, dims_dict): + # pandas 1.3.0: could simply do `self.index.rename(names_dict)` + new_names = [name_dict.get(k, k) for k in self.index.names] + idx = self.index.rename(new_names) + new_dim = dims_dict.get(self.dim, self.dim) + + return self._replace(idx, dim=new_dim) + def remove_unused_levels_categories(index: pd.Index) -> pd.Index: """ @@ -545,6 +582,21 @@ def __repr__(self): return formatting.indexes_repr(self) +def group_coords_by_index( + indexes: Mapping[Any, Index] +) -> List[Tuple[Index, List[Hashable]]]: + """Returns a list of unique indexes and their corresponding coordinate names.""" + unique_indexes: Dict[int, Index] = {} + grouped_coord_names: Mapping[int, List[Hashable]] = defaultdict(list) + + for coord_name, index_obj in indexes.items(): + index_id = id(index_obj) + unique_indexes[index_id] = index_obj + grouped_coord_names[index_id].append(coord_name) + + return [(unique_indexes[k], grouped_coord_names[k]) for k in unique_indexes] + + def default_indexes( coords: Mapping[Any, "Variable"], dims: Iterable ) -> Dict[Hashable, Index]: diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index c52fe4557d4..044843c5f94 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -99,34 +99,12 @@ def merge_query_results(results: List[QueryResult]) -> QueryResult: return QueryResult(dim_indexers, indexes, index_vars, drop_coords, rename_dims) -def group_coords_by_index( - indexes: Mapping[Any, "Index"] -) -> Dict[Tuple[Hashable, ...], "Index"]: - """From a flat mapping of coordinate names to their corresponding index, return - a dictionnary of unique index items with the name(s) of all their corresponding - coordinate(s) (tuple) as keys. - - """ - index_unique: Dict[int, "Index"] = {} - grouped_coord_names = defaultdict(list) - - for coord_name, index_obj in indexes.items(): - index_id = id(index_obj) - index_unique[index_id] = index_obj - grouped_coord_names[index_id].append(coord_name) - - return {tuple(grouped_coord_names[k]): index_unique[k] for k in index_unique} - - def group_indexers_by_index( obj: Union["DataArray", "Dataset"], indexers: Mapping[Any, Any], options: Mapping[str, Any], -) -> Tuple[Dict[int, "Index"], Dict[Union[int, None], Dict]]: - """Returns a dictionary of unique index items and another dictionary of label indexers - grouped by index (both using the same index ids as keys). - - """ +) -> List[Tuple["Index", Dict[Any, Any]]]: + """Returns a list of unique indexes and their corresponding indexers.""" unique_indexes = {} grouped_indexers: Mapping[Union[int, None], Dict] = defaultdict(dict) @@ -149,9 +127,10 @@ def group_indexers_by_index( else: # key is a dimension without coordinate # failback to location-based selection + unique_indexes[None] = None grouped_indexers[None][key] = label - return unique_indexes, dict(grouped_indexers) + return [(unique_indexes[k], grouped_indexers[k]) for k in unique_indexes] def map_index_queries( @@ -174,16 +153,15 @@ def map_index_queries( options = {"method": method, "tolerance": tolerance} indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "map_index_queries") - indexes, grouped_indexers = group_indexers_by_index(obj, indexers, options) + grouped_indexers = group_indexers_by_index(obj, indexers, options) results = [] - - # forward dimension indexers with no index/coordinate - results.append(QueryResult(grouped_indexers.pop(None, {}))) - - for index_id, index in indexes.items(): - labels = grouped_indexers[index_id] - results.append(index.query(labels, **options)) # type: ignore[call-arg] + for index, labels in grouped_indexers: + if index is None: + # forward dimension indexers with no index/coordinate + results.append(QueryResult(labels)) + else: + results.append(index.query(labels, **options)) # type: ignore[call-arg] merged = merge_query_results(results) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 920b301487a..ee0daad0829 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -63,17 +63,18 @@ def test_group_indexers_by_index(self) -> None: ) data.coords["y2"] = ("y", [2.0, 3.0]) - indexes, grouped_indexers = indexing.group_indexers_by_index( + grouped_indexers = indexing.group_indexers_by_index( data, {"z": 0, "one": "a", "two": 1, "y": 0}, {} ) - for k in indexes: - if indexes[k].equals(data.xindexes["x"]): - assert grouped_indexers[k] == {"one": "a", "two": 1} - elif indexes[k].equals(data.xindexes["y"]): - assert grouped_indexers[k] == {"y": 0} - assert grouped_indexers[None] == {"z": 0} - grouped_indexers.pop(None) - assert indexes.keys() == grouped_indexers.keys() + + for idx, indexers in grouped_indexers: + if idx is None: + assert indexers == {"z": 0} + elif idx.equals(data.xindexes["x"]): + assert indexers == {"one": "a", "two": 1} + elif idx.equals(data.xindexes["y"]): + assert indexers == {"y": 0} + assert len(grouped_indexers) == 3 with pytest.raises(KeyError, match=r"no index found for coordinate y2"): indexing.group_indexers_by_index(data, {"y2": 2.0}, {}) From e42e324539daafa8d8e3039efd9c716e0b215993 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 10 Sep 2021 12:57:31 +0200 Subject: [PATCH 029/159] wip refactor rename: return new vars from Index --- xarray/core/dataset.py | 35 +++++++++++++++++++------------- xarray/core/indexes.py | 44 +++++++++++++++++++++++++++++++++-------- xarray/core/variable.py | 14 +++++++++++++ 3 files changed, 71 insertions(+), 22 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 01474075a37..3d423825165 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -94,7 +94,13 @@ is_scalar, maybe_wrap_array, ) -from .variable import IndexVariable, Variable, as_variable, broadcast_variables +from .variable import ( + IndexVariable, + Variable, + as_variable, + broadcast_variables, + propagate_attrs_encoding, +) if TYPE_CHECKING: from ..backends import AbstractDataStore, ZarrStore @@ -1169,20 +1175,14 @@ def _overwrite_indexes( if drop_variables is None: drop_variables = [] + propagate_attrs_encoding(self._variables, variables) + new_variables = self._variables.copy() new_coord_names = self._coord_names.copy() new_indexes = dict(self.xindexes) - for name, var in variables.items(): - old_var = self._variables.get(name) - if old_var is not None: - # propagate attrs and encoding - # TODO: needs a test - var.attrs = {**old_var.attrs, **var.attrs} - var.encoding = {**old_var.encoding, **var.encoding} - new_variables[name] = var - for name in indexes: + new_variables[name] = variables[name] new_indexes[name] = indexes[name] for name in drop_variables: @@ -3307,21 +3307,28 @@ def _rename_dims(self, name_dict): def _rename_indexes(self, name_dict, dims_dict): if self._indexes is None: - return None + return None, {} indexes = {} + variables = {} for index, coord_names in group_coords_by_index(self.xindexes): - new_index = index.rename(name_dict, dims_dict) + new_index, new_index_vars = index.rename(name_dict, dims_dict) + # map new index to its corresponding coordinates new_coord_names = [name_dict.get(k, k) for k in coord_names] indexes.update({k: new_index for k in new_coord_names}) + variables.update(new_index_vars) - return indexes + return indexes, variables def _rename_all(self, name_dict, dims_dict): variables, coord_names = self._rename_vars(name_dict, dims_dict) dims = self._rename_dims(dims_dict) - indexes = self._rename_indexes(name_dict, dims_dict) + + indexes, index_vars = self._rename_indexes(name_dict, dims_dict) + propagate_attrs_encoding(variables, index_vars) + variables = {k: index_vars.get(k, v) for k, v in variables.items()} + return variables, coord_names, dims, indexes def rename( diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 9e3746ae027..dba4c88363b 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -32,7 +32,7 @@ class Index: @classmethod def from_variables( cls, variables: Mapping[Any, "Variable"] - ) -> Tuple["Index", Optional[IndexVars]]: + ) -> Tuple["Index", IndexVars]: raise NotImplementedError() def to_pandas_index(self) -> pd.Index: @@ -59,8 +59,8 @@ def intersection(self, other): # pragma: no cover def rename( self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable] - ) -> "Index": - return self + ) -> Tuple["Index", IndexVars]: + return self, {} def copy(self, deep: bool = True): # pragma: no cover raise NotImplementedError() @@ -174,7 +174,7 @@ def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None): coord_dtype = self.index.dtype self.coord_dtype = coord_dtype - def _replace(self, index, dim=None, coord_dtype=None) -> "PandasIndex": + def _replace(self, index, dim=None, coord_dtype=None): if dim is None: dim = self.dim if coord_dtype is None: @@ -293,10 +293,17 @@ def intersection(self, other): return type(self)(new_index, self.dim) def rename(self, name_dict, dims_dict): + if self.index.name not in name_dict and self.dim not in dims_dict: + return self, {} + new_name = name_dict.get(self.index.name, self.index.name) - idx = self.index.rename(new_name) + pd_idx = self.index.rename(new_name) new_dim = dims_dict.get(self.dim, self.dim) - return self._replace(idx, dim=new_dim) + + index, index_vars = self.from_pandas_index(pd_idx, dim=new_dim) + index.coord_dtype = self.coord_dtype + + return index, index_vars def copy(self, deep=True): return self._replace(self.index.copy(deep=deep)) @@ -516,12 +523,20 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: return QueryResult({self.dim: indexer}) def rename(self, name_dict, dims_dict): + if not set(self.index.names) & set(name_dict) and self.dim not in dims_dict: + return self, {} + # pandas 1.3.0: could simply do `self.index.rename(names_dict)` new_names = [name_dict.get(k, k) for k in self.index.names] - idx = self.index.rename(new_names) + pd_idx = self.index.rename(new_names) new_dim = dims_dict.get(self.dim, self.dim) - return self._replace(idx, dim=new_dim) + index, index_vars = self.from_pandas_index(pd_idx, new_dim) + index.level_coords_dtype = { + k: v for k, v in zip(new_names, self.level_coords_dtype.values()) + } + + return index, index_vars def remove_unused_levels_categories(index: pd.Index) -> pd.Index: @@ -597,6 +612,19 @@ def group_coords_by_index( return [(unique_indexes[k], grouped_coord_names[k]) for k in unique_indexes] +def unique_indexes(indexes: Mapping[Any, Index]) -> List[Index]: + """Returns a list of unique indexes, preserving order.""" + unique_indexes = [] + seen = [] + + for index in indexes.values(): + if index not in seen: + unique_indexes.append(index) + seen.append(index) + + return unique_indexes + + def default_indexes( coords: Mapping[Any, "Variable"], dims: Iterable ) -> Dict[Hashable, Index]: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 191bb4059f5..c0cd5c462a3 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2941,3 +2941,17 @@ def assert_unique_multiindex_level_names(variables): "conflicting level / dimension names. {} " "already exists as a level name.".format(d) ) + + +def propagate_attrs_encoding( + old_variables: Mapping[Any, Variable], new_variables: Mapping[Any, Variable] +) -> None: + """Propagate any attrs and/or encoding items from old variables that are not present + in new variables. + + """ + for name, var in new_variables.items(): + old_var = old_variables.get(name) + if old_var is not None: + var.attrs = {**old_var.attrs, **var.attrs} + var.encoding = {**old_var.encoding, **var.encoding} From 5febd7f8648911467e499a353a664af7acb25710 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 10 Sep 2021 13:01:10 +0200 Subject: [PATCH 030/159] refactor rename: update tests Still need to address default indexes internal checks (these are disabled for now). --- xarray/testing.py | 18 +++++----- xarray/tests/test_dataset.py | 43 ++++++++++++++++++---- xarray/tests/test_indexes.py | 69 ++++++++++++++++++++++++++++++++++-- 3 files changed, 112 insertions(+), 18 deletions(-) diff --git a/xarray/testing.py b/xarray/testing.py index 40ca12852b9..673d474d6cb 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -8,7 +8,7 @@ from xarray.core import duck_array_ops, formatting, utils from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset -from xarray.core.indexes import Index, default_indexes +from xarray.core.indexes import Index from xarray.core.variable import IndexVariable, Variable __all__ = ( @@ -251,7 +251,7 @@ def assert_chunks_equal(a, b): assert left.chunks == right.chunks -def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): +def _assert_indexes_invariants_checks(indexes, possible_coord_variables): assert isinstance(indexes, dict), indexes assert all(isinstance(v, Index) for v in indexes.values()), { k: type(v) for k, v in indexes.items() @@ -262,11 +262,11 @@ def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): } assert indexes.keys() <= index_vars, (set(indexes), index_vars) - # Note: when we support non-default indexes, these checks should be opt-in - # only! - defaults = default_indexes(possible_coord_variables, dims) - assert indexes.keys() == defaults.keys(), (set(indexes), set(defaults)) - assert all(v.equals(defaults[k]) for k, v in indexes.items()), (indexes, defaults) + # TODO: benbovy - explicit indexes: do we still need these checks? Or opt-in? + # non-default indexes are now supported. + # defaults = default_indexes(possible_coord_variables, dims) + # assert indexes.keys() == defaults.keys(), (set(indexes), set(defaults)) + # assert all(v.equals(defaults[k]) for k, v in indexes.items()), (indexes, defaults) def _assert_variable_invariants(var: Variable, name: Hashable = None): @@ -302,7 +302,7 @@ def _assert_dataarray_invariants(da: DataArray): _assert_variable_invariants(v, k) if da._indexes is not None: - _assert_indexes_invariants_checks(da._indexes, da._coords, da.dims) + _assert_indexes_invariants_checks(da._indexes, da._coords) def _assert_dataset_invariants(ds: Dataset): @@ -336,7 +336,7 @@ def _assert_dataset_invariants(ds: Dataset): } if ds._indexes is not None: - _assert_indexes_invariants_checks(ds._indexes, ds._variables, ds._dims) + _assert_indexes_invariants_checks(ds._indexes, ds._variables) assert isinstance(ds._encoding, (type(None), dict)) assert isinstance(ds._attrs, (type(None), dict)) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index ee6a01fb94a..fe5cc17b809 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2670,13 +2670,42 @@ def test_rename_vars(self): with pytest.raises(ValueError): original.rename_vars(names_dict_bad) - def test_rename_multiindex(self): - mindex = pd.MultiIndex.from_tuples( - [([1, 2]), ([3, 4])], names=["level0", "level1"] - ) - data = Dataset({}, {"x": mindex}) - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): - data.rename({"x": "level0"}) + def test_rename_dimension_coord(self) -> None: + # rename a dimension corodinate to a non-dimension coordinate + # should preserve index + original = Dataset(coords={"x": ("x", [0, 1, 2])}) + + actual = original.rename_vars({"x": "x_new"}) + assert "x_new" in actual.xindexes + + actual_2 = original.rename_dims({"x": "x_new"}) + assert "x" in actual_2.xindexes + + def test_rename_multiindex(self) -> None: + mindex = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) + original = Dataset({}, {"x": mindex}) + expected = Dataset({}, {"x": mindex.rename(["a", "c"])}) + + actual = original.rename({"b": "c"}) + assert_identical(expected, actual) + + with pytest.raises(ValueError, match=r"'a' conflicts"): + original.rename({"x": "a"}) + with pytest.raises(ValueError, match=r"'x' conflicts"): + original.rename({"a": "x"}) + with pytest.raises(ValueError, match=r"'b' conflicts"): + original.rename({"a": "b"}) + + def test_rename_perserve_attrs_encoding(self) -> None: + # test propagate attrs/encoding to new variable(s) created from Index object + original = Dataset(coords={"x": ("x", [0, 1, 2])}) + expected = Dataset(coords={"y": ("y", [0, 1, 2])}) + for ds, dim in zip([original, expected], ["x", "y"]): + ds[dim].attrs = {"foo": "bar"} + ds[dim].encoding = {"foo": "bar"} + + actual = original.rename({"x": "y"}) + assert_identical(actual, expected) @requires_cftime def test_rename_does_not_change_CFTimeIndex_type(self): diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 7ed7f468303..97c503aea15 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -119,21 +119,43 @@ def test_intersection(self) -> None: assert actual.index.equals(pd.Index([2, 3])) assert actual.dim == "x" + def test_rename(self) -> None: + index = PandasIndex(pd.Index([1, 2, 3], name="a"), "x", coord_dtype=np.int32) + + # shortcut + new_index, index_vars = index.rename({}, {}) + assert new_index is index + assert index_vars == {} + + new_index, index_vars = index.rename({"a": "b"}, {}) + assert new_index.index.name == "b" + assert new_index.dim == "x" + assert new_index.coord_dtype == np.int32 + xr.testing.assert_identical(index_vars["b"], IndexVariable("x", [1, 2, 3])) + + new_index, index_vars = index.rename({}, {"x": "y"}) + assert new_index.index.name == "a" + assert new_index.dim == "y" + assert new_index.coord_dtype == np.int32 + xr.testing.assert_identical(index_vars["a"], IndexVariable("y", [1, 2, 3])) + def test_copy(self) -> None: - expected = PandasIndex([1, 2, 3], "x") + expected = PandasIndex([1, 2, 3], "x", coord_dtype=np.int32) actual = expected.copy() assert actual.index.equals(expected.index) assert actual.index is not expected.index assert actual.dim == expected.dim + assert actual.coord_dtype == expected.coord_dtype def test_getitem(self) -> None: pd_idx = pd.Index([1, 2, 3]) - expected = PandasIndex(pd_idx, "x") + expected = PandasIndex(pd_idx, "x", coord_dtype=np.int32) actual = expected[1:] assert actual.index.equals(pd_idx[1:]) assert actual.dim == expected.dim + assert actual.coord_dtype == expected.coord_dtype class TestPandasMultiIndex: @@ -207,3 +229,46 @@ def test_query(self) -> None: index.query({"x": {"three": 0}}) with pytest.raises(IndexError): index.query({"x": (slice(None), 1, "no_level")}) + + def test_rename(self) -> None: + level_coords_dtype = {"one": "U<1", "two": np.int32} + index = PandasMultiIndex( + pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), + "x", + level_coords_dtype=level_coords_dtype, + ) + + # shortcut + new_index, index_vars = index.rename({}, {}) + assert new_index is index + assert index_vars == {} + + new_index, index_vars = index.rename({"two": "three"}, {}) + assert new_index.index.names == ["one", "three"] + assert new_index.dim == "x" + assert new_index.level_coords_dtype == {"one": "U<1", "three": np.int32} + assert list(index_vars.keys()) == ["x", "one", "three"] + for v in index_vars.values(): + assert v.dims == ("x",) + + new_index, index_vars = index.rename({}, {"x": "y"}) + assert new_index.index.names == ["one", "two"] + assert new_index.dim == "y" + assert new_index.level_coords_dtype == level_coords_dtype + assert list(index_vars.keys()) == ["y", "one", "two"] + for v in index_vars.values(): + assert v.dims == ("y",) + + def test_copy(self) -> None: + level_coords_dtype = {"one": "U<1", "two": np.int32} + expected = PandasMultiIndex( + pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), + "x", + level_coords_dtype=level_coords_dtype, + ) + actual = expected.copy() + + assert actual.index.equals(expected.index) + assert actual.index is not expected.index + assert actual.dim == expected.dim + assert actual.level_coords_dtype == expected.level_coords_dtype From 508cdcc04b82f8af08121f76a6e0bfcb5d762598 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 10 Sep 2021 13:05:50 +0200 Subject: [PATCH 031/159] typing tweaks --- xarray/core/indexing.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 044843c5f94..ffff51d57d4 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -31,11 +31,10 @@ is_duck_dask_array, sparse_array_type, ) +from .types import T_Xarray from .utils import either_dict_or_kwargs if TYPE_CHECKING: - from .dataarray import DataArray - from .dataset import Dataset from .indexes import Index, IndexVars @@ -100,7 +99,7 @@ def merge_query_results(results: List[QueryResult]) -> QueryResult: def group_indexers_by_index( - obj: Union["DataArray", "Dataset"], + obj: T_Xarray, indexers: Mapping[Any, Any], options: Mapping[str, Any], ) -> List[Tuple["Index", Dict[Any, Any]]]: @@ -134,7 +133,7 @@ def group_indexers_by_index( def map_index_queries( - obj: Union["DataArray", "Dataset"], + obj: T_Xarray, indexers: Mapping[Any, Any], method=None, tolerance=None, From f1535a31255416c9ed60fa0425b2afefd0b674cd Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 10 Sep 2021 15:31:41 +0200 Subject: [PATCH 032/159] fix html formatting: dims w/ or w/o index --- xarray/core/formatting_html.py | 21 ++++++++++++++++----- xarray/tests/test_formatting_html.py | 16 ++++++++-------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index ac5b79b287d..e6b923bde47 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -29,12 +29,12 @@ def short_data_repr_html(array): return f"
    {text}
    " -def format_dims(dims, coord_names): +def format_dims(dims, dims_with_index): if not dims: return "" dim_css_map = { - k: " class='xr-has-index'" if k in coord_names else "" for k, v in dims.items() + dim: " class='xr-has-index'" if dim in dims_with_index else "" for dim in dims } dims_li = "".join( @@ -161,8 +161,20 @@ def _mapping_section( ) +def _dims_with_index(obj): + if not hasattr(obj, "indexes"): + return [] + + dims_with_index = set() + for coord_name in obj.indexes: + for dim in obj[coord_name].dims: + dims_with_index.add(dim) + + return dims_with_index + + def dim_section(obj): - dim_list = format_dims(obj.dims, list(obj.coords)) + dim_list = format_dims(obj.dims, _dims_with_index(obj)) return collapsible_section( "Dimensions", inline_details=dim_list, enabled=False, collapsed=True @@ -246,12 +258,11 @@ def array_repr(arr): obj_type = "xarray.{}".format(type(arr).__name__) arr_name = f"'{arr.name}'" if getattr(arr, "name", None) else "" - coord_names = list(arr.coords) if hasattr(arr, "coords") else [] header_components = [ f"
    {obj_type}
    ", f"
    {arr_name}
    ", - format_dims(dims, coord_names), + format_dims(dims, _dims_with_index(arr)), ] sections = [array_section(arr)] diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index 9c04e47c631..c67619e18c7 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -64,27 +64,27 @@ def test_short_data_repr_html_dask(dask_dataarray) -> None: def test_format_dims_no_dims() -> None: dims: Dict = {} - coord_names: List = [] - formatted = fh.format_dims(dims, coord_names) + dims_with_index: List = [] + formatted = fh.format_dims(dims, dims_with_index) assert formatted == "" def test_format_dims_unsafe_dim_name() -> None: dims = {"": 3, "y": 2} - coord_names: List = [] - formatted = fh.format_dims(dims, coord_names) + dims_with_index: List = [] + formatted = fh.format_dims(dims, dims_with_index) assert "<x>" in formatted def test_format_dims_non_index() -> None: - dims, coord_names = {"x": 3, "y": 2}, ["time"] - formatted = fh.format_dims(dims, coord_names) + dims, dims_with_index = {"x": 3, "y": 2}, ["time"] + formatted = fh.format_dims(dims, dims_with_index) assert "class='xr-has-index'" not in formatted def test_format_dims_index() -> None: - dims, coord_names = {"x": 3, "y": 2}, ["x"] - formatted = fh.format_dims(dims, coord_names) + dims, dims_with_index = {"x": 3, "y": 2}, ["x"] + formatted = fh.format_dims(dims, dims_with_index) assert "class='xr-has-index'" in formatted From 021090f44d83d6fb98dc48688d46ad2de76b9e2c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 13 Sep 2021 23:47:20 +0200 Subject: [PATCH 033/159] minor fixes and tweaks --- xarray/core/formatting_html.py | 2 +- xarray/core/indexing.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index e6b923bde47..209533b2027 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -166,7 +166,7 @@ def _dims_with_index(obj): return [] dims_with_index = set() - for coord_name in obj.indexes: + for coord_name in obj.xindexes: for dim in obj[coord_name].dims: dims_with_index.add(dim) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index ffff51d57d4..d930674937e 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -71,15 +71,17 @@ def merge_query_results(results: List[QueryResult]) -> QueryResult: duplicate_dims = {k: v for k, v in all_dims_count.items() if v > 1} if duplicate_dims: + # TODO: this message is not right when combining indexe(s) queries with + # location-based indexing on a dimension with no dimension-coordinate (failback) fmt_dims = [ f"{dim!r}: {count} indexes involved" for dim, count in duplicate_dims.items() ] raise ValueError( - "Xarray does not support label-based selection with more than one index" + "Xarray does not support label-based selection with more than one index " "over the following dimension(s):\n" + "\n".join(fmt_dims) - + "Suggestion: use a multi-index for each of those dimension(s)." + + "\nSuggestion: use a multi-index for each of those dimension(s)." ) dim_indexers = {} @@ -124,8 +126,9 @@ def group_indexers_by_index( "that has no associated coordinate or index" ) else: - # key is a dimension without coordinate + # key is a dimension without a "dimension-coordinate" # failback to location-based selection + # TODO: depreciate this implicit behavior and suggest using isel instead? unique_indexes[None] = None grouped_indexers[None][key] = label From 0086e326dcb03b76ef3e9db3220dce1e682bc23c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 13 Sep 2021 23:58:46 +0200 Subject: [PATCH 034/159] wip refactor set_index --- xarray/core/dataset.py | 173 ++++++++++++++++++++--------------------- xarray/core/indexes.py | 136 ++++++++++++++++++++++++-------- 2 files changed, 189 insertions(+), 120 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3d423825165..1dfa266bbd9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -194,90 +194,6 @@ def calculate_dimensions(variables: Mapping[Any, Variable]) -> Dict[Hashable, in return dims -def merge_indexes( - indexes: Mapping[Any, Union[Hashable, Sequence[Hashable]]], - variables: Mapping[Any, Variable], - coord_names: Set[Hashable], - append: bool = False, -) -> Tuple[Dict[Hashable, Variable], Set[Hashable]]: - """Merge variables into multi-indexes. - - Not public API. Used in Dataset and DataArray set_index - methods. - """ - vars_to_replace: Dict[Hashable, Variable] = {} - vars_to_remove: List[Hashable] = [] - dims_to_replace: Dict[Hashable, Hashable] = {} - error_msg = "{} is not the name of an existing variable." - - for dim, var_names in indexes.items(): - if isinstance(var_names, str) or not isinstance(var_names, Sequence): - var_names = [var_names] - - names: List[Hashable] = [] - codes: List[List[int]] = [] - levels: List[List[int]] = [] - current_index_variable = variables.get(dim) - - for n in var_names: - try: - var = variables[n] - except KeyError: - raise ValueError(error_msg.format(n)) - if ( - current_index_variable is not None - and var.dims != current_index_variable.dims - ): - raise ValueError( - f"dimension mismatch between {dim!r} {current_index_variable.dims} and {n!r} {var.dims}" - ) - - if current_index_variable is not None and append: - current_index = current_index_variable.to_index() - if isinstance(current_index, pd.MultiIndex): - names.extend(current_index.names) - codes.extend(current_index.codes) - levels.extend(current_index.levels) - else: - names.append(f"{dim}_level_0") - cat = pd.Categorical(current_index.values, ordered=True) - codes.append(cat.codes) - levels.append(cat.categories) - - if not len(names) and len(var_names) == 1: - idx = pd.Index(variables[var_names[0]].values) - - else: # MultiIndex - for n in var_names: - try: - var = variables[n] - except KeyError: - raise ValueError(error_msg.format(n)) - names.append(n) - cat = pd.Categorical(var.values, ordered=True) - codes.append(cat.codes) - levels.append(cat.categories) - - idx = pd.MultiIndex(levels, codes, names=names) - for n in names: - dims_to_replace[n] = dim - - vars_to_replace[dim] = IndexVariable(dim, idx) - vars_to_remove.extend(var_names) - - new_variables = {k: v for k, v in variables.items() if k not in vars_to_remove} - new_variables.update(vars_to_replace) - - # update dimensions if necessary, GH: 3512 - for k, v in new_variables.items(): - if any(d in dims_to_replace for d in v.dims): - new_dims = [dims_to_replace.get(d, d) for d in v.dims] - new_variables[k] = v._replace(dims=new_dims) - new_coord_names = coord_names | set(vars_to_replace) - new_coord_names -= set(vars_to_remove) - return new_variables, new_coord_names - - def split_indexes( dims_or_levels: Union[Hashable, Sequence[Hashable]], variables: Mapping[Any, Variable], @@ -3307,7 +3223,7 @@ def _rename_dims(self, name_dict): def _rename_indexes(self, name_dict, dims_dict): if self._indexes is None: - return None, {} + return {}, {} indexes = {} variables = {} @@ -3751,11 +3667,90 @@ def set_index( Dataset.reset_index Dataset.swap_dims """ - indexes = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") - variables, coord_names = merge_indexes( - indexes, self._variables, self._coord_names, append=append + dim_coords = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") + + new_indexes: Dict[Hashable, Index] = {} + new_variables: Dict[Hashable, IndexVariable] = {} + maybe_drop_indexes: List[Hashable] = [] + drop_variables: List[Hashable] = [] + replace_dims: Dict[Hashable, Hashable] = {} + + index_coord_names = { + k: coord_names + for _, coord_names in group_coords_by_index(self.xindexes) + for k in coord_names + } + + for dim, _var_names in dim_coords.items(): + if isinstance(_var_names, str) or not isinstance(_var_names, Sequence): + var_names = [_var_names] + else: + var_names = list(_var_names) + + invalid_vars = set(var_names) - set(self._variables) + if invalid_vars: + raise ValueError( + ", ".join([str(v) for v in invalid_vars]) + + " variable(s) do not exist" + ) + + current_coord_names = index_coord_names.get(dim, []) + + # drop any pre-existing index involved + maybe_drop_indexes.extend(current_coord_names + var_names) + for k in var_names: + maybe_drop_indexes.extend(index_coord_names.get(k, [])) + + drop_variables.extend(var_names) + + if len(var_names) == 1 and (not append or dim not in self.xindexes): + var_name = var_names[0] + var = self._variables[var_name] + if var.dims != (dim,): + raise ValueError( + f"dimension mismatch: try setting an index for dimension {dim!r} with " + f"variable {var_name!r} that has dimensions {var.dims}" + ) + idx, idx_vars = PandasIndex.from_variables({dim: var}) + else: + if append: + current_variables = { + k: self._variables[k] for k in current_coord_names + } + else: + current_variables = {} + idx, idx_vars = PandasMultiIndex.from_variables_maybe_expand( + dim, + current_variables, + {k: self._variables[k] for k in var_names}, + ) + for n in idx.index.names: + replace_dims[n] = dim + + new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) + + indexes_: Dict[Any, Index] = { + k: v for k, v in self.xindexes.items() if k not in maybe_drop_indexes + } + indexes_.update(new_indexes) + + variables = { + k: v for k, v in self._variables.items() if k not in drop_variables + } + variables.update(new_variables) + + # update dimensions if necessary, GH: 3512 + for k, v in variables.items(): + if any(d in replace_dims for d in v.dims): + new_dims = [replace_dims.get(d, d) for d in v.dims] + variables[k] = v._replace(dims=new_dims) + + coord_names = set(new_variables) | self._coord_names + + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes_ ) - return self._replace_vars_and_dims(variables, coord_names=coord_names) def reset_index( self, diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index dba4c88363b..479e586d70c 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -312,21 +312,39 @@ def __getitem__(self, indexer: Any): return self._replace(self.index[indexer]) -def _create_variables_from_multiindex(index, dim, level_meta=None): - from .variable import IndexVariable +def _check_dim_compat(variables: Mapping[Any, "Variable"]) -> Hashable: + """Check that all multi-index variable candidates share the same (single) dimension + and return the name of that dimension. + + """ + if any([var.ndim != 1 for var in variables.values()]): + raise ValueError("PandasMultiIndex only accepts 1-dimensional variables") - if level_meta is None: - level_meta = {} + dims = set([var.dims for var in variables.values()]) - variables = {} + if len(dims) > 1: + raise ValueError( + "unmatched dimensions for variables " + + ", ".join([f"{k!r} {v.dims}" for k, v in variables.items()]) + ) - dim_coord_adapter = PandasMultiIndexingAdapter(index) - variables[dim] = IndexVariable(dim, dim_coord_adapter, fastpath=True) + return next(iter(dims))[0] - for level in index.names: - meta = level_meta.get(level, {}) + +def _create_variables_from_multiindex(index, dim, var_meta=None): + from .variable import IndexVariable + + if var_meta is None: + var_meta = {} + + def create_variable(name): + if name == dim: + level = None + else: + level = name + meta = var_meta.get(name, {}) data = PandasMultiIndexingAdapter(index, dtype=meta.get("dtype"), level=level) - variables[level] = IndexVariable( + return IndexVariable( dim, data, attrs=meta.get("attrs"), @@ -334,10 +352,16 @@ def _create_variables_from_multiindex(index, dim, level_meta=None): fastpath=True, ) + variables = {} + variables[dim] = create_variable(dim) + for level in index.names: + variables[level] = create_variable(level) + return variables class PandasMultiIndex(PandasIndex): + """Wrap a pandas.MultiIndex as an xarray compatible index.""" level_coords_dtype: Dict[str, Any] @@ -358,31 +382,83 @@ def _replace(self, index, dim=None, level_coords_dtype=None) -> "PandasMultiInde return type(self)(index, dim, level_coords_dtype) @classmethod - def from_variables(cls, variables: Mapping[Any, "Variable"]): - if any([var.ndim != 1 for var in variables.values()]): - raise ValueError("PandasMultiIndex only accepts 1-dimensional variables") - - dims = set([var.dims for var in variables.values()]) - if len(dims) != 1: - raise ValueError( - "unmatched dimensions for variables " - + ",".join([str(k) for k in variables]) - ) + def from_variables( + cls, variables: Mapping[Any, "Variable"] + ) -> Tuple["PandasMultiIndex", IndexVars]: + dim = _check_dim_compat(variables) - dim = next(iter(dims))[0] index = pd.MultiIndex.from_arrays( [var.values for var in variables.values()], names=variables.keys() ) level_coords_dtype = {name: var.dtype for name, var in variables.items()} obj = cls(index, dim, level_coords_dtype=level_coords_dtype) - level_meta = { + var_meta = { name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} for name, var in variables.items() } - index_vars = _create_variables_from_multiindex( - index, dim, level_meta=level_meta - ) + index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta) + + return obj, index_vars + + @classmethod + def from_variables_maybe_expand( + cls, + dim: Hashable, + current_variables: Mapping[Any, "Variable"], + variables: Mapping[Any, "Variable"], + ) -> Tuple["PandasMultiIndex", IndexVars]: + """Create a new multi-index maybe by expanding an existing one with + new variables as index levels. + + the index might be created along a new dimension. + """ + names: List[Hashable] = [] + codes: List[List[int]] = [] + levels: List[List[int]] = [] + var_meta: Dict[str, Dict] = {} + level_coords_dtype: Dict[Hashable, Any] = {} + + _check_dim_compat({**current_variables, **variables}) + + def add_level_var(name, var): + var_meta[name] = { + "dtype": var.dtype, + "attrs": var.attrs, + "encoding": var.encoding, + } + level_coords_dtype[name] = var.dtype + + if len(current_variables) > 1: + current_index: pd.MultiIndex = next( + iter(current_variables.values()) + )._data.array + names.extend(current_index.names) + codes.extend(current_index.codes) + levels.extend(current_index.levels) + for name in current_index.names: + add_level_var(name, current_variables[name]) + + elif len(current_variables) == 1: + # one 1D variable (no multi-index): convert it to an index level + var = next(iter(current_variables.values())) + new_var_name = f"{dim}_level_0" + names.append(new_var_name) + cat = pd.Categorical(var.values, ordered=True) + codes.append(cat.codes) + levels.append(cat.categories) + add_level_var(new_var_name, var) + + for name, var in variables.items(): + names.append(name) + cat = pd.Categorical(var.values, ordered=True) + codes.append(cat.codes) + levels.append(cat.categories) + add_level_var(name, var) + + index = pd.MultiIndex(levels, codes, names=names) + obj = cls(index, dim, level_coords_dtype=level_coords_dtype) + index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta) return obj, index_vars @@ -390,19 +466,17 @@ def from_variables(cls, variables: Mapping[Any, "Variable"]): def from_pandas_index( cls, index: pd.MultiIndex, dim: Hashable ) -> Tuple["PandasMultiIndex", IndexVars]: - level_meta = {} + var_meta = {} for i, idx in enumerate(index.levels): name = idx.name or f"{dim}_level_{i}" if name == dim: raise ValueError( f"conflicting multi-index level name {name!r} with dimension {dim!r}" ) - level_meta[name] = {"dtype": idx.dtype} + var_meta[name] = {"dtype": idx.dtype} - index = index.rename(level_meta.keys()) - index_vars = _create_variables_from_multiindex( - index, dim, level_meta=level_meta - ) + index = index.rename(var_meta.keys()) + index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta) return cls(index, dim), index_vars def query(self, labels, method=None, tolerance=None) -> QueryResult: From a891e22d33ae6553db8b2d58cf0522a92f21dd5d Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 14 Sep 2021 14:05:34 +0200 Subject: [PATCH 035/159] wip refactor set_index - Refactor reset_index - Improve creating new xarray indexes from pandas indexes with proper propagation of variable metadata (dtype, attrs, encoding) --- xarray/core/dataarray.py | 8 +-- xarray/core/dataset.py | 122 ++++++++++++++++++--------------------- xarray/core/indexes.py | 116 +++++++++++++++++++++++++++---------- 3 files changed, 142 insertions(+), 104 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index bcef818238c..0a862e6c0ee 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -46,7 +46,7 @@ from .common import AbstractArray, DataWithCoords from .computation import unify_chunks from .coordinates import DataArrayCoordinates, assert_coordinate_consistent -from .dataset import Dataset, split_indexes +from .dataset import Dataset from .formatting import format_item from .indexes import Index, Indexes, default_indexes, propagate_indexes from .indexing import is_fancy_indexer, map_index_queries @@ -2016,10 +2016,8 @@ def reset_index( -------- DataArray.set_index """ - coords, _ = split_indexes( - dims_or_levels, self._coords, set(), self._level_coords, drop=drop - ) - return self._replace(coords=coords) + ds = self._to_temp_dataset().reset_index(dims_or_levels, drop=drop) + return self._from_temp_dataset(ds) def reorder_levels( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1dfa266bbd9..8b8bd9435e4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -14,7 +14,6 @@ Any, Callable, Collection, - DefaultDict, Dict, Hashable, Iterable, @@ -194,64 +193,6 @@ def calculate_dimensions(variables: Mapping[Any, Variable]) -> Dict[Hashable, in return dims -def split_indexes( - dims_or_levels: Union[Hashable, Sequence[Hashable]], - variables: Mapping[Any, Variable], - coord_names: Set[Hashable], - level_coords: Mapping[Any, Hashable], - drop: bool = False, -) -> Tuple[Dict[Hashable, Variable], Set[Hashable]]: - """Extract (multi-)indexes (levels) as variables. - - Not public API. Used in Dataset and DataArray reset_index - methods. - """ - if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence): - dims_or_levels = [dims_or_levels] - - dim_levels: DefaultDict[Any, List[Hashable]] = defaultdict(list) - dims = [] - for k in dims_or_levels: - if k in level_coords: - dim_levels[level_coords[k]].append(k) - else: - dims.append(k) - - vars_to_replace = {} - vars_to_create: Dict[Hashable, Variable] = {} - vars_to_remove = [] - - for d in dims: - index = variables[d].to_index() - if isinstance(index, pd.MultiIndex): - dim_levels[d] = index.names - else: - vars_to_remove.append(d) - if not drop: - vars_to_create[str(d) + "_"] = Variable(d, index, variables[d].attrs) - - for d, levs in dim_levels.items(): - index = variables[d].to_index() - if len(levs) == index.nlevels: - vars_to_remove.append(d) - else: - vars_to_replace[d] = IndexVariable(d, index.droplevel(levs)) - - if not drop: - for lev in levs: - idx = index.get_level_values(lev) - vars_to_create[idx.name] = Variable(d, idx, variables[d].attrs) - - new_variables = dict(variables) - for v in set(vars_to_remove): - del new_variables[v] - new_variables.update(vars_to_replace) - new_variables.update(vars_to_create) - new_coord_names = (coord_names | set(vars_to_create)) - set(vars_to_remove) - - return new_variables, new_coord_names - - def _assert_empty(args: tuple, msg: str = "%s") -> None: if args: raise ValueError(msg % args) @@ -3777,14 +3718,61 @@ def reset_index( -------- Dataset.set_index """ - variables, coord_names = split_indexes( - dims_or_levels, - self._variables, - self._coord_names, - cast(Mapping[Hashable, Hashable], self._level_coords), - drop=drop, - ) - return self._replace_vars_and_dims(variables, coord_names=coord_names) + if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence): + dims_or_levels = [dims_or_levels] + + invalid_coords = set(dims_or_levels) - set(self.xindexes) + if invalid_coords: + raise ValueError( + f"{tuple(invalid_coords)} are not coordinates with an index" + ) + + drop_indexes: List[Hashable] = [] + drop_variables: List[Hashable] = [] + replaced_indexes: List[PandasMultiIndex] = [] + new_indexes: Dict[Hashable, Index] = {} + new_variables: Dict[Hashable, IndexVariable] = {} + + index_coord_names = { + k: coord_names + for _, coord_names in group_coords_by_index(self.xindexes) + for k in coord_names + } + + for name in dims_or_levels: + index = self.xindexes[name] + drop_indexes += [k for k in index_coord_names[name]] + + if isinstance(index, PandasMultiIndex) and name not in self.dims: + # special case for pd.MultiIndex (name is an index level): + # replace by a new index with dropped level(s) instead of just drop the index + # TODO: eventually extend Index API to allow this for custom multi-indexes? + if index not in replaced_indexes: + level_names = index.index.names + level_vars = { + k: self._variables[k] + for k in level_names + if k not in dims_or_levels + } + idx, idx_vars = index.keep_levels(level_vars) + new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) + replaced_indexes.append(index) + + if drop: + drop_variables.append(name) + + indexes = {k: v for k, v in self.xindexes.items() if k not in drop_indexes} + indexes.update(new_indexes) + + variables = { + k: v for k, v in self._variables.items() if k not in drop_variables + } + variables.update(new_variables) + + coord_names = set(new_variables) | self._coord_names + + return self._replace(variables, coord_names=coord_names, indexes=indexes) def reorder_levels( self, diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 479e586d70c..97b29cbfd65 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -11,6 +11,7 @@ Optional, Tuple, Union, + cast, ) import numpy as np @@ -212,7 +213,10 @@ def from_variables( @classmethod def from_pandas_index( - cls, index: pd.Index, dim: Hashable + cls, + index: pd.Index, + dim: Hashable, + var_meta: Optional[Dict[Any, Dict]] = None, ) -> Tuple["PandasIndex", IndexVars]: from .variable import IndexVariable @@ -223,10 +227,21 @@ def from_pandas_index( else: name = index.name - data = PandasIndexingAdapter(index) - index_var = IndexVariable(dim, data, fastpath=True) + if var_meta is None: + var_meta = {name: {}} + + data = PandasIndexingAdapter(index, dtype=var_meta[name].get("dtype")) + index_var = IndexVariable( + dim, + data, + fastpath=True, + attrs=var_meta[name].get("attrs"), + encoding=var_meta[name].get("encoding"), + ) - return cls(index, dim), {name: index_var} + return cls(index, dim, coord_dtype=var_meta[name].get("dtype")), { + name: index_var + } def to_pandas_index(self) -> pd.Index: return self.index @@ -297,13 +312,11 @@ def rename(self, name_dict, dims_dict): return self, {} new_name = name_dict.get(self.index.name, self.index.name) - pd_idx = self.index.rename(new_name) + index = self.index.rename(new_name) new_dim = dims_dict.get(self.dim, self.dim) + var_meta = {new_name: {"dtype": self.coord_dtype}} - index, index_vars = self.from_pandas_index(pd_idx, dim=new_dim) - index.coord_dtype = self.coord_dtype - - return index, index_vars + return self.from_pandas_index(index, dim=new_dim, var_meta=var_meta) def copy(self, deep=True): return self._replace(self.index.copy(deep=deep)) @@ -411,13 +424,12 @@ def from_variables_maybe_expand( """Create a new multi-index maybe by expanding an existing one with new variables as index levels. - the index might be created along a new dimension. + The index and its corresponding coordinates may be created along a new dimension. """ names: List[Hashable] = [] codes: List[List[int]] = [] levels: List[List[int]] = [] var_meta: Dict[str, Dict] = {} - level_coords_dtype: Dict[Hashable, Any] = {} _check_dim_compat({**current_variables, **variables}) @@ -427,12 +439,13 @@ def add_level_var(name, var): "attrs": var.attrs, "encoding": var.encoding, } - level_coords_dtype[name] = var.dtype if len(current_variables) > 1: - current_index: pd.MultiIndex = next( - iter(current_variables.values()) - )._data.array + # expand from an existing multi-index + data = cast( + PandasMultiIndexingAdapter, next(iter(current_variables.values()))._data + ) + current_index = data.array names.extend(current_index.names) codes.extend(current_index.codes) levels.extend(current_index.levels) @@ -440,7 +453,7 @@ def add_level_var(name, var): add_level_var(name, current_variables[name]) elif len(current_variables) == 1: - # one 1D variable (no multi-index): convert it to an index level + # expand from one 1D variable (no multi-index): convert it to an index level var = next(iter(current_variables.values())) new_var_name = f"{dim}_level_0" names.append(new_var_name) @@ -457,27 +470,63 @@ def add_level_var(name, var): add_level_var(name, var) index = pd.MultiIndex(levels, codes, names=names) - obj = cls(index, dim, level_coords_dtype=level_coords_dtype) - index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta) - return obj, index_vars + return cls.from_pandas_index(index, dim, var_meta=var_meta) + + def keep_levels( + self, level_variables: Mapping[Any, "Variable"] + ) -> Tuple[Union["PandasMultiIndex", PandasIndex], IndexVars]: + """Keep only the provided levels and return a new multi-index with its + corresponding coordinates. + + """ + var_meta: Dict[str, Dict] = {} + + for name, var in level_variables.items(): + var_meta[name] = { + "dtype": var.dtype, + "attrs": var.attrs, + "encoding": var.encoding, + } + + index = self.index.droplevel( + [k for k in self.index.names if k not in level_variables] + ) + + if isinstance(index, pd.MultiIndex): + return self.from_pandas_index(index, self.dim, var_meta=var_meta) + else: + return PandasIndex.from_pandas_index(index, self.dim, var_meta=var_meta) @classmethod def from_pandas_index( - cls, index: pd.MultiIndex, dim: Hashable + cls, + index: pd.MultiIndex, + dim: Hashable, + var_meta: Optional[Dict[Any, Dict]] = None, ) -> Tuple["PandasMultiIndex", IndexVars]: - var_meta = {} + + names = [] + idx_dtypes = {} for i, idx in enumerate(index.levels): name = idx.name or f"{dim}_level_{i}" if name == dim: raise ValueError( f"conflicting multi-index level name {name!r} with dimension {dim!r}" ) - var_meta[name] = {"dtype": idx.dtype} + names.append(name) + idx_dtypes[name] = idx.dtype + + if var_meta is None: + var_meta = {k: {} for k in names} + for name, dtype in idx_dtypes.items(): + var_meta[name]["dtype"] = var_meta[name].get("dtype", dtype) + + level_coords_dtype = {k: var_meta[k]["dtype"] for k in names} - index = index.rename(var_meta.keys()) + index = index.rename(names) index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta) - return cls(index, dim), index_vars + return cls(index, dim, level_coords_dtype=level_coords_dtype), index_vars def query(self, labels, method=None, tolerance=None) -> QueryResult: if method is not None or tolerance is not None: @@ -570,15 +619,19 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: raise KeyError(f"not all values found in index {coord_name!r}") if new_index is not None: + # variable(s) attrs and encoding metadata are propagated + # when replacing the indexes in the resulting xarray object + var_meta = {k: {"dtype": v} for k, v in self.level_coords_dtype.items()} + if isinstance(new_index, pd.MultiIndex): new_index, new_vars = PandasMultiIndex.from_pandas_index( - new_index, self.dim + new_index, self.dim, var_meta=var_meta ) dims_dict = {} drop_coords = set(self.index.names) - set(new_index.index.names) else: new_index, new_vars = PandasIndex.from_pandas_index( - new_index, new_index.name + new_index, new_index.name, var_meta=var_meta ) dims_dict = {self.dim: new_index.index.name} drop_coords = set(self.index.names) - {new_index.index.name} | { @@ -602,15 +655,14 @@ def rename(self, name_dict, dims_dict): # pandas 1.3.0: could simply do `self.index.rename(names_dict)` new_names = [name_dict.get(k, k) for k in self.index.names] - pd_idx = self.index.rename(new_names) - new_dim = dims_dict.get(self.dim, self.dim) + index = self.index.rename(new_names) - index, index_vars = self.from_pandas_index(pd_idx, new_dim) - index.level_coords_dtype = { - k: v for k, v in zip(new_names, self.level_coords_dtype.values()) + new_dim = dims_dict.get(self.dim, self.dim) + var_meta = { + k: {"dtype": v} for k, v in zip(new_names, self.level_coords_dtype.values()) } - return index, index_vars + return self.from_pandas_index(index, new_dim, var_meta=var_meta) def remove_unused_levels_categories(index: pd.Index) -> pd.Index: From e50978e0b3f5d85b1f3e1bf33eb3e37321336ade Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 14 Sep 2021 17:04:03 +0200 Subject: [PATCH 036/159] refactor reorder_levels --- xarray/core/dataarray.py | 13 ++-------- xarray/core/dataset.py | 31 ++++++++++++++-------- xarray/core/indexes.py | 56 +++++++++++++++++++++------------------- 3 files changed, 52 insertions(+), 48 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0a862e6c0ee..6aa68bfd902 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2042,17 +2042,8 @@ def reorder_levels( Another dataarray, with this dataarray's data but replaced coordinates. """ - dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") - replace_coords = {} - for dim, order in dim_order.items(): - coord = self._coords[dim] - index = coord.to_index() - if not isinstance(index, pd.MultiIndex): - raise ValueError(f"coordinate {dim!r} has no MultiIndex") - replace_coords[dim] = IndexVariable(coord.dims, index.reorder_levels(order)) - coords = self._coords.copy() - coords.update(replace_coords) - return self._replace(coords=coords) + ds = self._to_temp_dataset().reorder_levels(dim_order, **dim_order_kwargs) + return self._from_temp_dataset(ds) def stack( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 8b8bd9435e4..5f19bbe02b8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3638,11 +3638,11 @@ def set_index( current_coord_names = index_coord_names.get(dim, []) # drop any pre-existing index involved - maybe_drop_indexes.extend(current_coord_names + var_names) + maybe_drop_indexes += current_coord_names + var_names for k in var_names: - maybe_drop_indexes.extend(index_coord_names.get(k, [])) + maybe_drop_indexes += index_coord_names.get(k, []) - drop_variables.extend(var_names) + drop_variables += var_names if len(var_names) == 1 and (not append or dim not in self.xindexes): var_name = var_names[0] @@ -3800,16 +3800,25 @@ def reorder_levels( dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") variables = self._variables.copy() indexes = dict(self.xindexes) + new_indexes: Dict[Hashable, Index] = {} + new_variables: Dict[Hashable, IndexVariable] = {} + for dim, order in dim_order.items(): - coord = self._variables[dim] - # TODO: benbovy - flexible indexes: update when MultiIndex - # has its own class inherited from xarray.Index - index = self.xindexes[dim].to_pandas_index() - if not isinstance(index, pd.MultiIndex): + index = self.xindexes[dim] + + if not isinstance(index, PandasMultiIndex): raise ValueError(f"coordinate {dim} has no MultiIndex") - new_index = index.reorder_levels(order) - variables[dim] = IndexVariable(coord.dims, new_index) - indexes[dim] = PandasMultiIndex(new_index, dim) + + idx, idx_vars = index.reorder_levels({k: self._variables[k] for k in order}) + + new_variables.update(idx_vars) + new_indexes.update({k: idx for k in idx_vars}) + + indexes = {k: v for k, v in self.xindexes.items() if k not in new_indexes} + indexes.update(new_indexes) + + variables = {k: v for k, v in self._variables.items() if k not in new_variables} + variables.update(new_variables) return self._replace(variables, indexes=indexes) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 97b29cbfd65..84afc5ada01 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -344,6 +344,13 @@ def _check_dim_compat(variables: Mapping[Any, "Variable"]) -> Hashable: return next(iter(dims))[0] +def _get_var_metadata(variables: Mapping[Any, "Variable"]) -> Dict[Any, Dict[str, Any]]: + return { + name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} + for name, var in variables.items() + } + + def _create_variables_from_multiindex(index, dim, var_meta=None): from .variable import IndexVariable @@ -406,11 +413,9 @@ def from_variables( level_coords_dtype = {name: var.dtype for name, var in variables.items()} obj = cls(index, dim, level_coords_dtype=level_coords_dtype) - var_meta = { - name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} - for name, var in variables.items() - } - index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta) + index_vars = _create_variables_from_multiindex( + index, dim, var_meta=_get_var_metadata(variables) + ) return obj, index_vars @@ -429,17 +434,10 @@ def from_variables_maybe_expand( names: List[Hashable] = [] codes: List[List[int]] = [] levels: List[List[int]] = [] - var_meta: Dict[str, Dict] = {} + level_variables: Dict[Any, "Variable"] = {} _check_dim_compat({**current_variables, **variables}) - def add_level_var(name, var): - var_meta[name] = { - "dtype": var.dtype, - "attrs": var.attrs, - "encoding": var.encoding, - } - if len(current_variables) > 1: # expand from an existing multi-index data = cast( @@ -450,7 +448,7 @@ def add_level_var(name, var): codes.extend(current_index.codes) levels.extend(current_index.levels) for name in current_index.names: - add_level_var(name, current_variables[name]) + level_variables[name] = current_variables[name] elif len(current_variables) == 1: # expand from one 1D variable (no multi-index): convert it to an index level @@ -460,18 +458,20 @@ def add_level_var(name, var): cat = pd.Categorical(var.values, ordered=True) codes.append(cat.codes) levels.append(cat.categories) - add_level_var(new_var_name, var) + level_variables[new_var_name] = var for name, var in variables.items(): names.append(name) cat = pd.Categorical(var.values, ordered=True) codes.append(cat.codes) levels.append(cat.categories) - add_level_var(name, var) + level_variables[name] = var index = pd.MultiIndex(levels, codes, names=names) - return cls.from_pandas_index(index, dim, var_meta=var_meta) + return cls.from_pandas_index( + index, dim, var_meta=_get_var_metadata(level_variables) + ) def keep_levels( self, level_variables: Mapping[Any, "Variable"] @@ -480,15 +480,7 @@ def keep_levels( corresponding coordinates. """ - var_meta: Dict[str, Dict] = {} - - for name, var in level_variables.items(): - var_meta[name] = { - "dtype": var.dtype, - "attrs": var.attrs, - "encoding": var.encoding, - } - + var_meta = _get_var_metadata(level_variables) index = self.index.droplevel( [k for k in self.index.names if k not in level_variables] ) @@ -498,6 +490,18 @@ def keep_levels( else: return PandasIndex.from_pandas_index(index, self.dim, var_meta=var_meta) + def reorder_levels( + self, level_variables: Mapping[Any, "Variable"] + ) -> Tuple["PandasMultiIndex", IndexVars]: + """Re-arrange index levels using input order and return a new multi-index with + its corresponding coordinates. + + """ + index = self.index.reorder_levels(level_variables.keys()) + return self.from_pandas_index( + index, self.dim, var_meta=_get_var_metadata(level_variables) + ) + @classmethod def from_pandas_index( cls, From 8b1e4d56009876d32613becf76a8e1e02f88ad99 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 15 Sep 2021 11:28:02 +0200 Subject: [PATCH 037/159] Set pd.MultiIndex name from dim name Closes #4542 --- xarray/core/indexes.py | 5 ++++- xarray/tests/test_indexes.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 84afc5ada01..7df8438a27a 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -397,6 +397,7 @@ def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None): def _replace(self, index, dim=None, level_coords_dtype=None) -> "PandasMultiIndex": if dim is None: dim = self.dim + index.name = dim if level_coords_dtype is None: level_coords_dtype = self.level_coords_dtype return type(self)(index, dim, level_coords_dtype) @@ -410,6 +411,7 @@ def from_variables( index = pd.MultiIndex.from_arrays( [var.values for var in variables.values()], names=variables.keys() ) + index.name = dim level_coords_dtype = {name: var.dtype for name, var in variables.items()} obj = cls(index, dim, level_coords_dtype=level_coords_dtype) @@ -529,6 +531,7 @@ def from_pandas_index( level_coords_dtype = {k: var_meta[k]["dtype"] for k in names} index = index.rename(names) + index.name = dim index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta) return cls(index, dim, level_coords_dtype=level_coords_dtype), index_vars @@ -628,7 +631,7 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: var_meta = {k: {"dtype": v} for k, v in self.level_coords_dtype.items()} if isinstance(new_index, pd.MultiIndex): - new_index, new_vars = PandasMultiIndex.from_pandas_index( + new_index, new_vars = self.from_pandas_index( new_index, self.dim, var_meta=var_meta ) dims_dict = {} diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 97c503aea15..590b4b3cc5d 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -174,6 +174,7 @@ def test_from_variables(self) -> None: expected_idx = pd.MultiIndex.from_arrays([v_level1.data, v_level2.data]) assert index.dim == "x" assert index.index.equals(expected_idx) + assert index.index.name == "x" assert list(index_vars) == ["x", "level1", "level2"] xr.testing.assert_equal(xr.IndexVariable("x", expected_idx), index_vars["x"]) @@ -200,6 +201,7 @@ def test_from_pandas_index(self) -> None: assert index.dim == "x" assert index.index.equals(pd_idx) assert index.index.names == ("foo", "bar") + assert index.index.name == "x" xr.testing.assert_identical(index_vars["x"], IndexVariable("x", pd_idx)) xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", foo_data)) xr.testing.assert_identical(index_vars["bar"], IndexVariable("x", bar_data)) @@ -231,7 +233,7 @@ def test_query(self) -> None: index.query({"x": (slice(None), 1, "no_level")}) def test_rename(self) -> None: - level_coords_dtype = {"one": "U<1", "two": np.int32} + level_coords_dtype = {"one": " None: new_index, index_vars = index.rename({"two": "three"}, {}) assert new_index.index.names == ["one", "three"] assert new_index.dim == "x" - assert new_index.level_coords_dtype == {"one": "U<1", "three": np.int32} + assert new_index.level_coords_dtype == {"one": " Date: Wed, 15 Sep 2021 14:20:12 +0200 Subject: [PATCH 038/159] .sel() with multi-index: return scalar coords ..instead of dropping those coordinates Closes #1408 once tests are fixed/updated --- xarray/core/dataset.py | 32 +++++++++++++++++++++++++++++--- xarray/core/indexes.py | 33 +++++++++++++++++++++++---------- xarray/core/indexing.py | 26 ++++++++++++++++++-------- 3 files changed, 70 insertions(+), 21 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5f19bbe02b8..86648c9ed96 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1021,9 +1021,15 @@ def _overwrite_indexes( indexes: Mapping[Hashable, Index], variables: Optional[Mapping[Hashable, Variable]] = None, drop_variables: Optional[List[Hashable]] = None, + drop_indexes: Optional[List[Hashable]] = None, rename_dims: Optional[Mapping[Hashable, Hashable]] = None, ) -> "Dataset": - """Maybe replace indexes and their corresponding index variables.""" + """Maybe replace indexes. + + This function may do a lot more depending on index query + results. + + """ if not indexes: return self @@ -1031,6 +1037,8 @@ def _overwrite_indexes( variables = {} if drop_variables is None: drop_variables = [] + if drop_indexes is None: + drop_indexes = [] propagate_attrs_encoding(self._variables, variables) @@ -1038,13 +1046,31 @@ def _overwrite_indexes( new_coord_names = self._coord_names.copy() new_indexes = dict(self.xindexes) + index_variables = {} + no_index_variables = {} + for k, v in variables.items(): + if k in indexes: + index_variables[k] = v + else: + no_index_variables[k] = v + for name in indexes: new_variables[name] = variables[name] - new_indexes[name] = indexes[name] + + for name in index_variables: + new_variables[name] = variables[name] + + # append no-index variables at the end + for k in no_index_variables: + new_variables.pop(k) + new_variables.update(no_index_variables) + + for name in drop_indexes: + new_indexes.pop(name) for name in drop_variables: new_variables.pop(name) - new_indexes.pop(name) + new_indexes.pop(name, None) new_coord_names.remove(name) replaced = self._replace( diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 7df8438a27a..31a8c13cfae 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -536,12 +536,15 @@ def from_pandas_index( return cls(index, dim, level_coords_dtype=level_coords_dtype), index_vars def query(self, labels, method=None, tolerance=None) -> QueryResult: + from .variable import Variable + if method is not None or tolerance is not None: raise ValueError( "multi-index does not support ``method`` and ``tolerance``" ) new_index = None + scalar_coord_values = {} # label(s) given for multi-index level(s) if all([lbl in self.index.names for lbl in labels]): @@ -567,6 +570,7 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: indexer, new_index = self.index.get_loc_level( tuple(label_values.values()), level=tuple(label_values.keys()) ) + scalar_coord_values.update(label_values) # GH2619. Raise a KeyError if nothing is chosen if indexer.dtype.kind == "b" and indexer.sum() == 0: raise KeyError(f"{labels} not found") @@ -601,9 +605,9 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: elif len(label) == self.index.nlevels: indexer = self.index.get_loc(label) else: - indexer, new_index = self.index.get_loc_level( - label, level=list(range(len(label))) - ) + levels = [self.index.names[i] for i in range(len(label))] + indexer, new_index = self.index.get_loc_level(label, level=levels) + scalar_coord_values.update({k: v for k, v in zip(levels, label)}) else: label = ( @@ -613,6 +617,7 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: ) if label.ndim == 0: indexer, new_index = self.index.get_loc_level(label.item(), level=0) + scalar_coord_values[self.index.names[0]] = label.item() elif label.dtype.kind == "b": indexer = label else: @@ -635,21 +640,29 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: new_index, self.dim, var_meta=var_meta ) dims_dict = {} - drop_coords = set(self.index.names) - set(new_index.index.names) + drop_coords = [] else: new_index, new_vars = PandasIndex.from_pandas_index( new_index, new_index.name, var_meta=var_meta ) dims_dict = {self.dim: new_index.index.name} - drop_coords = set(self.index.names) - {new_index.index.name} | { - self.dim - } + drop_coords = [self.dim] + + indexes = cast(Dict[Any, Index], {k: new_index for k in new_vars}) + + # add scalar variable for each dropped level + variables = cast( + Dict[Hashable, Union["Variable", "IndexVariable"]], new_vars + ) + for name, val in scalar_coord_values.items(): + variables[name] = Variable([], val) return QueryResult( {self.dim: indexer}, - indexes={k: new_index for k in new_vars}, - index_vars=new_vars, - drop_coords=list(drop_coords), + indexes=indexes, + variables=variables, + drop_indexes=list(scalar_coord_values), + drop_coords=drop_coords, rename_dims=dims_dict, ) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index d930674937e..3e892f0ef40 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -35,7 +35,8 @@ from .utils import either_dict_or_kwargs if TYPE_CHECKING: - from .indexes import Index, IndexVars + from .indexes import Index + from .variable import IndexVariable, Variable @dataclass @@ -49,10 +50,12 @@ class QueryResult: location-based indexers. indexes: dict, optional New indexes to replace in the resulting DataArray or Dataset. - index_vars : dict, optional - New indexed variables to replace in the resulting DataArray or Dataset. + variables : dict, optional + New variables to replace in the resulting DataArray or Dataset. drop_coords : list, optional Coordinate(s) to drop in the resulting DataArray or Dataset. + drop_indexes : list, optional + Indexes(s) to drop in the resulting DataArray or Dataset. rename_dims : dict, optional A dictionnary in the form ``{old_dim: new_dim}`` for dimension(s) to rename in the resulting DataArray or Dataset. @@ -60,9 +63,12 @@ class QueryResult: """ dim_indexers: Dict[Any, Any] - indexes: Dict[Hashable, "Index"] = field(default_factory=dict) - index_vars: "IndexVars" = field(default_factory=dict) + indexes: Dict[Any, "Index"] = field(default_factory=dict) + variables: Dict[Any, Union["Variable", "IndexVariable"]] = field( + default_factory=dict + ) drop_coords: List[Hashable] = field(default_factory=list) + drop_indexes: List[Hashable] = field(default_factory=list) rename_dims: Dict[Any, Hashable] = field(default_factory=dict) @@ -86,18 +92,22 @@ def merge_query_results(results: List[QueryResult]) -> QueryResult: dim_indexers = {} indexes = {} - index_vars = {} + variables = {} drop_coords = [] + drop_indexes = [] rename_dims = {} for res in results: dim_indexers.update(res.dim_indexers) indexes.update(res.indexes) - index_vars.update(res.index_vars) + variables.update(res.variables) drop_coords += res.drop_coords + drop_indexes += res.drop_indexes rename_dims.update(res.rename_dims) - return QueryResult(dim_indexers, indexes, index_vars, drop_coords, rename_dims) + return QueryResult( + dim_indexers, indexes, variables, drop_coords, drop_indexes, rename_dims + ) def group_indexers_by_index( From 4fbdc860ed2a2c7072d2847523af1fbaa43ecfda Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 16 Sep 2021 00:42:10 +0200 Subject: [PATCH 039/159] fix multi-index level coordinate inline repr --- xarray/core/indexing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 3e892f0ef40..91fb7b0f605 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1517,7 +1517,7 @@ def _repr_inline_(self, max_width) -> str: indices = np.concatenate( [np.arange(0, n_values), np.arange(-n_values, 0)] ) - subset = self[indices] + subset = self[OuterIndexer((indices,))] else: subset = self From ef6fbbdb257e0338c29ef93baa588f07bf227bda Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 16 Sep 2021 00:43:23 +0200 Subject: [PATCH 040/159] wip refactor stack Some changes: Multi-indexes are created only if the stacked dimensions each have a single coordinate index. In the other cases, multi-indexes are not created, which means that it is an irreversible operation (unstack will not work) Multi-indexes are created from the stacked coordinate variables and not anymore from the unstacked dimension indexes (level product). There's a significant decrease in performance but it is probably acceptable since it's now possible to avoid the creation of multi-indexes. It is possible to stack a dimension with a multi-index, but it drops the index. Otherwise it would make it hard for unstack() to figure out what's going on. It makes it clear that this is an irreversible operation. --- xarray/core/dataset.py | 62 ++++++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 86648c9ed96..114895ad1a8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3851,28 +3851,54 @@ def reorder_levels( def _stack_once(self, dims, new_dim): if ... in dims: dims = list(infix_dims(dims, self.dims)) - variables = {} + + # TODO: add default dimension variables (range) if missing + # only if we want backwards compatibility (multi-index always created) + + variables: Dict[Hashable, Variable] = {} + stacked_var_names: List[Hashable] = [] + drop_indexes: List[Hashable] = [] + for name, var in self.variables.items(): - if name not in dims: - if any(d in var.dims for d in dims): - add_dims = [d for d in dims if d not in var.dims] - vdims = list(var.dims) + add_dims - shape = [self.dims[d] for d in vdims] - exp_var = var.set_dims(vdims, shape) - stacked_var = exp_var.stack(**{new_dim: dims}) - variables[name] = stacked_var - else: - variables[name] = var.copy(deep=False) + if any(d in var.dims for d in dims): + add_dims = [d for d in dims if d not in var.dims] + vdims = list(var.dims) + add_dims + shape = [self.dims[d] for d in vdims] + exp_var = var.set_dims(vdims, shape) + stacked_var = exp_var.stack(**{new_dim: dims}) + variables[name] = stacked_var + stacked_var_names.append(name) + else: + variables[name] = var.copy(deep=False) - # consider dropping levels that are unused? - levels = [self.get_index(dim) for dim in dims] - idx = utils.multiindex_from_product_levels(levels, names=dims) - variables[new_dim] = IndexVariable(new_dim, idx) + # drop indexes of stacked coordinates (if any) + index_coord_names = { + k: coord_names + for _, coord_names in group_coords_by_index(self.xindexes) + for k in coord_names + } + for k in stacked_var_names: + drop_indexes += index_coord_names.get(k, []) - coord_names = set(self._coord_names) - set(dims) | {new_dim} + # A new index is created only if all stacked dimensions have an index + # TODO: add API option for the creation of a new index (see GH 5202) + stacked_idx_vars = { + k: variables[k] for k in stacked_var_names if k in self.xindexes + } + if len(stacked_idx_vars) == len(dims): + idx, idx_vars = PandasMultiIndex.from_variables(stacked_idx_vars) + new_indexes = {k: idx for k in idx_vars} + # keep consistent multi-index coordinate order + for k in idx_vars: + variables.pop(k, None) + variables.update(idx_vars) + coord_names = set(self._coord_names) | {new_dim} + else: + new_indexes = {} + coord_names = set(self._coord_names) - indexes = {k: v for k, v in self.xindexes.items() if k not in dims} - indexes[new_dim] = PandasMultiIndex(idx, new_dim) + indexes = {k: v for k, v in self.xindexes.items() if k not in drop_indexes} + indexes.update(new_indexes) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes From 8bc2590905abefb92bde078370a1e098373456a5 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 16 Sep 2021 01:24:21 +0200 Subject: [PATCH 041/159] stack: better rule for add/skip multi-index It is more robust and allows creating a multi-index from non-default (non-dimension) coordinates, as long as there's one and only one 1-d indexed coordinate for each dimension to stack. --- xarray/core/dataset.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 114895ad1a8..e5d3b546a15 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3877,16 +3877,24 @@ def _stack_once(self, dims, new_dim): for _, coord_names in group_coords_by_index(self.xindexes) for k in coord_names } - for k in stacked_var_names: - drop_indexes += index_coord_names.get(k, []) + for name in stacked_var_names: + drop_indexes += index_coord_names.get(name, []) - # A new index is created only if all stacked dimensions have an index + # A new index is created only if each of the stacked dimensions has + # one and only one 1-d coordinate index # TODO: add API option for the creation of a new index (see GH 5202) - stacked_idx_vars = { - k: variables[k] for k in stacked_var_names if k in self.xindexes - } - if len(stacked_idx_vars) == len(dims): - idx, idx_vars = PandasMultiIndex.from_variables(stacked_idx_vars) + idx_vars_candidates = {} + idx_vars_candidates_dim = [] + for name in stacked_var_names: + var = self._variables[name] + if name in self.xindexes and var.ndim == 1: + dim = var.dims[0] + if dim in dims: + idx_vars_candidates[name] = variables[name] + idx_vars_candidates_dim.append(dim) + + if len(set(idx_vars_candidates_dim)) == len(idx_vars_candidates_dim) == 2: + idx, idx_vars = PandasMultiIndex.from_variables(idx_vars_candidates) new_indexes = {k: idx for k in idx_vars} # keep consistent multi-index coordinate order for k in idx_vars: From 64df1e90416ec563bd9fe7664da81058638eacc9 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 16 Sep 2021 12:12:03 +0200 Subject: [PATCH 042/159] add utility methods to class Indexes This removes some ugly & duplicate patterns. --- xarray/core/coordinates.py | 4 +- xarray/core/dataset.py | 36 +++------ xarray/core/indexes.py | 137 ++++++++++++++++++++++++++--------- xarray/tests/test_indexes.py | 40 +++++++++- 4 files changed, 154 insertions(+), 63 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 5bfbd347620..749ac2dcef9 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -49,11 +49,11 @@ def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]: raise NotImplementedError() @property - def indexes(self) -> Indexes: + def indexes(self) -> Indexes[pd.Index]: return self._data.indexes # type: ignore[attr-defined] @property - def xindexes(self) -> Indexes: + def xindexes(self) -> Indexes[Index]: return self._data.xindexes # type: ignore[attr-defined] @property diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e5d3b546a15..0cf4becba66 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -62,7 +62,6 @@ PandasIndex, PandasMultiIndex, default_indexes, - group_coords_by_index, isel_variable_and_index, propagate_indexes, remove_unused_levels_categories, @@ -1589,7 +1588,7 @@ def identical(self, other: "Dataset") -> bool: return False @property - def indexes(self) -> Indexes: + def indexes(self) -> Indexes[pd.Index]: """Mapping of pandas.Index objects used for label based indexing. Raises an error if this Dataset has indexes that cannot be coerced @@ -1603,7 +1602,7 @@ def indexes(self) -> Indexes: return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) @property - def xindexes(self) -> Indexes: + def xindexes(self) -> Indexes[Index]: """Mapping of xarray Index objects used for label based indexing.""" if self._indexes is None: self._indexes = default_indexes(self._variables, self._dims) @@ -3195,7 +3194,7 @@ def _rename_indexes(self, name_dict, dims_dict): indexes = {} variables = {} - for index, coord_names in group_coords_by_index(self.xindexes): + for index, coord_names in self.xindexes.group_by_index(): new_index, new_index_vars = index.rename(name_dict, dims_dict) # map new index to its corresponding coordinates new_coord_names = [name_dict.get(k, k) for k in coord_names] @@ -3642,12 +3641,6 @@ def set_index( drop_variables: List[Hashable] = [] replace_dims: Dict[Hashable, Hashable] = {} - index_coord_names = { - k: coord_names - for _, coord_names in group_coords_by_index(self.xindexes) - for k in coord_names - } - for dim, _var_names in dim_coords.items(): if isinstance(_var_names, str) or not isinstance(_var_names, Sequence): var_names = [_var_names] @@ -3661,12 +3654,14 @@ def set_index( + " variable(s) do not exist" ) - current_coord_names = index_coord_names.get(dim, []) + current_coord_names = self.xindexes.get_all_coords(dim, errors="ignore") # drop any pre-existing index involved - maybe_drop_indexes += current_coord_names + var_names + maybe_drop_indexes += list(current_coord_names) + var_names for k in var_names: - maybe_drop_indexes += index_coord_names.get(k, []) + maybe_drop_indexes += list( + self.xindexes.get_all_coords(k, errors="ignore") + ) drop_variables += var_names @@ -3759,15 +3754,9 @@ def reset_index( new_indexes: Dict[Hashable, Index] = {} new_variables: Dict[Hashable, IndexVariable] = {} - index_coord_names = { - k: coord_names - for _, coord_names in group_coords_by_index(self.xindexes) - for k in coord_names - } - for name in dims_or_levels: index = self.xindexes[name] - drop_indexes += [k for k in index_coord_names[name]] + drop_indexes += list(self.xindexes.get_all_coords(name)) if isinstance(index, PandasMultiIndex) and name not in self.dims: # special case for pd.MultiIndex (name is an index level): @@ -3872,13 +3861,8 @@ def _stack_once(self, dims, new_dim): variables[name] = var.copy(deep=False) # drop indexes of stacked coordinates (if any) - index_coord_names = { - k: coord_names - for _, coord_names in group_coords_by_index(self.xindexes) - for k in coord_names - } for name in stacked_var_names: - drop_indexes += index_coord_names.get(name, []) + drop_indexes += list(self.xindexes.get_all_coords(name, errors="ignore")) # A new index is created only if each of the stacked dimensions has # one and only one 1-d coordinate index diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 31a8c13cfae..1087f36e651 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -4,12 +4,15 @@ TYPE_CHECKING, Any, Dict, + Generic, Hashable, Iterable, List, Mapping, Optional, + Set, Tuple, + TypeVar, Union, cast, ) @@ -712,20 +715,114 @@ def remove_unused_levels_categories(index: pd.Index) -> pd.Index: return index -class Indexes(collections.abc.Mapping): - """Immutable proxy for Dataset or DataArrary indexes.""" +# generic type that represents either pandas or xarray indexes +T_Index = TypeVar("T_Index") - __slots__ = ("_indexes",) - def __init__(self, indexes): - """Not for public consumption. +class Indexes(collections.abc.Mapping, Generic[T_Index]): + """Immutable proxy for Dataset or DataArrary indexes. + + Keys are coordinate names and values may correspond to either pandas or + xarray indexes. + + Also provides some utility methods. + + """ + + _indexes: Dict[Any, T_Index] + _mappings_cached: bool + _coord_name_id: Dict[Any, int] + _id_coord_names: Dict[int, Tuple[Hashable, ...]] + _id_index: Dict[int, T_Index] + + __slots__ = ( + "_indexes", + "_mappings_cached", + "_coord_name_id", + "_id_coord_names", + "_id_index", + ) + + def __init__(self, indexes: Dict[Any, T_Index]): + """Constructor not for public consumption. Parameters ---------- - indexes : Dict[Any, pandas.Index] + indexes : dict Indexes held by this object. """ self._indexes = indexes + self._mappings_cached = False + self._coord_name_id = {} + self._id_coord_names = {} + self._id_index = {} + + def _cache_mappings(self) -> None: + if self._mappings_cached: + return + + self._coord_name_id = {} + self._id_index = {} + grouped_coord_names: Mapping[int, List[Hashable]] = defaultdict(list) + + for coord_name, index_obj in self._indexes.items(): + index_id = id(index_obj) + self._id_index[index_id] = index_obj + self._coord_name_id[coord_name] = index_id + grouped_coord_names[index_id].append(coord_name) + self._id_coord_names = {k: tuple(v) for k, v in grouped_coord_names.items()} + + self._mappings_cached = True + + def get_unique(self) -> List[T_Index]: + """Returns a list of unique indexes, preserving order.""" + + unique_indexes: List[T_Index] = [] + seen: Set[T_Index] = set() + + for index in self._indexes.values(): + if index not in seen: + unique_indexes.append(index) + seen.add(index) + + return unique_indexes + + def get_all_coords( + self, coord_name: Hashable, errors: str = "raise" + ) -> Tuple[Hashable, ...]: + """Return the names of all coordinates having the same index. + + Parameters + ---------- + coord_name : hashable + Name of an indexed coordinate. + errors : {"raise", "ignore"}, optional + If "raise", raises a ValueError if `coord_name` is not in indexes. + If "ignore", an empty tuple is returned instead. + + Returns + ------- + names : tuple + The names of all coordinates having the same index. + + """ + if errors not in ["raise", "ignore"]: + raise ValueError('errors must be either "raise" or "ignore"') + + if coord_name not in self._indexes: + if errors == "raise": + raise ValueError(f"no index found for {coord_name!r} coordinate") + else: + return tuple() + + self._cache_mappings() + return self._id_coord_names[self._coord_name_id[coord_name]] + + def group_by_index(self) -> List[Tuple[T_Index, Tuple[Hashable, ...]]]: + """Returns a list of unique indexes and their corresponding coordinate names.""" + + self._cache_mappings() + return [(self._id_index[i], self._id_coord_names[i]) for i in self._id_index] def __iter__(self): return iter(self._indexes) @@ -743,34 +840,6 @@ def __repr__(self): return formatting.indexes_repr(self) -def group_coords_by_index( - indexes: Mapping[Any, Index] -) -> List[Tuple[Index, List[Hashable]]]: - """Returns a list of unique indexes and their corresponding coordinate names.""" - unique_indexes: Dict[int, Index] = {} - grouped_coord_names: Mapping[int, List[Hashable]] = defaultdict(list) - - for coord_name, index_obj in indexes.items(): - index_id = id(index_obj) - unique_indexes[index_id] = index_obj - grouped_coord_names[index_id].append(coord_name) - - return [(unique_indexes[k], grouped_coord_names[k]) for k in unique_indexes] - - -def unique_indexes(indexes: Mapping[Any, Index]) -> List[Index]: - """Returns a list of unique indexes, preserving order.""" - unique_indexes = [] - seen = [] - - for index in indexes.values(): - if index not in seen: - unique_indexes.append(index) - seen.append(index) - - return unique_indexes - - def default_indexes( coords: Mapping[Any, "Variable"], dims: Iterable ) -> Dict[Hashable, Index]: diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 590b4b3cc5d..6b108904efd 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -3,7 +3,12 @@ import pytest import xarray as xr -from xarray.core.indexes import PandasIndex, PandasMultiIndex, _asarray_tuplesafe +from xarray.core.indexes import ( + Indexes, + PandasIndex, + PandasMultiIndex, + _asarray_tuplesafe, +) from xarray.core.variable import IndexVariable @@ -175,6 +180,7 @@ def test_from_variables(self) -> None: assert index.dim == "x" assert index.index.equals(expected_idx) assert index.index.name == "x" + assert index.index.names == ["level1", "level2"] assert list(index_vars) == ["x", "level1", "level2"] xr.testing.assert_equal(xr.IndexVariable("x", expected_idx), index_vars["x"]) @@ -274,3 +280,35 @@ def test_copy(self) -> None: assert actual.index is not expected.index assert actual.dim == expected.dim assert actual.level_coords_dtype == expected.level_coords_dtype + + +class TestIndexes: + def test_get_unique(self) -> None: + idx = [PandasIndex([1, 2, 3], "x"), PandasIndex([4, 5, 6], "y")] + indexes = Indexes({"a": idx[0], "b": idx[1], "c": idx[0]}) + + assert indexes.get_unique() == idx + + def test_get_all_coords(self) -> None: + idx = [PandasIndex([1, 2, 3], "x"), PandasIndex([4, 5, 6], "y")] + indexes = Indexes({"a": idx[0], "b": idx[1], "c": idx[0]}) + + assert indexes.get_all_coords("a") == ("a", "c") + # test cached internal dicts + assert indexes.get_all_coords("a") == ("a", "c") + + with pytest.raises(ValueError, match="errors must be.*"): + indexes.get_all_coords("a", errors="invalid") + + with pytest.raises(ValueError, match="no index found.*"): + indexes.get_all_coords("z") + + assert indexes.get_all_coords("z", errors="ignore") == tuple() + + def test_group_by_index(self): + idx = [PandasIndex([1, 2, 3], "x"), PandasIndex([4, 5, 6], "y")] + indexes = Indexes({"a": idx[0], "b": idx[1], "c": idx[0]}) + + assert indexes.group_by_index() == [(idx[0], ("a", "c")), (idx[1], ("b",))] + # test cached internal dicts + assert indexes.group_by_index() == [(idx[0], ("a", "c")), (idx[1], ("b",))] From da8846ddd71c46f2f0daff14c3b77cf7accf21a9 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 16 Sep 2021 16:46:20 +0200 Subject: [PATCH 043/159] re-arrange class Indexes internals --- xarray/core/indexes.py | 45 +++++++++++++----------------------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 1087f36e651..f35ac2d0f45 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1,4 +1,5 @@ import collections.abc +import functools from collections import defaultdict from typing import ( TYPE_CHECKING, @@ -730,18 +731,6 @@ class Indexes(collections.abc.Mapping, Generic[T_Index]): """ _indexes: Dict[Any, T_Index] - _mappings_cached: bool - _coord_name_id: Dict[Any, int] - _id_coord_names: Dict[int, Tuple[Hashable, ...]] - _id_index: Dict[int, T_Index] - - __slots__ = ( - "_indexes", - "_mappings_cached", - "_coord_name_id", - "_id_coord_names", - "_id_index", - ) def __init__(self, indexes: Dict[Any, T_Index]): """Constructor not for public consumption. @@ -752,27 +741,23 @@ def __init__(self, indexes: Dict[Any, T_Index]): Indexes held by this object. """ self._indexes = indexes - self._mappings_cached = False - self._coord_name_id = {} - self._id_coord_names = {} - self._id_index = {} - def _cache_mappings(self) -> None: - if self._mappings_cached: - return + @functools.cached_property + def _coord_name_id(self) -> Dict[Any, int]: + return {k: id(idx) for k, idx in self._indexes.items()} + + @functools.cached_property + def _id_index(self) -> Dict[int, T_Index]: + return {id(idx): idx for idx in self.get_unique()} - self._coord_name_id = {} - self._id_index = {} - grouped_coord_names: Mapping[int, List[Hashable]] = defaultdict(list) + @functools.cached_property + def _id_coord_names(self) -> Dict[int, Tuple[Hashable, ...]]: + id_coord_names: Mapping[int, List[Hashable]] = defaultdict(list) - for coord_name, index_obj in self._indexes.items(): - index_id = id(index_obj) - self._id_index[index_id] = index_obj - self._coord_name_id[coord_name] = index_id - grouped_coord_names[index_id].append(coord_name) - self._id_coord_names = {k: tuple(v) for k, v in grouped_coord_names.items()} + for k, v in self._coord_name_id.items(): + id_coord_names[v].append(k) - self._mappings_cached = True + return {k: tuple(v) for k, v in id_coord_names.items()} def get_unique(self) -> List[T_Index]: """Returns a list of unique indexes, preserving order.""" @@ -815,13 +800,11 @@ def get_all_coords( else: return tuple() - self._cache_mappings() return self._id_coord_names[self._coord_name_id[coord_name]] def group_by_index(self) -> List[Tuple[T_Index, Tuple[Hashable, ...]]]: """Returns a list of unique indexes and their corresponding coordinate names.""" - self._cache_mappings() return [(self._id_index[i], self._id_coord_names[i]) for i in self._id_index] def __iter__(self): From 28ecad657b3028675e1f72b4bdf45ec1dcb8ce6c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 16 Sep 2021 16:55:16 +0200 Subject: [PATCH 044/159] wip refactor stack / unstack stack: revert to old way of creating multi-index unstack: support non-default (non-dimension) multi-index (as long as there is exactly one multi-index per specified dimension) --- xarray/core/dataset.py | 128 +++++++++++++++++++++++++++++------------ 1 file changed, 90 insertions(+), 38 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 0cf4becba66..28b0a9f4473 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3837,6 +3837,38 @@ def reorder_levels( return self._replace(variables, indexes=indexes) + def _find_stack_index( + self, dim, multi=False + ) -> Tuple[Union[Index, None], List[Hashable]]: + """Used by stack and unstack to find one pandas (multi-)index among + the indexed coordinates along dimension `dim`. + + If it finds exactly one index returns it with its corresponding + coordinate name(s), otherwise returns None and an empty list. + + """ + if multi: + index_cls = PandasMultiIndex + else: + index_cls = PandasIndex # type: ignore[assignment] + + indexes: Set[Index] = set() + names: List[Hashable] = [] + + for name in self._coord_names: + var = self._variables[name] + index = self.xindexes.get(name) + if index is not None and var.ndim == 1: + var_dim = var.dims[0] + if var_dim == dim and type(index) is index_cls: + indexes.add(index) + names.append(name) + + if len(indexes) == 1: + return next(iter(indexes)), names + else: + return None, [] + def _stack_once(self, dims, new_dim): if ... in dims: dims = list(infix_dims(dims, self.dims)) @@ -3866,19 +3898,21 @@ def _stack_once(self, dims, new_dim): # A new index is created only if each of the stacked dimensions has # one and only one 1-d coordinate index - # TODO: add API option for the creation of a new index (see GH 5202) - idx_vars_candidates = {} - idx_vars_candidates_dim = [] - for name in stacked_var_names: - var = self._variables[name] - if name in self.xindexes and var.ndim == 1: - dim = var.dims[0] - if dim in dims: - idx_vars_candidates[name] = variables[name] - idx_vars_candidates_dim.append(dim) - - if len(set(idx_vars_candidates_dim)) == len(idx_vars_candidates_dim) == 2: - idx, idx_vars = PandasMultiIndex.from_variables(idx_vars_candidates) + # TODO: add API option to force/skip the creation of a new index (see GH 5202) + stack_indexes: Dict[Any, Tuple[pd.Index, Any]] = {} + # stack_idx_vars: Dict[Any, Variable] = {} + for dim in dims: + index, names = self._find_stack_index(dim) + if index is not None: + stack_indexes[dim] = cast(PandasIndex, index).index, names[0] + # n = names[0] + # stack_idx_vars[n] = variables[n] + + if len(stack_indexes) == len(dims): + levels, names = zip(*stack_indexes.values()) + midx = utils.multiindex_from_product_levels(levels, names=names) + idx, idx_vars = PandasMultiIndex.from_pandas_index(midx, new_dim) + # idx, idx_vars = PandasMultiIndex.from_variables(stack_idx_vars) new_indexes = {k: idx for k in idx_vars} # keep consistent multi-index coordinate order for k in idx_vars: @@ -4057,15 +4091,20 @@ def ensure_stackable(val): return data_array - def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset": - index = self.get_index(dim) - index = remove_unused_levels_categories(index) + def _unstack_once( + self, + dim: Hashable, + index_and_coords: Tuple[PandasMultiIndex, List[Hashable]], + fill_value, + ) -> "Dataset": + index, index_vnames = index_and_coords + pd_index = remove_unused_levels_categories(index.index) variables: Dict[Hashable, Variable] = {} indexes = {k: v for k, v in self.xindexes.items() if k != dim} for name, var in self.variables.items(): - if name != dim: + if name not in index_vnames: if dim in var.dims: if isinstance(fill_value, Mapping): fill_value_ = fill_value[name] @@ -4073,52 +4112,56 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset": fill_value_ = fill_value variables[name] = var._unstack_once( - index=index, dim=dim, fill_value=fill_value_ + index=pd_index, dim=dim, fill_value=fill_value_ ) else: variables[name] = var - for name, lev in zip(index.names, index.levels): + for name, lev in zip(pd_index.names, pd_index.levels): idx, idx_vars = PandasIndex.from_pandas_index(lev, name) variables[name] = idx_vars[name] indexes[name] = idx - coord_names = set(self._coord_names) - {dim} | set(index.names) + coord_names = set(self._coord_names) - {dim} | set(pd_index.names) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) def _unstack_full_reindex( - self, dim: Hashable, fill_value, sparse: bool + self, + dim: Hashable, + index_and_coords: Tuple[PandasMultiIndex, List[Hashable]], + fill_value, + sparse: bool, ) -> "Dataset": - index = self.get_index(dim) - index = remove_unused_levels_categories(index) - full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) + index, index_vnames = index_and_coords + pd_index = remove_unused_levels_categories(index.index) + full_idx = pd.MultiIndex.from_product(pd_index.levels, names=pd_index.names) # take a shortcut in case the MultiIndex was not modified. - if index.equals(full_idx): + if pd_index.equals(full_idx): obj = self else: obj = self._reindex( {dim: full_idx}, copy=False, fill_value=fill_value, sparse=sparse ) - new_dim_names = index.names - new_dim_sizes = [lev.size for lev in index.levels] + new_dim_names = pd_index.names + new_dim_sizes = [lev.size for lev in pd_index.levels] variables: Dict[Hashable, Variable] = {} indexes = {k: v for k, v in self.xindexes.items() if k != dim} for name, var in obj.variables.items(): - if name != dim: + if name not in index_vnames: if dim in var.dims: new_dims = dict(zip(new_dim_names, new_dim_sizes)) variables[name] = var.unstack({dim: new_dims}) else: variables[name] = var - for name, lev in zip(new_dim_names, index.levels): + for name, lev in zip(new_dim_names, pd_index.levels): idx, idx_vars = PandasIndex.from_pandas_index(lev, name) variables[name] = idx_vars[name] indexes[name] = idx @@ -4162,10 +4205,9 @@ def unstack( -------- Dataset.stack """ + if dim is None: - dims = [ - d for d in self.dims if isinstance(self.get_index(d), pd.MultiIndex) - ] + dims = list(self.dims) else: if isinstance(dim, str) or not isinstance(dim, Iterable): dims = [dim] @@ -4178,13 +4220,21 @@ def unstack( f"Dataset does not contain the dimensions: {missing_dims}" ) - non_multi_dims = [ - d for d in dims if not isinstance(self.get_index(d), pd.MultiIndex) - ] + # each specified dimension must have exactly one multi-index + stacked_indexes: Dict[Any, Tuple[PandasMultiIndex, List[Any]]] = {} + for d in dims: + idx, idx_var_names = self._find_stack_index(d, multi=True) + if idx is not None: + stacked_indexes[d] = cast(PandasMultiIndex, idx), idx_var_names + + if dim is None: + dims = list(stacked_indexes) + else: + non_multi_dims = set(dims) - set(stacked_indexes) if non_multi_dims: raise ValueError( "cannot unstack dimensions that do not " - f"have a MultiIndex: {non_multi_dims}" + f"have exactly one MultiIndex: {tuple(non_multi_dims)}" ) result = self.copy(deep=False) @@ -4216,9 +4266,11 @@ def unstack( not isinstance(v.data, np.ndarray) for v in self.variables.values() ) ): - result = result._unstack_full_reindex(dim, fill_value, sparse) + result = result._unstack_full_reindex( + dim, stacked_indexes[dim], fill_value, sparse + ) else: - result = result._unstack_once(dim, fill_value) + result = result._unstack_once(dim, stacked_indexes[dim], fill_value) return result def update(self, other: "CoercibleMapping") -> "Dataset": From e62b9cd0726deb495493bf60d2617808371fb65d Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 23 Sep 2021 11:38:14 +0200 Subject: [PATCH 045/159] add PandasMultiIndex.from_product_variables Refactored utils.multiindex_from_product_levels utility function into a new PandasMultiIndex classmethod. --- xarray/core/dataset.py | 19 ++++----- xarray/core/indexes.py | 44 +++++++++++++++++---- xarray/core/utils.py | 28 -------------- xarray/tests/test_indexes.py | 75 ++++++++++++++++++++++++++++-------- xarray/tests/test_utils.py | 25 ------------ 5 files changed, 105 insertions(+), 86 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 28b0a9f4473..43a8013c638 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3899,20 +3899,17 @@ def _stack_once(self, dims, new_dim): # A new index is created only if each of the stacked dimensions has # one and only one 1-d coordinate index # TODO: add API option to force/skip the creation of a new index (see GH 5202) - stack_indexes: Dict[Any, Tuple[pd.Index, Any]] = {} - # stack_idx_vars: Dict[Any, Variable] = {} + product_vars: Dict[Any, Variable] = {} for dim in dims: index, names = self._find_stack_index(dim) if index is not None: - stack_indexes[dim] = cast(PandasIndex, index).index, names[0] - # n = names[0] - # stack_idx_vars[n] = variables[n] - - if len(stack_indexes) == len(dims): - levels, names = zip(*stack_indexes.values()) - midx = utils.multiindex_from_product_levels(levels, names=names) - idx, idx_vars = PandasMultiIndex.from_pandas_index(midx, new_dim) - # idx, idx_vars = PandasMultiIndex.from_variables(stack_idx_vars) + n = names[0] + product_vars[n] = self.variables[n] + + if len(product_vars) == len(dims): + idx, idx_vars = PandasMultiIndex.from_product_variables( + product_vars, new_dim + ) new_indexes = {k: idx for k in idx_vars} # keep consistent multi-index coordinate order for k in idx_vars: diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index f35ac2d0f45..8a2595ae9bd 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -329,9 +329,9 @@ def __getitem__(self, indexer: Any): return self._replace(self.index[indexer]) -def _check_dim_compat(variables: Mapping[Any, "Variable"]) -> Hashable: - """Check that all multi-index variable candidates share the same (single) dimension - and return the name of that dimension. +def _check_dim_compat(variables: Mapping[Any, "Variable"], all_dims: str = "equal"): + """Check that all multi-index variable candidates are 1-dimensional and + either share the same (single) dimension or each have a different dimension. """ if any([var.ndim != 1 for var in variables.values()]): @@ -339,13 +339,17 @@ def _check_dim_compat(variables: Mapping[Any, "Variable"]) -> Hashable: dims = set([var.dims for var in variables.values()]) - if len(dims) > 1: + if all_dims == "equal" and len(dims) > 1: raise ValueError( - "unmatched dimensions for variables " + "unmatched dimensions for multi-index variables " + ", ".join([f"{k!r} {v.dims}" for k, v in variables.items()]) ) - return next(iter(dims))[0] + if all_dims == "different" and len(dims) < len(variables): + raise ValueError( + "conflicting dimensions for multi-index product variables " + + ", ".join([f"{k!r} {v.dims}" for k, v in variables.items()]) + ) def _get_var_metadata(variables: Mapping[Any, "Variable"]) -> Dict[Any, Dict[str, Any]]: @@ -410,7 +414,8 @@ def _replace(self, index, dim=None, level_coords_dtype=None) -> "PandasMultiInde def from_variables( cls, variables: Mapping[Any, "Variable"] ) -> Tuple["PandasMultiIndex", IndexVars]: - dim = _check_dim_compat(variables) + _check_dim_compat(variables) + dim = next(iter(variables.values())).dims[0] index = pd.MultiIndex.from_arrays( [var.values for var in variables.values()], names=variables.keys() @@ -425,6 +430,31 @@ def from_variables( return obj, index_vars + @classmethod + def from_product_variables( + cls, variables: Mapping[Any, "Variable"], dim: Hashable + ) -> Tuple["PandasMultiIndex", IndexVars]: + """Create a new Pandas MultiIndex from the product of 1-d variables (levels) along a + new dimension. + + Level variables must have a dimension distinct from each other. + + Keeps levels the same (doesn't refactorize them) so that it gives back the original + labels after a stack/unstack roundtrip. + + """ + _check_dim_compat(variables, all_dims="different") + + level_indexes = [utils.safe_cast_to_index(var) for var in variables.values()] + + split_labels, levels = zip(*[lev.factorize() for lev in level_indexes]) + labels_mesh = np.meshgrid(*split_labels, indexing="ij") + labels = [x.ravel() for x in labels_mesh] + + index = pd.MultiIndex(levels, labels, sortorder=0, names=variables.keys()) + + return cls.from_pandas_index(index, dim, var_meta=_get_var_metadata(variables)) + @classmethod def from_variables_maybe_expand( cls, diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 85553b1594f..35a284723c7 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -23,7 +23,6 @@ MutableMapping, MutableSet, Optional, - Sequence, Tuple, TypeVar, Union, @@ -113,33 +112,6 @@ def safe_cast_to_index(array: Any) -> pd.Index: return _maybe_cast_to_cftimeindex(index) -def multiindex_from_product_levels( - levels: Sequence[pd.Index], names: Sequence[str] = None -) -> pd.MultiIndex: - """Creating a MultiIndex from a product without refactorizing levels. - - Keeping levels the same gives back the original labels when we unstack. - - Parameters - ---------- - levels : sequence of pd.Index - Values for each MultiIndex level. - names : sequence of str, optional - Names for each level. - - Returns - ------- - pandas.MultiIndex - """ - if any(not isinstance(lev, pd.Index) for lev in levels): - raise TypeError("levels must be a list of pd.Index objects") - - split_labels, levels = zip(*[lev.factorize() for lev in levels]) - labels_mesh = np.meshgrid(*split_labels, indexing="ij") - labels = [x.ravel() for x in labels_mesh] - return pd.MultiIndex(levels, labels, sortorder=0, names=names) - - def maybe_wrap_array(original, new_array): """Wrap a transformed array with __array_wrap__ if it can be done safely. diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 6b108904efd..309bab5f95a 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -11,6 +11,8 @@ ) from xarray.core.variable import IndexVariable +from . import assert_equal, assert_identical + def test_asarray_tuplesafe() -> None: res = _asarray_tuplesafe(("a", 1)) @@ -40,7 +42,7 @@ def test_from_variables(self) -> None: ) index, index_vars = PandasIndex.from_variables({"x": var}) - xr.testing.assert_identical(var.to_index_variable(), index_vars["x"]) + assert_identical(var.to_index_variable(), index_vars["x"]) assert index_vars["x"].dtype == var.dtype assert index.dim == "x" assert index.index.equals(index_vars["x"].to_index()) @@ -62,7 +64,7 @@ def test_from_pandas_index(self) -> None: assert index.dim == "x" assert index.index is pd_idx assert index.index.name == "foo" - xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3])) + assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3])) # test no name set for pd.Index pd_idx.name = None @@ -136,13 +138,13 @@ def test_rename(self) -> None: assert new_index.index.name == "b" assert new_index.dim == "x" assert new_index.coord_dtype == np.int32 - xr.testing.assert_identical(index_vars["b"], IndexVariable("x", [1, 2, 3])) + assert_identical(index_vars["b"], IndexVariable("x", [1, 2, 3])) new_index, index_vars = index.rename({}, {"x": "y"}) assert new_index.index.name == "a" assert new_index.dim == "y" assert new_index.coord_dtype == np.int32 - xr.testing.assert_identical(index_vars["a"], IndexVariable("y", [1, 2, 3])) + assert_identical(index_vars["a"], IndexVariable("y", [1, 2, 3])) def test_copy(self) -> None: expected = PandasIndex([1, 2, 3], "x", coord_dtype=np.int32) @@ -183,9 +185,9 @@ def test_from_variables(self) -> None: assert index.index.names == ["level1", "level2"] assert list(index_vars) == ["x", "level1", "level2"] - xr.testing.assert_equal(xr.IndexVariable("x", expected_idx), index_vars["x"]) - xr.testing.assert_identical(v_level1.to_index_variable(), index_vars["level1"]) - xr.testing.assert_identical(v_level2.to_index_variable(), index_vars["level2"]) + assert_equal(xr.IndexVariable("x", expected_idx), index_vars["x"]) + assert_identical(v_level1.to_index_variable(), index_vars["level1"]) + assert_identical(v_level2.to_index_variable(), index_vars["level2"]) var = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) with pytest.raises( @@ -194,9 +196,56 @@ def test_from_variables(self) -> None: PandasMultiIndex.from_variables({"var": var}) v_level3 = xr.Variable("y", [4, 5, 6]) - with pytest.raises(ValueError, match=r"unmatched dimensions for variables.*"): + with pytest.raises( + ValueError, match=r"unmatched dimensions for multi-index variables.*" + ): PandasMultiIndex.from_variables({"level1": v_level1, "level3": v_level3}) + def test_from_product_variables(self) -> None: + prod_vars = { + "x": xr.Variable("x", pd.Index(["b", "a"]), attrs={"foo": "bar"}), + "y": xr.Variable("y", pd.Index([1, 3, 2])), + } + + index, index_vars = PandasMultiIndex.from_product_variables(prod_vars, "z") + + assert index.dim == "z" + assert index.index.names == ["x", "y"] + np.testing.assert_array_equal( + index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] + ) + + assert list(index_vars) == ["z", "x", "y"] + midx = pd.MultiIndex.from_product([["b", "a"], [1, 3, 2]]) + assert_equal(xr.IndexVariable("z", midx), index_vars["z"]) + assert_identical( + xr.IndexVariable("z", ["b", "b", "b", "a", "a", "a"], attrs={"foo": "bar"}), + index_vars["x"], + ) + assert_identical(xr.IndexVariable("z", [1, 3, 2, 1, 3, 2]), index_vars["y"]) + + with pytest.raises( + ValueError, match=r"conflicting dimensions for multi-index product.*" + ): + PandasMultiIndex.from_product_variables( + {"x": xr.Variable("x", ["a", "b"]), "x2": xr.Variable("x", [1, 2])}, + "z", + ) + + def test_from_product_variables_non_unique(self) -> None: + prod_vars = { + "x": xr.Variable("x", pd.Index(["b", "a"]), attrs={"foo": "bar"}), + "y": xr.Variable("y", pd.Index([1, 1, 2])), + } + + index, _ = PandasMultiIndex.from_product_variables(prod_vars, "z") + + np.testing.assert_array_equal( + index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] + ) + np.testing.assert_array_equal(index.index.levels[0], ["b", "a"]) + np.testing.assert_array_equal(index.index.levels[1], [1, 2]) + def test_from_pandas_index(self) -> None: foo_data = np.array([0, 0, 1], dtype="int") bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") @@ -208,9 +257,9 @@ def test_from_pandas_index(self) -> None: assert index.index.equals(pd_idx) assert index.index.names == ("foo", "bar") assert index.index.name == "x" - xr.testing.assert_identical(index_vars["x"], IndexVariable("x", pd_idx)) - xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", foo_data)) - xr.testing.assert_identical(index_vars["bar"], IndexVariable("x", bar_data)) + assert_identical(index_vars["x"], IndexVariable("x", pd_idx)) + assert_identical(index_vars["foo"], IndexVariable("x", foo_data)) + assert_identical(index_vars["bar"], IndexVariable("x", bar_data)) assert index_vars["foo"].dtype == foo_data.dtype assert index_vars["bar"].dtype == bar_data.dtype @@ -293,8 +342,6 @@ def test_get_all_coords(self) -> None: idx = [PandasIndex([1, 2, 3], "x"), PandasIndex([4, 5, 6], "y")] indexes = Indexes({"a": idx[0], "b": idx[1], "c": idx[0]}) - assert indexes.get_all_coords("a") == ("a", "c") - # test cached internal dicts assert indexes.get_all_coords("a") == ("a", "c") with pytest.raises(ValueError, match="errors must be.*"): @@ -310,5 +357,3 @@ def test_group_by_index(self): indexes = Indexes({"a": idx[0], "b": idx[1], "c": idx[0]}) assert indexes.group_by_index() == [(idx[0], ("a", "c")), (idx[1], ("b",))] - # test cached internal dicts - assert indexes.group_by_index() == [(idx[0], ("a", "c")), (idx[1], ("b",))] diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index ce796e9de49..0a720b23d3b 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -89,31 +89,6 @@ def test_safe_cast_to_index_datetime_datetime(): assert isinstance(actual, pd.Index) -def test_multiindex_from_product_levels(): - result = utils.multiindex_from_product_levels( - [pd.Index(["b", "a"]), pd.Index([1, 3, 2])] - ) - np.testing.assert_array_equal( - result.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] - ) - np.testing.assert_array_equal(result.levels[0], ["b", "a"]) - np.testing.assert_array_equal(result.levels[1], [1, 3, 2]) - - other = pd.MultiIndex.from_product([["b", "a"], [1, 3, 2]]) - np.testing.assert_array_equal(result.values, other.values) - - -def test_multiindex_from_product_levels_non_unique(): - result = utils.multiindex_from_product_levels( - [pd.Index(["b", "a"]), pd.Index([1, 1, 2])] - ) - np.testing.assert_array_equal( - result.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] - ) - np.testing.assert_array_equal(result.levels[0], ["b", "a"]) - np.testing.assert_array_equal(result.levels[1], [1, 2]) - - class TestArrayEquiv: def test_0d(self): # verify our work around for pd.isnull not working for 0-dimensional From f7aca70761ac67e91dec169441dd5e3cc9f2c378 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 23 Sep 2021 12:52:07 +0200 Subject: [PATCH 046/159] unstack: propagate index coordinate metadata --- xarray/core/dataset.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 43a8013c638..c72ae66acb1 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4115,7 +4115,11 @@ def _unstack_once( variables[name] = var for name, lev in zip(pd_index.names, pd_index.levels): - idx, idx_vars = PandasIndex.from_pandas_index(lev, name) + var = self.variables[name] + meta = { + name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} + } + idx, idx_vars = PandasIndex.from_pandas_index(lev, name, var_meta=meta) variables[name] = idx_vars[name] indexes[name] = idx @@ -4159,7 +4163,11 @@ def _unstack_full_reindex( variables[name] = var for name, lev in zip(new_dim_names, pd_index.levels): - idx, idx_vars = PandasIndex.from_pandas_index(lev, name) + var = self.variables[name] + meta = { + name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} + } + idx, idx_vars = PandasIndex.from_pandas_index(lev, name, var_meta=meta) variables[name] = idx_vars[name] indexes[name] = idx From e44c56e635f6f37598cbecb0da290a71b944f3ff Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 24 Sep 2021 03:23:24 +0200 Subject: [PATCH 047/159] wip: fix/update tests --- xarray/core/dataarray.py | 5 +- xarray/core/dataset.py | 12 ++-- xarray/core/indexes.py | 35 ++++++---- xarray/core/indexing.py | 4 +- xarray/core/merge.py | 6 +- xarray/tests/test_dataarray.py | 93 ++++++++++++++----------- xarray/tests/test_dataset.py | 120 ++++++++++++++++++++++++++++----- xarray/tests/test_indexing.py | 12 ++-- 8 files changed, 199 insertions(+), 88 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6aa68bfd902..1a8dd9582c8 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -59,12 +59,11 @@ _default, either_dict_or_kwargs, ) -from .variable import ( +from .variable import ( # assert_unique_multiindex_level_names, IndexVariable, Variable, as_compatible_data, as_variable, - assert_unique_multiindex_level_names, ) if TYPE_CHECKING: @@ -159,7 +158,7 @@ def _infer_coords_and_dims( "matching the dimension size" ) - assert_unique_multiindex_level_names(new_coords) + # assert_unique_multiindex_level_names(new_coords) return new_coords, dims diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c72ae66acb1..bc699835615 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1054,7 +1054,7 @@ def _overwrite_indexes( no_index_variables[k] = v for name in indexes: - new_variables[name] = variables[name] + new_indexes[name] = indexes[name] for name in index_variables: new_variables[name] = variables[name] @@ -3708,7 +3708,7 @@ def set_index( new_dims = [replace_dims.get(d, d) for d in v.dims] variables[k] = v._replace(dims=new_dims) - coord_names = set(new_variables) | self._coord_names + coord_names = self._coord_names - set(drop_variables) | set(new_variables) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes_ @@ -3761,7 +3761,6 @@ def reset_index( if isinstance(index, PandasMultiIndex) and name not in self.dims: # special case for pd.MultiIndex (name is an index level): # replace by a new index with dropped level(s) instead of just drop the index - # TODO: eventually extend Index API to allow this for custom multi-indexes? if index not in replaced_indexes: level_names = index.index.names level_vars = { @@ -3769,9 +3768,10 @@ def reset_index( for k in level_names if k not in dims_or_levels } - idx, idx_vars = index.keep_levels(level_vars) - new_indexes.update({k: idx for k in idx_vars}) - new_variables.update(idx_vars) + if level_vars: + idx, idx_vars = index.keep_levels(level_vars) + new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) replaced_indexes.append(index) if drop: diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 8a2595ae9bd..fb1ddec5d5e 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -570,6 +570,7 @@ def from_pandas_index( return cls(index, dim, level_coords_dtype=level_coords_dtype), index_vars def query(self, labels, method=None, tolerance=None) -> QueryResult: + from .dataarray import DataArray from .variable import Variable if method is not None or tolerance is not None: @@ -644,26 +645,36 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: scalar_coord_values.update({k: v for k, v in zip(levels, label)}) else: - label = ( - label - if getattr(label, "ndim", 1) > 1 # vectorized-indexing - else _asarray_tuplesafe(label) - ) - if label.ndim == 0: - indexer, new_index = self.index.get_loc_level(label.item(), level=0) - scalar_coord_values[self.index.names[0]] = label.item() - elif label.dtype.kind == "b": - indexer = label + label_array = normalize_label(label) + if label_array.ndim == 0: + label_value = as_scalar(label_array) + indexer, new_index = self.index.get_loc_level(label_value, level=0) + scalar_coord_values[self.index.names[0]] = label_value + elif label_array.dtype.kind == "b": + indexer = label_array else: - if label.ndim > 1: + if label_array.ndim > 1: raise ValueError( "Vectorized selection is not available along " f"coordinate {coord_name!r} with a multi-index" ) - indexer = get_indexer_nd(self.index, label) + indexer = get_indexer_nd(self.index, label_array) if np.any(indexer < 0): raise KeyError(f"not all values found in index {coord_name!r}") + # attach dimension names and/or coordinates to positional indexer + if isinstance(label, Variable): + indexer = Variable(label.dims, indexer) + elif isinstance(label, DataArray): + # do not include label-indexer DataArray coordinates that conflict + # with the level names of this index + coords = { + k: v + for k, v in label._coords.items() + if k not in self.index.names + } + indexer = DataArray(indexer, coords=coords, dims=label.dims) + if new_index is not None: # variable(s) attrs and encoding metadata are propagated # when replacing the indexes in the resulting xarray object diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 91fb7b0f605..217b98d5c34 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -127,9 +127,9 @@ def group_indexers_by_index( unique_indexes[index_id] = index grouped_indexers[index_id][key] = label elif key in obj.coords: - raise KeyError(f"no index found for coordinate {key}") + raise KeyError(f"no index found for coordinate {key!r}") elif key not in obj.dims: - raise KeyError(f"{key} is not a valid dimension or coordinate") + raise KeyError(f"{key!r} is not a valid dimension or coordinate") elif len(options): raise ValueError( f"cannot supply selection options {options!r} for dimension {key!r}" diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 6ae7886c421..1aa3abe65c7 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -24,7 +24,7 @@ from .duck_array_ops import lazy_array_equiv from .indexes import Index, PandasIndex, PandasMultiIndex from .utils import Frozen, compat_dict_union, dict_equiv, equivalent -from .variable import Variable, as_variable, assert_unique_multiindex_level_names +from .variable import Variable, as_variable # , assert_unique_multiindex_level_names if TYPE_CHECKING: from .coordinates import Coordinates @@ -487,7 +487,7 @@ def merge_coords( collected = collect_variables_and_indexes(aligned) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) variables, out_indexes = merge_collected(collected, prioritized, compat=compat) - assert_unique_multiindex_level_names(variables) + # assert_unique_multiindex_level_names(variables) return variables, out_indexes @@ -684,7 +684,7 @@ def merge_core( variables, out_indexes = merge_collected( collected, prioritized, compat=compat, combine_attrs=combine_attrs ) - assert_unique_multiindex_level_names(variables) + # assert_unique_multiindex_level_names(variables) dims = calculate_dimensions(variables) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 0d3d9907783..099355f9d7e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1002,18 +1002,19 @@ def test_sel_float(self): def test_sel_float_multiindex(self): # regression test https://github.com/pydata/xarray/issues/5691 - midx = pd.MultiIndex.from_arrays( - [["a", "a", "b", "b"], [0.1, 0.2, 0.3, 0.4]], names=["lvl1", "lvl2"] + # test multi-index created from coordinates, one with dtype=float32 + lvl1 = ["a", "a", "b", "b"] + lvl2 = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) + da = xr.DataArray( + [1, 2, 3, 4], dims="x", coords={"lvl1": ("x", lvl1), "lvl2": ("x", lvl2)} ) - da = xr.DataArray([1, 2, 3, 4], coords={"x": midx}, dims="x") + da = da.set_index(x=["lvl1", "lvl2"]) actual = da.sel(lvl1="a", lvl2=0.1) expected = da.isel(x=0) assert_equal(actual, expected) - # TODO: test multi-index created from coordinates, one with dtype=float32 - def test_sel_no_index(self): array = DataArray(np.arange(10), dims="x") assert_identical(array[0], array.sel(x=0)) @@ -1846,31 +1847,33 @@ def test_set_index(self): array2d.set_index(x="level") # Issue 3176: Ensure clear error message on key error. - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match=r".*variable\(s\) do not exist"): obj.set_index(x="level_4") - assert str(excinfo.value) == "level_4 is not the name of an existing variable." def test_reset_index(self): indexes = [self.mindex.get_level_values(n) for n in self.mindex.names] coords = {idx.name: ("x", idx) for idx in indexes} + coords["x"] = ("x", self.mindex.values) expected = DataArray(self.mda.values, coords=coords, dims="x") obj = self.mda.reset_index("x") assert_identical(obj, expected) + assert len(obj.xindexes) == 0 obj = self.mda.reset_index(self.mindex.names) assert_identical(obj, expected) + assert len(obj.xindexes) == 0 obj = self.mda.reset_index(["x", "level_1"]) assert_identical(obj, expected) + assert list(obj.xindexes) == ["level_2"] - coords = { - "x": ("x", self.mindex.droplevel("level_1")), - "level_1": ("x", self.mindex.get_level_values("level_1")), - } expected = DataArray(self.mda.values, coords=coords, dims="x") obj = self.mda.reset_index(["level_1"]) assert_identical(obj, expected) + assert list(obj.xindexes) == ["level_2"] + assert type(obj.xindexes["level_2"]) is PandasIndex - expected = DataArray(self.mda.values, dims="x") + coords = {k: v for k, v in coords.items() if k != "x"} + expected = DataArray(self.mda.values, coords=coords, dims="x") obj = self.mda.reset_index("x", drop=True) assert_identical(obj, expected) @@ -1880,15 +1883,16 @@ def test_reset_index(self): # single index array = DataArray([1, 2], coords={"x": ["a", "b"]}, dims="x") - expected = DataArray([1, 2], coords={"x_": ("x", ["a", "b"])}, dims="x") - assert_identical(array.reset_index("x"), expected) + obj = array.reset_index("x") + assert_identical(obj, array) + assert len(obj.xindexes) == 0 def test_reset_index_keep_attrs(self): coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True}) da = DataArray([1, 0], [coord_1]) - expected = DataArray([1, 0], {"coord_1_": coord_1}, dims=["coord_1"]) obj = da.reset_index("coord_1") - assert_identical(expected, obj) + assert_identical(obj, da) + assert len(obj.xindexes) == 0 def test_reorder_levels(self): midx = self.mindex.reorder_levels(["level_2", "level_1"]) @@ -2148,42 +2152,53 @@ def test_dataset_math(self): assert_identical(actual, expected) def test_stack_unstack(self): - orig = DataArray([[0, 1], [2, 3]], dims=["x", "y"], attrs={"foo": 2}) + orig = DataArray( + [[0, 1], [2, 3]], + coords={"x": [0, 1], "y": ["a", "b"]}, + dims=["x", "y"], + attrs={"foo": 2}, + ) assert_identical(orig, orig.unstack()) # test GH3000 - a = orig[:0, :1].stack(dim=("x", "y")).dim.to_index() - if pd.__version__ < "0.24.0": - b = pd.MultiIndex( - levels=[pd.Int64Index([]), pd.Int64Index([0])], - labels=[[], []], - names=["x", "y"], - ) - else: - b = pd.MultiIndex( - levels=[pd.Int64Index([]), pd.Int64Index([0])], - codes=[[], []], - names=["x", "y"], - ) - pd.testing.assert_index_equal(a, b) - - actual = orig.stack(z=["x", "y"]).unstack("z").drop_vars(["x", "y"]) + # no default range index anymore + # a = orig[:0, :1].stack(dim=("x", "y")).dim.to_index() + # if pd.__version__ < "0.24.0": + # b = pd.MultiIndex( + # levels=[pd.Int64Index([]), pd.Int64Index([0])], + # labels=[[], []], + # names=["x", "y"], + # ) + # else: + # b = pd.MultiIndex( + # levels=[pd.Int64Index([]), pd.Int64Index([0])], + # codes=[[], []], + # names=["x", "y"], + # ) + # pd.testing.assert_index_equal(a, b) + + actual = orig.stack(z=["x", "y"]).unstack("z") assert_identical(orig, actual) - actual = orig.stack(z=[...]).unstack("z").drop_vars(["x", "y"]) + actual = orig.stack(z=[...]).unstack("z") assert_identical(orig, actual) dims = ["a", "b", "c", "d", "e"] - orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), dims=dims) + coords = { + "a": [0], + "b": [1, 2], + "c": [3, 4, 5], + "d": [6, 7], + "e": [8], + } + orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), coords=coords, dims=dims) stacked = orig.stack(ab=["a", "b"], cd=["c", "d"]) unstacked = stacked.unstack(["ab", "cd"]) - roundtripped = unstacked.drop_vars(["a", "b", "c", "d"]).transpose(*dims) - assert_identical(orig, roundtripped) + assert_identical(orig, unstacked.transpose(*dims)) unstacked = stacked.unstack() - roundtripped = unstacked.drop_vars(["a", "b", "c", "d"]).transpose(*dims) - assert_identical(orig, roundtripped) + assert_identical(orig, unstacked.transpose(*dims)) def test_stack_unstack_decreasing_coordinate(self): # regression test for GH980 diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index fe5cc17b809..04fe914db2a 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -239,8 +239,8 @@ def test_repr_multiindex(self): Dimensions: (x: 4) Coordinates: * x (x) object MultiIndex - level_1 (x) object 'a' 'a' 'b' 'b' - level_2 (x) int64 1 2 1 2 + * level_1 (x) object 'a' 'a' 'b' 'b' + * level_2 (x) int64 1 2 1 2 Data variables: *empty*""" ) @@ -259,8 +259,8 @@ def test_repr_multiindex(self): Dimensions: (x: 4) Coordinates: * x (x) object MultiIndex - a_quite_long_level_name (x) object 'a' 'a' 'b' 'b' - level_2 (x) int64 1 2 1 2 + * a_quite_long_level_name (x) object 'a' 'a' 'b' 'b' + * level_2 (x) int64 1 2 1 2 Data variables: *empty*""" ) @@ -2994,33 +2994,49 @@ def test_set_index(self): obj = ds.set_index(x=mindex.names) assert_identical(obj, expected) + # ensure pre-existing indexes involved are removed + # (level_2 should be a coordinate with no index) + ds = create_test_multiindex() + coords = {"x": coords["level_1"], "level_2": coords["level_2"]} + expected = Dataset({}, coords=coords) + + obj = ds.set_index(x="level_1") + assert_identical(obj, expected) + # ensure set_index with no existing index and a single data var given # doesn't return multi-index ds = Dataset(data_vars={"x_var": ("x", [0, 1, 2])}) expected = Dataset(coords={"x": [0, 1, 2]}) assert_identical(ds.set_index(x="x_var"), expected) - # Issue 3176: Ensure clear error message on key error. - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match=r"bar variable\(s\) do not exist"): ds.set_index(foo="bar") - assert str(excinfo.value) == "bar is not the name of an existing variable." + + with pytest.raises(ValueError, match=r"dimension mismatch.*"): + ds.set_index(y="x_var") def test_reset_index(self): ds = create_test_multiindex() mindex = ds["x"].to_index() indexes = [mindex.get_level_values(n) for n in mindex.names] coords = {idx.name: ("x", idx) for idx in indexes} + coords["x"] = ("x", mindex.values) expected = Dataset({}, coords=coords) obj = ds.reset_index("x") assert_identical(obj, expected) + assert len(obj.xindexes) == 0 + + ds = Dataset(coords={"y": ("x", [1, 2, 3])}) + with pytest.raises(ValueError, match=r".*not coordinates with an index"): + ds.reset_index("y") def test_reset_index_keep_attrs(self): coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True}) ds = Dataset({}, {"coord_1": coord_1}) - expected = Dataset({}, {"coord_1_": coord_1}) obj = ds.reset_index("coord_1") - assert_identical(expected, obj) + assert_identical(obj, ds) + assert len(obj.xindexes) == 0 def test_reorder_levels(self): ds = create_test_multiindex() @@ -3028,6 +3044,10 @@ def test_reorder_levels(self): midx = mindex.reorder_levels(["level_2", "level_1"]) expected = Dataset({}, coords={"x": midx}) + # check attrs propagated + ds["level_1"].attrs["foo"] = "bar" + expected["level_1"].attrs["foo"] = "bar" + reindexed = ds.reorder_levels(x=["level_2", "level_1"]) assert_identical(reindexed, expected) @@ -3037,15 +3057,22 @@ def test_reorder_levels(self): def test_stack(self): ds = Dataset( - {"a": ("x", [0, 1]), "b": (("x", "y"), [[0, 1], [2, 3]]), "y": ["a", "b"]} + data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, + coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, ) exp_index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) expected = Dataset( - {"a": ("z", [0, 0, 1, 1]), "b": ("z", [0, 1, 2, 3]), "z": exp_index} + data_vars={"b": ("z", [0, 1, 2, 3])}, + coords={"z": exp_index}, ) + # check attrs propagated + ds["x"].attrs["foo"] = "bar" + expected["x"].attrs["foo"] = "bar" + actual = ds.stack(z=["x", "y"]) assert_identical(expected, actual) + assert list(actual.xindexes) == ["z", "x", "y"] actual = ds.stack(z=[...]) assert_identical(expected, actual) @@ -3060,17 +3087,75 @@ def test_stack(self): exp_index = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=["y", "x"]) expected = Dataset( - {"a": ("z", [0, 1, 0, 1]), "b": ("z", [0, 2, 1, 3]), "z": exp_index} + data_vars={"b": ("z", [0, 2, 1, 3])}, + coords={"z": exp_index}, ) + expected["x"].attrs["foo"] = "bar" + actual = ds.stack(z=["y", "x"]) assert_identical(expected, actual) + assert list(actual.xindexes) == ["z", "y", "x"] + + def test_stack_no_index(self) -> None: + ds = Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, + coords={"xx": ("x", [0, 1]), "y": ["a", "b"]}, + ) + expected = Dataset( + data_vars={"b": ("z", [0, 1, 2, 3])}, + coords={"xx": ("z", [0, 0, 1, 1]), "y": ("z", ["a", "b", "a", "b"])}, + ) + + actual = ds.stack(z=["x", "y"]) + assert_identical(expected, actual) + assert len(actual.xindexes) == 0 + + # multi-index on a dimension to stack is discarded too + midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("lvl1", "lvl2")) + ds = xr.Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3], [4, 5], [6, 7]])}, + coords={"x": midx, "y": [0, 1]}, + ) + expected = Dataset( + data_vars={"b": ("z", [0, 1, 2, 3, 4, 5, 6, 7])}, + coords={ + "x": ("z", np.repeat(midx.values, 2)), + "lvl1": ("z", np.repeat(midx.get_level_values("lvl1"), 2)), + "lvl2": ("z", np.repeat(midx.get_level_values("lvl2"), 2)), + "y": ("z", [0, 1, 0, 1] * 2), + }, + ) + actual = ds.stack(z=["x", "y"]) + assert_identical(expected, actual) + assert len(actual.xindexes) == 0 + + def test_stack_non_dim_coords(self): + ds = Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, + coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, + ).rename_vars(x="xx") + + exp_index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["xx", "y"]) + expected = Dataset( + data_vars={"b": ("z", [0, 1, 2, 3])}, + coords={"z": exp_index}, + ) + + actual = ds.stack(z=["x", "y"]) + assert_identical(expected, actual) + assert list(actual.xindexes) == ["z", "xx", "y"] def test_unstack(self): index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) - ds = Dataset({"b": ("z", [0, 1, 2, 3]), "z": index}) + ds = Dataset(data_vars={"b": ("z", [0, 1, 2, 3])}, coords={"z": index}) expected = Dataset( {"b": (("x", "y"), [[0, 1], [2, 3]]), "x": [0, 1], "y": ["a", "b"]} ) + + # check attrs propagated + ds["x"].attrs["foo"] = "bar" + expected["x"].attrs["foo"] = "bar" + for dim in ["z", ["z"], None]: actual = ds.unstack(dim) assert_identical(actual, expected) @@ -3079,7 +3164,7 @@ def test_unstack_errors(self): ds = Dataset({"x": [1, 2, 3]}) with pytest.raises(ValueError, match=r"does not contain the dimensions"): ds.unstack("foo") - with pytest.raises(ValueError, match=r"do not have a MultiIndex"): + with pytest.raises(ValueError, match=r".*do not have exactly one MultiIndex"): ds.unstack("x") def test_unstack_fill_value(self): @@ -3139,12 +3224,11 @@ def test_stack_unstack_fast(self): def test_stack_unstack_slow(self): ds = Dataset( - { + data_vars={ "a": ("x", [0, 1]), "b": (("x", "y"), [[0, 1], [2, 3]]), - "x": [0, 1], - "y": ["a", "b"], - } + }, + coords={"x": [0, 1], "y": ["a", "b"]}, ) stacked = ds.stack(z=["x", "y"]) actual = stacked.isel(z=slice(None, None, -1)).unstack("z") diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index ee0daad0829..6a82bc4ca3e 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -76,9 +76,11 @@ def test_group_indexers_by_index(self) -> None: assert indexers == {"y": 0} assert len(grouped_indexers) == 3 - with pytest.raises(KeyError, match=r"no index found for coordinate y2"): + with pytest.raises(KeyError, match=r"no index found for coordinate 'y2'"): indexing.group_indexers_by_index(data, {"y2": 2.0}, {}) - with pytest.raises(KeyError, match=r"w is not a valid dimension or coordinate"): + with pytest.raises( + KeyError, match=r"'w' is not a valid dimension or coordinate" + ): indexing.group_indexers_by_index(data, {"w": "a"}, {}) with pytest.raises(ValueError, match=r"cannot supply.*"): indexing.group_indexers_by_index(data, {"z": 1}, {"method": "nearest"}) @@ -112,9 +114,9 @@ def test_indexer( for k in results.indexes: assert results.indexes[k].equals(expected_idx[k]) - assert results.index_vars.keys() == expected_vars.keys() - for k in results.index_vars: - assert_array_equal(results.index_vars[k], expected_vars[k]) + assert results.variables.keys() == expected_vars.keys() + for k in results.variables: + assert_array_equal(results.variables[k], expected_vars[k]) assert set(results.drop_coords) == set(expected_drop) assert results.rename_dims == expected_rename_dims From f88c2c2f6c8498cd8e1a002e3778e0fed0ff39c1 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 24 Sep 2021 14:46:22 +0200 Subject: [PATCH 048/159] fix/refactor reindex --- xarray/core/alignment.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index c50d6b54afb..a288fe9bcca 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -20,7 +20,7 @@ from . import dtypes from .indexes import Index, PandasIndex, get_indexer_nd from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index -from .variable import IndexVariable, Variable +from .variable import Variable if TYPE_CHECKING: from .common import DataWithCoords @@ -574,8 +574,15 @@ def reindex_variables( "from that to be indexed along {:s}".format(str(indexer.dims), dim) ) - target = safe_cast_to_index(indexers[dim]) - new_indexes[dim] = PandasIndex(target, dim) + var_meta = {dim: {"dtype": getattr(indexer, "dtype", None)}} + if dim in variables: + var = variables[dim] + var_meta[dim].update({"attrs": var.attrs, "encoding": var.encoding}) + + target = safe_cast_to_index(indexers[dim]).rename(dim) + idx, idx_vars = PandasIndex.from_pandas_index(target, dim, var_meta=var_meta) + new_indexes[dim] = idx + reindexed.update(idx_vars) if dim in indexes: # TODO (benbovy - flexible indexes): support other indexes than pd.Index? @@ -598,13 +605,6 @@ def reindex_variables( int_indexers[dim] = int_indexer - if dim in variables: - var = variables[dim] - args: tuple = (var.attrs, var.encoding) - else: - args = () - reindexed[dim] = IndexVariable((dim,), indexers[dim], *args) - for dim in sizes: if dim not in indexes and dim in indexers: existing_size = sizes[dim] From 1daaf8d6c405a904569a3119ac75aa00c0ffb727 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 24 Sep 2021 15:11:10 +0200 Subject: [PATCH 049/159] functools.cached_property is since py38 Use the good old way as we still support py37 --- xarray/core/indexes.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index fb1ddec5d5e..4dc71c40bcd 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1,5 +1,4 @@ import collections.abc -import functools from collections import defaultdict from typing import ( TYPE_CHECKING, @@ -783,22 +782,31 @@ def __init__(self, indexes: Dict[Any, T_Index]): """ self._indexes = indexes - @functools.cached_property + self.__coord_name_id: Optional[Dict[Any, int]] = None + self.__id_index: Optional[Dict[int, T_Index]] = None + self.__id_coord_names: Optional[Dict[int, Tuple[Hashable, ...]]] = None + + @property def _coord_name_id(self) -> Dict[Any, int]: - return {k: id(idx) for k, idx in self._indexes.items()} + if self.__coord_name_id is None: + self.__coord_name_id = {k: id(idx) for k, idx in self._indexes.items()} + return self.__coord_name_id - @functools.cached_property + @property def _id_index(self) -> Dict[int, T_Index]: - return {id(idx): idx for idx in self.get_unique()} + if self.__id_index is None: + self.__id_index = {id(idx): idx for idx in self.get_unique()} + return self.__id_index - @functools.cached_property + @property def _id_coord_names(self) -> Dict[int, Tuple[Hashable, ...]]: - id_coord_names: Mapping[int, List[Hashable]] = defaultdict(list) - - for k, v in self._coord_name_id.items(): - id_coord_names[v].append(k) + if self.__id_coord_names is None: + id_coord_names: Mapping[int, List[Hashable]] = defaultdict(list) + for k, v in self._coord_name_id.items(): + id_coord_names[v].append(k) + self.__id_coord_names = {k: tuple(v) for k, v in id_coord_names.items()} - return {k: tuple(v) for k, v in id_coord_names.items()} + return self.__id_coord_names def get_unique(self) -> List[T_Index]: """Returns a list of unique indexes, preserving order.""" From 41ebc45b792e88f69b6a4c42de3977614f688770 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 24 Sep 2021 15:34:15 +0200 Subject: [PATCH 050/159] update reset_coords --- xarray/core/coordinates.py | 3 ++- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- xarray/tests/test_dataarray.py | 6 ++++++ 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 749ac2dcef9..0caf78728d6 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -360,7 +360,8 @@ def to_dataset(self) -> "Dataset": from .dataset import Dataset coords = {k: v.copy(deep=False) for k, v in self._data._coords.items()} - return Dataset._construct_direct(coords, set(coords)) + indexes = dict(self._data.xindexes) + return Dataset._construct_direct(coords, set(coords), indexes=indexes) def __delitem__(self, key: Hashable) -> None: if key not in self: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1a8dd9582c8..886b59572de 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -862,7 +862,7 @@ def reset_coords( Dataset, or DataArray if ``drop == True`` """ if names is None: - names = set(self.coords) - set(self.dims) + names = set(self.coords) - set(self.xindexes) dataset = self.coords.to_dataset().reset_coords(names, drop) if drop: return self._replace(coords=dataset._variables) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bc699835615..3aaa5e00473 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1677,7 +1677,7 @@ def reset_coords( else: names = list(names) self._assert_all_in_dataset(names) - bad_coords = set(names) & set(self.dims) + bad_coords = set(names) & set(self.xindexes) if bad_coords: raise ValueError( f"cannot remove index coordinates with reset_coords: {bad_coords}" diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 099355f9d7e..a6d466014c0 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1430,6 +1430,12 @@ def test_reset_coords(self): with pytest.raises(ValueError, match=r"cannot remove index"): data.reset_coords("y") + # non-dimension index coordinate + midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("lvl1", "lvl2")) + data = DataArray([1, 2, 3, 4], coords={"x": midx}, dims="x", name="foo") + with pytest.raises(ValueError, match=r"cannot remove index"): + data.reset_coords("lvl1") + def test_assign_coords(self): array = DataArray(10) actual = array.assign_coords(c=42) From e7508832dedae10613c174d777499b445284883a Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 24 Sep 2021 15:35:01 +0200 Subject: [PATCH 051/159] fix ipython key completion test make sure the stacked dataset used has a multi-index --- xarray/tests/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 04fe914db2a..2d4bf66097a 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5533,7 +5533,7 @@ def test_ipython_key_completion(self): assert sorted(actual) == sorted(expected) # MultiIndex - ds_midx = ds.stack(dim12=["dim1", "dim2"]) + ds_midx = ds.stack(dim12=["dim2", "dim3"]) actual = ds_midx._ipython_key_completions_() expected = [ "var1", From 6c6f09eb396ba7bdaa6415886fcf1f0874992cfa Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 27 Sep 2021 14:11:06 +0200 Subject: [PATCH 052/159] update test_map_index_queries --- xarray/tests/test_indexing.py | 143 +++++++++++++++++++++------------- 1 file changed, 91 insertions(+), 52 deletions(-) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 6a82bc4ca3e..d9ded2d0430 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -1,4 +1,5 @@ import itertools +from typing import Any, Dict, cast import numpy as np import pandas as pd @@ -7,6 +8,7 @@ from xarray import DataArray, Dataset, Variable from xarray.core import indexing, nputils from xarray.core.indexes import PandasIndex, PandasMultiIndex +from xarray.core.types import T_Xarray from . import IndexerMaker, ReturnItem, assert_array_equal @@ -86,40 +88,51 @@ def test_group_indexers_by_index(self) -> None: indexing.group_indexers_by_index(data, {"z": 1}, {"method": "nearest"}) def test_map_index_queries(self) -> None: + def create_query_results( + x_indexer, + x_index, + index_vars, + other_vars, + drop_coords, + drop_indexes, + rename_dims, + ): + dim_indexers = {"x": x_indexer} + indexes = {k: x_index for k in index_vars} + variables = {} + variables.update(index_vars) + variables.update(other_vars) + + return indexing.QueryResult( + dim_indexers=dim_indexers, + indexes=indexes, + variables=variables, + drop_coords=drop_coords, + drop_indexes=drop_indexes, + rename_dims=rename_dims, + ) + def test_indexer( - data, - x, - expected_pos, - expected_idx=None, - expected_vars=None, - expected_drop=None, - expected_rename_dims=None, + data: T_Xarray, + x: Any, + expected: indexing.QueryResult, ) -> None: - if expected_vars is None: - expected_vars = {} - if expected_idx is None: - expected_idx = {} - else: - expected_idx = {k: expected_idx for k in expected_vars} - if expected_drop is None: - expected_drop = [] - if expected_rename_dims is None: - expected_rename_dims = {} - results = indexing.map_index_queries(data, {"x": x}) - assert_array_equal(results.dim_indexers.get("x"), expected_pos) + assert results.dim_indexers.keys() == expected.dim_indexers.keys() + assert_array_equal(results.dim_indexers["x"], expected.dim_indexers["x"]) - assert results.indexes.keys() == expected_idx.keys() + assert results.indexes.keys() == expected.indexes.keys() for k in results.indexes: - assert results.indexes[k].equals(expected_idx[k]) + assert results.indexes[k].equals(expected.indexes[k]) - assert results.variables.keys() == expected_vars.keys() + assert results.variables.keys() == expected.variables.keys() for k in results.variables: - assert_array_equal(results.variables[k], expected_vars[k]) + assert_array_equal(results.variables[k], expected.variables[k]) - assert set(results.drop_coords) == set(expected_drop) - assert results.rename_dims == expected_rename_dims + assert set(results.drop_coords) == set(expected.drop_coords) + assert set(results.drop_indexes) == set(expected.drop_indexes) + assert results.rename_dims == expected.rename_dims data = Dataset({"x": ("x", [1, 2, 3])}) mindex = pd.MultiIndex.from_product( @@ -127,68 +140,94 @@ def test_indexer( ) mdata = DataArray(range(8), [("x", mindex)]) - test_indexer(data, 1, 0) - test_indexer(data, np.int32(1), 0) - test_indexer(data, Variable([], 1), 0) - test_indexer(mdata, ("a", 1, -1), 0) - test_indexer( - mdata, - ("a", 1), + test_indexer(data, 1, indexing.QueryResult({"x": 0})) + test_indexer(data, np.int32(1), indexing.QueryResult({"x": 0})) + test_indexer(data, Variable([], 1), indexing.QueryResult({"x": 0})) + test_indexer(mdata, ("a", 1, -1), indexing.QueryResult({"x": 0})) + + expected = create_query_results( [True, True, False, False, False, False, False, False], *PandasIndex.from_pandas_index(pd.Index([-1, -2]), "three"), - ["x", "one", "two"], + {"one": Variable((), "a"), "two": Variable((), 1)}, + ["x"], + ["one", "two"], {"x": "three"}, ) - test_indexer( - mdata, - "a", + test_indexer(mdata, ("a", 1), expected) + + expected = create_query_results( slice(0, 4, None), *PandasMultiIndex.from_pandas_index( pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), "x", ), + {"one": Variable((), "a")}, + [], ["one"], + {}, ) - test_indexer( - mdata, - ("a",), + test_indexer(mdata, "a", expected) + + expected = create_query_results( [True, True, True, True, False, False, False, False], *PandasMultiIndex.from_pandas_index( pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), "x", ), + {"one": Variable((), "a")}, + [], ["one"], + {}, + ) + test_indexer(mdata, ("a",), expected) + + test_indexer( + mdata, [("a", 1, -1), ("b", 2, -2)], indexing.QueryResult({"x": [0, 7]}) + ) + test_indexer( + mdata, slice("a", "b"), indexing.QueryResult({"x": slice(0, 8, None)}) ) - test_indexer(mdata, [("a", 1, -1), ("b", 2, -2)], [0, 7]) - test_indexer(mdata, slice("a", "b"), slice(0, 8, None)) - test_indexer(mdata, slice(("a", 1), ("b", 1)), slice(0, 6, None)) - test_indexer(mdata, {"one": "a", "two": 1, "three": -1}, 0) test_indexer( mdata, - {"one": "a", "two": 1}, + slice(("a", 1), ("b", 1)), + indexing.QueryResult({"x": slice(0, 6, None)}), + ) + test_indexer( + mdata, {"one": "a", "two": 1, "three": -1}, indexing.QueryResult({"x": 0}) + ) + + expected = create_query_results( [True, True, False, False, False, False, False, False], *PandasIndex.from_pandas_index(pd.Index([-1, -2]), "three"), - ["x", "one", "two"], + {"one": Variable((), "a"), "two": Variable((), 1)}, + ["x"], + ["one", "two"], {"x": "three"}, ) - test_indexer( - mdata, - {"one": "a", "three": -1}, + test_indexer(mdata, {"one": "a", "two": 1}, expected) + + expected = create_query_results( [True, False, True, False, False, False, False, False], *PandasIndex.from_pandas_index(pd.Index([1, 2]), "two"), - ["x", "one", "three"], + {"one": Variable((), "a"), "three": Variable((), -1)}, + ["x"], + ["one", "three"], {"x": "two"}, ) - test_indexer( - mdata, - {"one": "a"}, + test_indexer(mdata, {"one": "a", "three": -1}, expected) + + expected = create_query_results( [True, True, True, True, False, False, False, False], *PandasMultiIndex.from_pandas_index( pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), "x", ), + {"one": Variable((), "a")}, + [], ["one"], + {}, ) + test_indexer(mdata, {"one": "a"}, expected) def test_read_only_view(self) -> None: From 6437e13c86f45be6f51e5037aeeb123ceca15453 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 27 Sep 2021 15:13:54 +0200 Subject: [PATCH 053/159] do not coerce bool indexer as float coord dtype Fixes #5727 --- xarray/core/indexes.py | 4 +++- xarray/tests/test_indexes.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 4dc71c40bcd..e5c27411d43 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -139,9 +139,11 @@ def _is_nested_tuple(possible_tuple): def normalize_label(value, dtype=None) -> np.ndarray: if getattr(value, "ndim", 1) <= 1: value = _asarray_tuplesafe(value) - if dtype is not None and dtype.kind == "f": + if dtype is not None and dtype.kind == "f" and value.dtype.kind != "b": # pd.Index built from coordinate with float precision != 64 # see https://github.com/pydata/xarray/pull/3153 for details + # bypass coercing dtype for boolean indexers (ignore index) + # see https://github.com/pydata/xarray/issues/5727 value = np.asarray(value, dtype=dtype) return value diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 309bab5f95a..24c000d2dcf 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -88,6 +88,16 @@ def test_query(self) -> None: with pytest.raises(ValueError, match=r"does not have a MultiIndex"): index.query({"x": {"one": 0}}) + def test_query_boolean(self) -> None: + # index should be ignored and indexer dtype should not be coerced + # see https://github.com/pydata/xarray/issues/5727 + index = PandasIndex(pd.Index([0.0, 2.0, 1.0, 3.0]), "x") + actual = index.query({"x": [False, True, False, True]}) + expected_dim_indexers = {"x": [False, True, False, True]} + np.testing.assert_array_equal( + actual.dim_indexers["x"], expected_dim_indexers["x"] + ) + def test_query_datetime(self) -> None: index = PandasIndex( pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x" From a955d457172073fc31d66552d2007cb10b28edd5 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 13 Oct 2021 13:07:56 +0200 Subject: [PATCH 054/159] add Index.create_variables() method This will probably be used instead of including index coordinate variables in the signature (return type) of many methods of ``Index``. It has several advantages: - More DRY, and for custom indexes that do not need to create coordinate variables with special data adapters, it's easier to just skip implementing this method and not bother with returning empty dicts in other methods - This allows to decouple index vs. coordinates creation. For many cases this can be done at the same time but for some cases like alignment this is useful - It's a more elegant solution when we need to propagate metadata (attrs, encoding) --- xarray/core/indexes.py | 32 ++++++++++++++++++++++++++++++++ xarray/tests/test_indexes.py | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index e5c27411d43..5caa919d45e 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -39,6 +39,11 @@ def from_variables( ) -> Tuple["Index", IndexVars]: raise NotImplementedError() + def create_variables( + self, attrs: Mapping[Any, Any], encoding: Mapping[Any, Any] + ) -> IndexVars: + return {} + def to_pandas_index(self) -> pd.Index: """Cast this xarray index to a pandas.Index object or raise a TypeError if this is not supported. @@ -216,6 +221,18 @@ def from_variables( return obj, {name: index_var} + def create_variables( + self, attrs: Mapping[Any, Any], encoding: Mapping[Any, Any] + ) -> IndexVars: + from .variable import IndexVariable + + name = self.index.name + data = PandasIndexingAdapter(self.index, dtype=self.coord_dtype) + var = IndexVariable( + self.dim, data, attrs=attrs.get(name), encoding=encoding.get(name) + ) + return {name: var} + @classmethod def from_pandas_index( cls, @@ -570,6 +587,21 @@ def from_pandas_index( index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta) return cls(index, dim, level_coords_dtype=level_coords_dtype), index_vars + def create_variables( + self, attrs: Mapping[Any, Any], encoding: Mapping[Any, Any] + ) -> IndexVars: + var_meta = {} + for name in self.index.names: + var_meta[name] = { + "dtype": self.level_coords_dtype[name], + "attrs": attrs.get(name, {}), + "encoding": encoding.get(name, {}), + } + + return _create_variables_from_multiindex( + self.index, self.dim, var_meta=var_meta + ) + def query(self, labels, method=None, tolerance=None) -> QueryResult: from .dataarray import DataArray from .variable import Variable diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 24c000d2dcf..9ee121c341b 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -73,7 +73,19 @@ def test_from_pandas_index(self) -> None: assert index.index is not pd_idx assert index.index.name == "x" - def to_pandas_index(self): + def test_create_variables(self) -> None: + pd_idx = pd.Index([1, 2, 3], name="foo") + index, _ = PandasIndex.from_pandas_index(pd_idx, "x") + attrs = {"unit": "m"} + encoding = {"fill_value": 0} + + actual = index.create_variables( + attrs={"foo": attrs}, encoding={"foo": encoding} + ) + expected = {"foo": IndexVariable("x", pd_idx, attrs=attrs, encoding=encoding)} + assert_identical(actual["foo"], expected["foo"]) + + def test_to_pandas_index(self) -> None: pd_idx = pd.Index([1, 2, 3], name="foo") index = PandasIndex(pd_idx, "x") assert index.to_pandas_index() is pd_idx @@ -276,6 +288,25 @@ def test_from_pandas_index(self) -> None: with pytest.raises(ValueError, match=".*conflicting multi-index level name.*"): PandasMultiIndex.from_pandas_index(pd_idx, "foo") + def test_create_variables(self) -> None: + foo_data = np.array([0, 0, 1], dtype="int") + bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) + + index, _ = PandasMultiIndex.from_pandas_index(pd_idx, "x") + index_vars = index.create_variables( + attrs={"foo": {"unit": "m"}}, + encoding={"bar": {"fill_value": 0}}, + ) + + assert_identical(index_vars["x"], IndexVariable("x", pd_idx)) + assert_identical( + index_vars["foo"], IndexVariable("x", foo_data, attrs={"unit": "m"}) + ) + assert_identical( + index_vars["bar"], IndexVariable("x", bar_data, encoding={"fill_value": 0}) + ) + def test_query(self) -> None: index = PandasMultiIndex( pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x" From c87586e958eccd1f5025f253fbc644530f7796af Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 13 Oct 2021 13:46:13 +0200 Subject: [PATCH 055/159] improve Dataset/DataArray indexes proxy class One possible solution to the recurring problem of accessing coordinate variables related to a given index in some of Xarray's internals. I feel that the ``Indexes`` proxy is the right place for storing references to those indexed coordinate variables. It's safer than relying on a public attribute/property of ``Index``. --- xarray/core/dataarray.py | 4 +- xarray/core/dataset.py | 4 +- xarray/core/indexes.py | 67 +++++++++++++++++++++++++++------ xarray/tests/test_indexes.py | 73 ++++++++++++++++++++++++++++-------- 4 files changed, 117 insertions(+), 31 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 886b59572de..34049288f90 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -827,14 +827,14 @@ def indexes(self) -> Indexes: DataArray.xindexes """ - return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) + return self.xindexes.to_pandas_indexes() @property def xindexes(self) -> Indexes: """Mapping of xarray Index objects used for label based indexing.""" if self._indexes is None: self._indexes = default_indexes(self._coords, self.dims) - return Indexes(self._indexes) + return Indexes(self._indexes, {k: self._coords[k] for k in self._indexes}) @property def coords(self) -> DataArrayCoordinates: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3aaa5e00473..16c7f186c00 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1599,14 +1599,14 @@ def indexes(self) -> Indexes[pd.Index]: Dataset.xindexes """ - return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) + return self.xindexes.to_pandas_indexes() @property def xindexes(self) -> Indexes[Index]: """Mapping of xarray Index objects used for label based indexing.""" if self._indexes is None: self._indexes = default_indexes(self._variables, self._dims) - return Indexes(self._indexes) + return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) @property def coords(self) -> DatasetCoordinates: diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 5caa919d45e..e98eefc36d5 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -22,9 +22,10 @@ from . import formatting, utils from .indexing import PandasIndexingAdapter, PandasMultiIndexingAdapter, QueryResult -from .utils import is_dict_like, is_scalar +from .utils import FrozenDict, is_dict_like, is_scalar if TYPE_CHECKING: + from .utils import Frozen from .variable import IndexVariable, Variable IndexVars = Dict[Any, "IndexVariable"] @@ -805,16 +806,29 @@ class Indexes(collections.abc.Mapping, Generic[T_Index]): """ _indexes: Dict[Any, T_Index] + _variables: Dict[Any, "Variable"] + + __slots__ = ( + "_indexes", + "_variables", + "__coord_name_id", + "__id_index", + "__id_coord_names", + ) - def __init__(self, indexes: Dict[Any, T_Index]): + def __init__(self, indexes: Dict[Any, T_Index], variables: Dict[Any, "Variable"]): """Constructor not for public consumption. Parameters ---------- indexes : dict Indexes held by this object. + variables : dict + Indexed coordinate variables in this object. + """ self._indexes = indexes + self._variables = variables self.__coord_name_id: Optional[Dict[Any, int]] = None self.__id_index: Optional[Dict[int, T_Index]] = None @@ -842,8 +856,13 @@ def _id_coord_names(self) -> Dict[int, Tuple[Hashable, ...]]: return self.__id_coord_names + @property + def coords(self) -> Frozen: + """Return an immutable dictionnary of all indexed coordinate variables.""" + return FrozenDict(self._variables) + def get_unique(self) -> List[T_Index]: - """Returns a list of unique indexes, preserving order.""" + """Return a list of unique indexes, preserving order.""" unique_indexes: List[T_Index] = [] seen: Set[T_Index] = set() @@ -857,8 +876,8 @@ def get_unique(self) -> List[T_Index]: def get_all_coords( self, coord_name: Hashable, errors: str = "raise" - ) -> Tuple[Hashable, ...]: - """Return the names of all coordinates having the same index. + ) -> Dict[Hashable, "Variable"]: + """Return all coordinates having the same index. Parameters ---------- @@ -870,8 +889,8 @@ def get_all_coords( Returns ------- - names : tuple - The names of all coordinates having the same index. + coords : dict + A dictionary of all coordinate variables having the same index. """ if errors not in ["raise", "ignore"]: @@ -881,14 +900,38 @@ def get_all_coords( if errors == "raise": raise ValueError(f"no index found for {coord_name!r} coordinate") else: - return tuple() + return {} + + all_coord_names = self._id_coord_names[self._coord_name_id[coord_name]] + return {k: self._variables[k] for k in all_coord_names} - return self._id_coord_names[self._coord_name_id[coord_name]] + def group_by_index(self) -> List[Tuple[T_Index, Dict[Hashable, "Variable"]]]: + """Returns a list of unique indexes and their corresponding coordinates.""" - def group_by_index(self) -> List[Tuple[T_Index, Tuple[Hashable, ...]]]: - """Returns a list of unique indexes and their corresponding coordinate names.""" + index_coords = [] - return [(self._id_index[i], self._id_coord_names[i]) for i in self._id_index] + for i in self._id_index: + index = self._id_index[i] + coords = {k: self._variables[k] for k in self._id_coord_names[i]} + index_coords.append((index, coords)) + + return index_coords + + def to_pandas_indexes(self): + """Returns an immutable proxy for Dataset or DataArrary pandas indexes. + + Raises an error if this proxy contains indexes that cannot be coerced to + pandas.Index objects. + + """ + indexes = {} + for k, idx in self._indexes.items(): + if isinstance(idx, pd.Index): + indexes[k] = idx + elif isinstance(idx, Index): + indexes[k] = idx.to_pandas_index() + + return Indexes(indexes, self._variables) def __iter__(self): return iter(self._indexes) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 9ee121c341b..f774eabc56a 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -1,15 +1,18 @@ +from typing import Any, Dict, List, Tuple + import numpy as np import pandas as pd import pytest import xarray as xr from xarray.core.indexes import ( + Index, Indexes, PandasIndex, PandasMultiIndex, _asarray_tuplesafe, ) -from xarray.core.variable import IndexVariable +from xarray.core.variable import IndexVariable, Variable from . import assert_equal, assert_identical @@ -373,28 +376,68 @@ def test_copy(self) -> None: class TestIndexes: - def test_get_unique(self) -> None: - idx = [PandasIndex([1, 2, 3], "x"), PandasIndex([4, 5, 6], "y")] - indexes = Indexes({"a": idx[0], "b": idx[1], "c": idx[0]}) + def _create_indexes(self) -> Tuple[Indexes[Index], List[PandasIndex]]: + x_idx = PandasIndex(pd.Index([1, 2, 3], name="x"), "x") + y_idx = PandasIndex(pd.Index([4, 5, 6], name="y"), "y") + z_pd_midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=["one", "two"] + ) + z_midx = PandasMultiIndex(z_pd_midx, "z") + + unique_indexes = [x_idx, y_idx, z_midx] + indexes: Dict[Any, Index] = { + "x": x_idx, + "y": y_idx, + "z": z_midx, + "one": z_midx, + "two": z_midx, + } + variables: Dict[Any, Variable] = {} + for idx in unique_indexes: + variables.update(idx.create_variables({}, {})) + + return Indexes(indexes, variables), unique_indexes - assert indexes.get_unique() == idx + def test_coords(self) -> None: + indexes, _ = self._create_indexes() + assert tuple(indexes.coords) == ("x", "y", "z", "one", "two") + + def test_get_unique(self) -> None: + indexes, unique = self._create_indexes() + assert indexes.get_unique() == unique def test_get_all_coords(self) -> None: - idx = [PandasIndex([1, 2, 3], "x"), PandasIndex([4, 5, 6], "y")] - indexes = Indexes({"a": idx[0], "b": idx[1], "c": idx[0]}) + indexes, _ = self._create_indexes() - assert indexes.get_all_coords("a") == ("a", "c") + expected = { + "z": indexes.coords["z"], + "one": indexes.coords["one"], + "two": indexes.coords["two"], + } + assert indexes.get_all_coords("one") == expected with pytest.raises(ValueError, match="errors must be.*"): - indexes.get_all_coords("a", errors="invalid") + indexes.get_all_coords("x", errors="invalid") with pytest.raises(ValueError, match="no index found.*"): - indexes.get_all_coords("z") + indexes.get_all_coords("no_coord") - assert indexes.get_all_coords("z", errors="ignore") == tuple() + assert indexes.get_all_coords("no_coord", errors="ignore") == {} def test_group_by_index(self): - idx = [PandasIndex([1, 2, 3], "x"), PandasIndex([4, 5, 6], "y")] - indexes = Indexes({"a": idx[0], "b": idx[1], "c": idx[0]}) - - assert indexes.group_by_index() == [(idx[0], ("a", "c")), (idx[1], ("b",))] + indexes, unique = self._create_indexes() + + expected = [ + (unique[0], {"x": indexes.coords["x"]}), + (unique[1], {"y": indexes.coords["y"]}), + ( + unique[2], + { + "z": indexes.coords["z"], + "one": indexes.coords["one"], + "two": indexes.coords["two"], + }, + ), + ] + + assert indexes.group_by_index() == expected From db65e4400ddf1da8ccafd77283e87efe4caba41d Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 13 Oct 2021 23:16:53 +0200 Subject: [PATCH 056/159] align Indexes API with Coordinates API Expose similar properties: - variables - dims --- xarray/core/coordinates.py | 6 +----- xarray/core/dataset.py | 27 +-------------------------- xarray/core/indexes.py | 19 ++++++++++++++----- xarray/core/merge.py | 4 ++-- xarray/core/variable.py | 26 ++++++++++++++++++++++++++ xarray/tests/test_indexes.py | 24 ++++++++++++++---------- 6 files changed, 58 insertions(+), 48 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 0caf78728d6..9ce0e83282f 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -20,7 +20,7 @@ from .indexes import Index, Indexes from .merge import merge_coordinates_without_align, merge_coords from .utils import Frozen, ReprObject -from .variable import Variable +from .variable import calculate_dimensions, Variable if TYPE_CHECKING: from .dataarray import DataArray @@ -272,8 +272,6 @@ def to_dataset(self) -> "Dataset": def _update_coords( self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] ) -> None: - from .dataset import calculate_dimensions - variables = self._data._variables.copy() variables.update(coords) @@ -335,8 +333,6 @@ def __getitem__(self, key: Hashable) -> "DataArray": def _update_coords( self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] ) -> None: - from .dataset import calculate_dimensions - coords_plus_data = coords.copy() coords_plus_data[_THIS_ARRAY] = self._data.variable dims = calculate_dimensions(coords_plus_data) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 16c7f186c00..4e946e2bfc4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -93,6 +93,7 @@ maybe_wrap_array, ) from .variable import ( + calculate_dimensions, IndexVariable, Variable, as_variable, @@ -166,32 +167,6 @@ def _get_virtual_variable( return ref_name, var_name, virtual_var -def calculate_dimensions(variables: Mapping[Any, Variable]) -> Dict[Hashable, int]: - """Calculate the dimensions corresponding to a set of variables. - - Returns dictionary mapping from dimension names to sizes. Raises ValueError - if any of the dimension sizes conflict. - """ - dims: Dict[Hashable, int] = {} - last_used = {} - scalar_vars = {k for k, v in variables.items() if not v.dims} - for k, var in variables.items(): - for dim, size in zip(var.dims, var.shape): - if dim in scalar_vars: - raise ValueError( - f"dimension {dim!r} already exists as a scalar variable" - ) - if dim not in dims: - dims[dim] = size - last_used[dim] = k - elif dims[dim] != size: - raise ValueError( - f"conflicting sizes for dimension {dim!r}: " - f"length {size} on {k!r} and length {dims[dim]} on {last_used!r}" - ) - return dims - - def _assert_empty(args: tuple, msg: str = "%s") -> None: if args: raise ValueError(msg % args) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index e98eefc36d5..c7c0e99dcee 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -22,10 +22,9 @@ from . import formatting, utils from .indexing import PandasIndexingAdapter, PandasMultiIndexingAdapter, QueryResult -from .utils import FrozenDict, is_dict_like, is_scalar +from .utils import Frozen, is_dict_like, is_scalar if TYPE_CHECKING: - from .utils import Frozen from .variable import IndexVariable, Variable IndexVars = Dict[Any, "IndexVariable"] @@ -811,6 +810,7 @@ class Indexes(collections.abc.Mapping, Generic[T_Index]): __slots__ = ( "_indexes", "_variables", + "_dims", "__coord_name_id", "__id_index", "__id_coord_names", @@ -830,6 +830,7 @@ def __init__(self, indexes: Dict[Any, T_Index], variables: Dict[Any, "Variable"] self._indexes = indexes self._variables = variables + self._dims: Optional[Mapping[Hashable, int]] = None self.__coord_name_id: Optional[Dict[Any, int]] = None self.__id_index: Optional[Dict[int, T_Index]] = None self.__id_coord_names: Optional[Dict[int, Tuple[Hashable, ...]]] = None @@ -857,9 +858,17 @@ def _id_coord_names(self) -> Dict[int, Tuple[Hashable, ...]]: return self.__id_coord_names @property - def coords(self) -> Frozen: - """Return an immutable dictionnary of all indexed coordinate variables.""" - return FrozenDict(self._variables) + def variables(self) -> Mapping[Hashable, "Variable"]: + return Frozen(self._variables) + + @property + def dims(self) -> Mapping[Hashable, int]: + from .variable import calculate_dimensions + + if self._dims is None: + self._dims = calculate_dimensions(self._variables) + + return Frozen(self._dims) def get_unique(self) -> List[T_Index]: """Return a list of unique indexes, preserving order.""" diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 1aa3abe65c7..36ef65110f2 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -24,7 +24,7 @@ from .duck_array_ops import lazy_array_equiv from .indexes import Index, PandasIndex, PandasMultiIndex from .utils import Frozen, compat_dict_union, dict_equiv, equivalent -from .variable import Variable, as_variable # , assert_unique_multiindex_level_names +from .variable import calculate_dimensions, Variable, as_variable # , assert_unique_multiindex_level_names if TYPE_CHECKING: from .coordinates import Coordinates @@ -671,7 +671,7 @@ def merge_core( MergeError if the merge cannot be done successfully. """ from .dataarray import DataArray - from .dataset import Dataset, calculate_dimensions + from .dataset import Dataset _assert_compat_valid(compat) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c0cd5c462a3..819f6d90772 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2955,3 +2955,29 @@ def propagate_attrs_encoding( if old_var is not None: var.attrs = {**old_var.attrs, **var.attrs} var.encoding = {**old_var.encoding, **var.encoding} + + +def calculate_dimensions(variables: Mapping[Any, Variable]) -> Dict[Hashable, int]: + """Calculate the dimensions corresponding to a set of variables. + + Returns dictionary mapping from dimension names to sizes. Raises ValueError + if any of the dimension sizes conflict. + """ + dims: Dict[Hashable, int] = {} + last_used = {} + scalar_vars = {k for k, v in variables.items() if not v.dims} + for k, var in variables.items(): + for dim, size in zip(var.dims, var.shape): + if dim in scalar_vars: + raise ValueError( + f"dimension {dim!r} already exists as a scalar variable" + ) + if dim not in dims: + dims[dim] = size + last_used[dim] = k + elif dims[dim] != size: + raise ValueError( + f"conflicting sizes for dimension {dim!r}: " + f"length {size} on {k!r} and length {dims[dim]} on {last_used!r}" + ) + return dims diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index f774eabc56a..ec24a1d3f79 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -398,9 +398,13 @@ def _create_indexes(self) -> Tuple[Indexes[Index], List[PandasIndex]]: return Indexes(indexes, variables), unique_indexes - def test_coords(self) -> None: + def test_variables(self) -> None: indexes, _ = self._create_indexes() - assert tuple(indexes.coords) == ("x", "y", "z", "one", "two") + assert tuple(indexes.variables) == ("x", "y", "z", "one", "two") + + def test_dims(self) -> None: + indexes, _ = self._create_indexes() + assert indexes.dims == {"x": 3, "y": 3, "z": 4} def test_get_unique(self) -> None: indexes, unique = self._create_indexes() @@ -410,9 +414,9 @@ def test_get_all_coords(self) -> None: indexes, _ = self._create_indexes() expected = { - "z": indexes.coords["z"], - "one": indexes.coords["one"], - "two": indexes.coords["two"], + "z": indexes.variables["z"], + "one": indexes.variables["one"], + "two": indexes.variables["two"], } assert indexes.get_all_coords("one") == expected @@ -428,14 +432,14 @@ def test_group_by_index(self): indexes, unique = self._create_indexes() expected = [ - (unique[0], {"x": indexes.coords["x"]}), - (unique[1], {"y": indexes.coords["y"]}), + (unique[0], {"x": indexes.variables["x"]}), + (unique[1], {"y": indexes.variables["y"]}), ( unique[2], { - "z": indexes.coords["z"], - "one": indexes.coords["one"], - "two": indexes.coords["two"], + "z": indexes.variables["z"], + "one": indexes.variables["one"], + "two": indexes.variables["two"], }, ), ] From 417f1f44ba0045f510d55d32db9dbd8c2d57a849 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 13 Oct 2021 23:42:44 +0200 Subject: [PATCH 057/159] clean-up formatting using updated Indexes API --- xarray/core/formatting_html.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 209533b2027..87429f084a4 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -161,20 +161,8 @@ def _mapping_section( ) -def _dims_with_index(obj): - if not hasattr(obj, "indexes"): - return [] - - dims_with_index = set() - for coord_name in obj.xindexes: - for dim in obj[coord_name].dims: - dims_with_index.add(dim) - - return dims_with_index - - def dim_section(obj): - dim_list = format_dims(obj.dims, _dims_with_index(obj)) + dim_list = format_dims(obj.dims, obj.xindexes.dims) return collapsible_section( "Dimensions", inline_details=dim_list, enabled=False, collapsed=True @@ -255,6 +243,10 @@ def _obj_repr(obj, header_components, sections): def array_repr(arr): dims = OrderedDict((k, v) for k, v in zip(arr.dims, arr.shape)) + if hasattr(arr, "xindexes"): + indexed_dims = arr.xindexes.dims + else: + indexed_dims = {} obj_type = "xarray.{}".format(type(arr).__name__) arr_name = f"'{arr.name}'" if getattr(arr, "name", None) else "" @@ -262,7 +254,7 @@ def array_repr(arr): header_components = [ f"
    {obj_type}
    ", f"
    {arr_name}
    ", - format_dims(dims, _dims_with_index(arr)), + format_dims(dims, indexed_dims), ] sections = [array_section(arr)] From 929b8f366309d94d28f05863abdb1f141cfd7560 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 14 Oct 2021 18:20:37 +0200 Subject: [PATCH 058/159] wip: refactor alignment --- xarray/core/alignment.py | 260 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 258 insertions(+), 2 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index a288fe9bcca..ca643070bc6 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -5,11 +5,17 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, + FrozenSet, Hashable, + List, Mapping, Optional, + Sequence, + Set, Tuple, + Type, TypeVar, Union, ) @@ -18,9 +24,9 @@ import pandas as pd from . import dtypes -from .indexes import Index, PandasIndex, get_indexer_nd +from .indexes import Index, Indexes, PandasIndex, get_indexer_nd from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index -from .variable import Variable +from .variable import Variable, calculate_dimensions if TYPE_CHECKING: from .common import DataWithCoords @@ -75,6 +81,256 @@ def _override_indexes(objects, all_indexes, exclude): return objects +CoordNamesAndDims = FrozenSet[Tuple[Hashable, Tuple[Hashable, ...]]] +MatchingIndexKey = Tuple[CoordNamesAndDims, Type[Index]] +NormalizedIndexes = Dict[MatchingIndexKey, Index] +NormalizedCoords = Dict[MatchingIndexKey, Dict[Hashable, Variable]] + + + +class Alignator: + """Implements all the complex logic for the alignment of Xarray objects.""" + + objects: List[Union["Dataset", "DataArray"]] + join: str + exclude_dims: FrozenSet + indexes: Dict[MatchingIndexKey, Index] + coords: Dict[MatchingIndexKey, Dict[Hashable, Variable]] + all_indexes: Mapping[MatchingIndexKey, List[Index]] + all_coords: Mapping[MatchingIndexKey, List[Dict[Hashable, Variable]]] + unindexed_dim_sizes: Mapping[Hashable, Set] + aligned_indexes: Dict[Hashable, Index] + aligned_index_coords: Dict[Hashable, Variable] + + def __init__( + self, + objects: List[Union["Dataset", "DataArray"]], + join: str, + indexes: Union[Mapping[Any, Any], None], + exclude: Union[str, Set, Sequence], + ): + self.objects = objects + + if join not in ["inner", "outer", "overwrite", "exact", "left", "right"]: + raise ValueError(f"invalid value for join: {join}") + self.join = join + + if isinstance(exclude, str): + exclude = [exclude] + self.exclude_dims = frozenset(exclude) + + if indexes is None: + indexes = {} + self.indexes, self.coords = self._normalize_indexes(indexes) + + self.all_indexes = defaultdict(list) + self.all_coords = defaultdict(list) + self.unindexed_dim_sizes = defaultdict(set) + + self.aligned_indexes = {} + self.aligned_index_coords = {} + + def _normalize_indexes( + self, + indexes: Mapping[Any, Any], + ) -> Tuple[NormalizedIndexes, NormalizedCoords]: + """Normalize the indexes used for alignment. + + Return dictionaries of xarray Index objects and coordinates such that we can + group matching indexes based on the dictionary keys. + + """ + if isinstance(indexes, Indexes): + variables = dict(indexes.variables) + else: + variables = {} + + xr_indexes = {} + for k, idx in indexes.items(): + if not isinstance(idx, Index): + pd_idx = safe_cast_to_index(idx).copy() + pd_idx.name = k + idx, _ = PandasIndex.from_pandas_index(pd_idx, k) + variables.update(idx.create_variables()) + xr_indexes[k] = idx + + normalized_indexes = {} + normalized_coords = {} + for idx, coords in Indexes(xr_indexes, variables).group_by_index(): + coord_names_and_dims = [] + all_dims = set() + + for name, var in coords.items(): + dims = var.dims + coord_names_and_dims.append((name, dims)) + all_dims.update(dims) + + exclude_dims = all_dims & self.exclude_dims + if exclude_dims == all_dims: + continue + elif exclude_dims: + excl_dims_str = ", ".join(str(d) for d in exclude_dims) + incl_dims_str = ", ".join(str(d) for d in all_dims - exclude_dims) + raise ValueError( + f"cannot exclude dimension(s) {excl_dims_str} from alignment because " + "these are used by an index together with non-excluded dimensions " + f"{incl_dims_str}" + ) + + key = (frozenset(coord_names_and_dims), type(idx)) + normalized_indexes[key] = idx + normalized_coords[key] = coords + + return normalized_indexes, normalized_coords + + def find_matching_indexes(self): + for obj in self.objects: + obj_indexes, obj_coords = self._normalize_indexes(obj.xindexes) + for key, idx in obj_indexes.items(): + self.all_indexes[key].append(idx) + self.all_coords[key].append(obj_coords[key]) + + def find_matching_unindexed_dims(self): + for obj in self.objects: + for dim in obj.dims: + if dim not in self.exclude_dims and dim not in obj.xindexes.dims: + self.unindexed_dim_sizes[dim].add(obj.sizes[dim]) + + def assert_no_index_conflict(self): + """Check for uniqueness of both coordinate and dimension names accross all sets + of matching indexes. + + We need to make sure that all indexes used for alignment are fully compatible + and do not conflict each other. + + """ + matching_keys = set(self.all_indexes) | set(self.indexes) + + coord_count = defaultdict(int) + dim_count = defaultdict(int) + for coord_names_dims, _ in matching_keys: + dims_set = set() + for name, dims in coord_names_dims: + coord_count[name] += 1 + dims_set |= dims + for dim in dims_set: + dim_count[dim] += 1 + + for count, msg in [(coord_count, "coordinates"), (dim_count, "dimensions")]: + dup = {k: v for k, v in count.items() if v > 1} + if dup: + items_msg = ", ".join(f"{k} ({v} conflicting indexes)" for k, v in dup.items()) + raise ValueError( + "cannot align objects with conflicting indexes found for " + f"the following {msg}: {items_msg}\n" + "Conflicting indexes may occur when\n" + "- they relate to different sets of coordinates and/or dimensions\n" + "- they don't have the same type\n" + "- they are used to reindex data along common dimensions" + ) + + def _should_reindex(self, dims, index, other_indexes, coords, other_coords) -> bool: + """Whether or not we'll need to reindex all the variables + for a set of matching indexes. + + We won't reindex when all matching indexes are equal for two reasons: + - It's faster for the usual case (already aligned objects). + - It ensures it's possible to do operations that don't require alignment + on indexes with duplicate values (which cannot be reindexed with + pandas). This is useful, e.g., for overwriting such duplicate indexes. + + """ + try: + index_not_equal = any(not index.equals(idx) for idx in other_indexes) + except NotImplementedError: + # check coordinates equality for indexes that do not support alignment + index_not_equal = any( + not coords[k].equals(o_coords[k]) for o_coords in other_coords for k in coords + ) + has_unindexed_dims = any(dim in self.unindexed_dim_sizes for dim in dims) + return index_not_equal or has_unindexed_dims + + def _get_index_joiner(self, index_cls) -> Callable: + if self.join == "outer": + return functools.partial(functools.reduce, index_cls.union) + elif self.join == "inner": + return functools.partial(functools.reduce, index_cls.intersection) + elif self.join == "left": + return operator.itemgetter(0) + elif self.join == "right": + return operator.itemgetter(-1) + elif self.join == "override": + # We rewrite all indexes and then use join='left' + return operator.itemgetter(0) + else: + # join='exact' return dummy lambda (error is raised) + return lambda *args: None + + def align_indexes(self): + for key, matching_indexes in self.all_indexes.items(): + matching_coords = self.all_coords[key] + dims = set([d for coord in matching_coords[0].values() for d in coord.dims]) + index_cls = key[1] + + if key in self.indexes: + joined_index = self.indexes[key] + joined_index_coords = self.coords[key] + reindex = self._should_reindex( + dims, joined_index, matching_indexes, joined_index_coords, matching_coords + ) + else: + reindex = self._should_reindex( + dims, + matching_indexes[0], + matching_indexes[1:], + matching_coords[0], + matching_coords[1:], + ) + if reindex: + if self.join == "exact": + raise ValueError(f"indexes are not equal") + joiner = self._get_index_joiner(index_cls) + try: + joined_index = joiner(matching_indexes) + if self.join == "left": + joined_index_coords = matching_coords[0] + elif self.join == "right": + joined_index_coords = matching_coords[-1] + else: + joined_index_coords = joined_index.create_variables() + except NotImplementedError: + raise TypeError( + f"{index_cls.__qualname__} doesn't support alignment " + "with inner/outer join method" + ) + else: + joined_index = matching_indexes[0] + joined_index_coords = matching_coords[0] + + for name, var in joined_index_coords.items(): + self.aligned_indexes[name] = joined_index + self.aligned_index_coords[name] = var + + def assert_unindexed_dim_sizes_equal(self): + aligned_indexes_dim_sizes = calculate_dimensions(self.aligned_index_coords) + + for dim, sizes in self.unindexed_dim_sizes.items(): + if dim in aligned_indexes_dim_sizes: + sizes.add(aligned_indexes_dim_sizes[dim]) + if len(sizes) > 1: + raise ValueError( + f"arguments without labels along dimension {dim!r} cannot be " + f"aligned because they have different dimension sizes: {sizes!r}" + ) + + def align(self): + self.find_matching_indexes() + self.find_matching_unindexed_dims() + self.assert_no_index_conflict() + self.align_indexes() + self.assert_unindexed_dim_sizes_equal() + + def align( *objects: "DataAlignable", join="inner", From 02cc5d2375d13976d0e433fc295dc69a4920a9fc Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 14 Oct 2021 23:21:18 +0200 Subject: [PATCH 059/159] wip refactor alignment Almost done rewriting `align`. --- xarray/core/alignment.py | 119 ++++++++++++++++++++++++++++----------- 1 file changed, 87 insertions(+), 32 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index ca643070bc6..4c9b5bd9db4 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -26,7 +26,7 @@ from . import dtypes from .indexes import Index, Indexes, PandasIndex, get_indexer_nd from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index -from .variable import Variable, calculate_dimensions +from .variable import Variable if TYPE_CHECKING: from .common import DataWithCoords @@ -87,20 +87,23 @@ def _override_indexes(objects, all_indexes, exclude): NormalizedCoords = Dict[MatchingIndexKey, Dict[Hashable, Variable]] - class Alignator: - """Implements all the complex logic for the alignment of Xarray objects.""" + """Implements all the complex logic for the alignment of Xarray objects. + + For internal use only, not public API. + + """ objects: List[Union["Dataset", "DataArray"]] join: str exclude_dims: FrozenSet + reindex_dims: Set indexes: Dict[MatchingIndexKey, Index] coords: Dict[MatchingIndexKey, Dict[Hashable, Variable]] all_indexes: Mapping[MatchingIndexKey, List[Index]] all_coords: Mapping[MatchingIndexKey, List[Dict[Hashable, Variable]]] unindexed_dim_sizes: Mapping[Hashable, Set] - aligned_indexes: Dict[Hashable, Index] - aligned_index_coords: Dict[Hashable, Variable] + aligned_indexes: Indexes[Index] def __init__( self, @@ -119,6 +122,8 @@ def __init__( exclude = [exclude] self.exclude_dims = frozenset(exclude) + self.reindex_dims = set() + if indexes is None: indexes = {} self.indexes, self.coords = self._normalize_indexes(indexes) @@ -127,9 +132,6 @@ def __init__( self.all_coords = defaultdict(list) self.unindexed_dim_sizes = defaultdict(set) - self.aligned_indexes = {} - self.aligned_index_coords = {} - def _normalize_indexes( self, indexes: Mapping[Any, Any], @@ -219,21 +221,23 @@ def assert_no_index_conflict(self): for count, msg in [(coord_count, "coordinates"), (dim_count, "dimensions")]: dup = {k: v for k, v in count.items() if v > 1} if dup: - items_msg = ", ".join(f"{k} ({v} conflicting indexes)" for k, v in dup.items()) + items_msg = ", ".join( + f"{k} ({v} conflicting indexes)" for k, v in dup.items() + ) raise ValueError( "cannot align objects with conflicting indexes found for " f"the following {msg}: {items_msg}\n" "Conflicting indexes may occur when\n" - "- they relate to different sets of coordinates and/or dimensions\n" + "- they relate to different sets of coordinate and/or dimension names\n" "- they don't have the same type\n" "- they are used to reindex data along common dimensions" ) - def _should_reindex(self, dims, index, other_indexes, coords, other_coords) -> bool: - """Whether or not we'll need to reindex all the variables - for a set of matching indexes. + def _need_reindex(self, dims, index, other_indexes, coords, other_coords) -> bool: + """Whether or not we need to reindex variables for a set of + matching indexes. - We won't reindex when all matching indexes are equal for two reasons: + We don't reindex when all matching indexes are equal for two reasons: - It's faster for the usual case (already aligned objects). - It ensures it's possible to do operations that don't require alignment on indexes with duplicate values (which cannot be reindexed with @@ -245,7 +249,9 @@ def _should_reindex(self, dims, index, other_indexes, coords, other_coords) -> b except NotImplementedError: # check coordinates equality for indexes that do not support alignment index_not_equal = any( - not coords[k].equals(o_coords[k]) for o_coords in other_coords for k in coords + not coords[k].equals(o_coords[k]) + for o_coords in other_coords + for k in coords ) has_unindexed_dims = any(dim in self.unindexed_dim_sizes for dim in dims) return index_not_equal or has_unindexed_dims @@ -267,6 +273,9 @@ def _get_index_joiner(self, index_cls) -> Callable: return lambda *args: None def align_indexes(self): + aligned_indexes = {} + aligned_index_vars = {} + for key, matching_indexes in self.all_indexes.items(): matching_coords = self.all_coords[key] dims = set([d for coord in matching_coords[0].values() for d in coord.dims]) @@ -274,12 +283,16 @@ def align_indexes(self): if key in self.indexes: joined_index = self.indexes[key] - joined_index_coords = self.coords[key] - reindex = self._should_reindex( - dims, joined_index, matching_indexes, joined_index_coords, matching_coords + joined_index_vars = self.coords[key] + reindex = self._need_reindex( + dims, + joined_index, + matching_indexes, + joined_index_vars, + matching_coords, ) else: - reindex = self._should_reindex( + reindex = self._need_reindex( dims, matching_indexes[0], matching_indexes[1:], @@ -288,16 +301,17 @@ def align_indexes(self): ) if reindex: if self.join == "exact": - raise ValueError(f"indexes are not equal") + # TODO: more informative error message + raise ValueError("indexes are not equal") joiner = self._get_index_joiner(index_cls) try: joined_index = joiner(matching_indexes) if self.join == "left": - joined_index_coords = matching_coords[0] + joined_index_vars = matching_coords[0] elif self.join == "right": - joined_index_coords = matching_coords[-1] + joined_index_vars = matching_coords[-1] else: - joined_index_coords = joined_index.create_variables() + joined_index_vars = joined_index.create_variables() except NotImplementedError: raise TypeError( f"{index_cls.__qualname__} doesn't support alignment " @@ -305,30 +319,71 @@ def align_indexes(self): ) else: joined_index = matching_indexes[0] - joined_index_coords = matching_coords[0] + joined_index_vars = matching_coords[0] - for name, var in joined_index_coords.items(): - self.aligned_indexes[name] = joined_index - self.aligned_index_coords[name] = var + for name, var in joined_index_vars.items(): + aligned_indexes[name] = joined_index + aligned_index_vars[name] = var - def assert_unindexed_dim_sizes_equal(self): - aligned_indexes_dim_sizes = calculate_dimensions(self.aligned_index_coords) + self.aligned_indexes = Indexes(aligned_indexes, aligned_index_vars) + if reindex: + self.reindex_dims = {dim for dim in self.aligned_indexes.dims} + def assert_unindexed_dim_sizes_equal(self): for dim, sizes in self.unindexed_dim_sizes.items(): - if dim in aligned_indexes_dim_sizes: - sizes.add(aligned_indexes_dim_sizes[dim]) + index_size = self.aligned_indexes.dims.get(dim) + if index_size is not None: + sizes.add(index_size) + add_err_msg = ( + f" (note: indexed labels also found for dimension {dim!r} " + f"with size {index_size!r})" + ) + else: + add_err_msg = "" if len(sizes) > 1: raise ValueError( f"arguments without labels along dimension {dim!r} cannot be " f"aligned because they have different dimension sizes: {sizes!r}" + + add_err_msg ) - def align(self): + def reindex(self, copy, fill_value) -> Tuple[Union["Dataset", "DataArray"], ...]: + result = [] + + for obj in self.objects: + valid_indexers = {} + for dim in self.aligned_indexes.dims: + if ( + dim in obj.dims + and dim in self.reindex_dims + # TODO: default dim var instead? + and dim in self.aligned_indexes.variables + ): + valid_indexers[dim] = safe_cast_to_index( + self.aligned_indexes.variables[dim] + ) + if not valid_indexers: + # fast path for no reindexing necessary + new_obj = obj.copy(deep=copy) + else: + # TODO: propagate aligned indexes and index vars + new_obj = obj.reindex( + copy=copy, fill_value=fill_value, indexers=valid_indexers + ) + new_obj.encoding = obj.encoding + result.append(new_obj) + + return tuple(result) + + def align( + self, copy=True, fill_value=dtypes.NA + ) -> Tuple[Union["Dataset", "DataArray"], ...]: self.find_matching_indexes() self.find_matching_unindexed_dims() self.assert_no_index_conflict() self.align_indexes() self.assert_unindexed_dim_sizes_equal() + return self.reindex(copy, fill_value) def align( From d1bbc4a51486bf66bacd88a62e9d80aae7704366 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 15 Oct 2021 12:07:17 +0200 Subject: [PATCH 060/159] tweaks and fixes - Some fixes and clean-up in new implementation of align - Add Index.join method (only for inner/outer join) - Tweak index generic types - IndexVars type: use Variable instead of IndexVariable (IndexVariable may eventually be dropped) --- xarray/core/alignment.py | 104 +++++++++++++++++++++------------------ xarray/core/indexes.py | 82 +++++++++++++++++++++--------- xarray/core/indexing.py | 6 +-- xarray/core/types.py | 2 + 4 files changed, 120 insertions(+), 74 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 4c9b5bd9db4..30709f99ebd 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -81,12 +81,6 @@ def _override_indexes(objects, all_indexes, exclude): return objects -CoordNamesAndDims = FrozenSet[Tuple[Hashable, Tuple[Hashable, ...]]] -MatchingIndexKey = Tuple[CoordNamesAndDims, Type[Index]] -NormalizedIndexes = Dict[MatchingIndexKey, Index] -NormalizedCoords = Dict[MatchingIndexKey, Dict[Hashable, Variable]] - - class Alignator: """Implements all the complex logic for the alignment of Xarray objects. @@ -94,15 +88,21 @@ class Alignator: """ + CoordNamesAndDims = FrozenSet[Tuple[Hashable, Tuple[Hashable, ...]]] + MatchingIndexKey = Tuple[CoordNamesAndDims, Type[Index]] + NormalizedIndexes = Dict[MatchingIndexKey, Index] + NormalizedIndexVars = Dict[MatchingIndexKey, Dict[Hashable, Variable]] + AlignedObjects = Tuple[Union["Dataset", "DataArray"], ...] + objects: List[Union["Dataset", "DataArray"]] join: str exclude_dims: FrozenSet reindex_dims: Set indexes: Dict[MatchingIndexKey, Index] - coords: Dict[MatchingIndexKey, Dict[Hashable, Variable]] - all_indexes: Mapping[MatchingIndexKey, List[Index]] - all_coords: Mapping[MatchingIndexKey, List[Dict[Hashable, Variable]]] - unindexed_dim_sizes: Mapping[Hashable, Set] + index_vars: Dict[MatchingIndexKey, Dict[Hashable, Variable]] + all_indexes: Dict[MatchingIndexKey, List[Index]] + all_index_vars: Dict[MatchingIndexKey, List[Dict[Hashable, Variable]]] + unindexed_dim_sizes: Dict[Hashable, Set] aligned_indexes: Indexes[Index] def __init__( @@ -126,20 +126,20 @@ def __init__( if indexes is None: indexes = {} - self.indexes, self.coords = self._normalize_indexes(indexes) + self.indexes, self.index_vars = self._normalize_indexes(indexes) self.all_indexes = defaultdict(list) - self.all_coords = defaultdict(list) + self.all_index_vars = defaultdict(list) self.unindexed_dim_sizes = defaultdict(set) def _normalize_indexes( self, indexes: Mapping[Any, Any], - ) -> Tuple[NormalizedIndexes, NormalizedCoords]: + ) -> Tuple[NormalizedIndexes, NormalizedIndexVars]: """Normalize the indexes used for alignment. - Return dictionaries of xarray Index objects and coordinates such that we can - group matching indexes based on the dictionary keys. + Return dictionaries of xarray Index objects and coordinate variables + such that we can group matching indexes based on the dictionary keys. """ if isinstance(indexes, Indexes): @@ -157,12 +157,12 @@ def _normalize_indexes( xr_indexes[k] = idx normalized_indexes = {} - normalized_coords = {} - for idx, coords in Indexes(xr_indexes, variables).group_by_index(): + normalized_index_vars = {} + for idx, index_vars in Indexes(xr_indexes, variables).group_by_index(): coord_names_and_dims = [] all_dims = set() - for name, var in coords.items(): + for name, var in index_vars.items(): dims = var.dims coord_names_and_dims.append((name, dims)) all_dims.update(dims) @@ -181,16 +181,16 @@ def _normalize_indexes( key = (frozenset(coord_names_and_dims), type(idx)) normalized_indexes[key] = idx - normalized_coords[key] = coords + normalized_index_vars[key] = index_vars - return normalized_indexes, normalized_coords + return normalized_indexes, normalized_index_vars def find_matching_indexes(self): for obj in self.objects: - obj_indexes, obj_coords = self._normalize_indexes(obj.xindexes) + obj_indexes, obj_index_vars = self._normalize_indexes(obj.xindexes) for key, idx in obj_indexes.items(): self.all_indexes[key].append(idx) - self.all_coords[key].append(obj_coords[key]) + self.all_index_vars[key].append(obj_index_vars[key]) def find_matching_unindexed_dims(self): for obj in self.objects: @@ -230,7 +230,7 @@ def assert_no_index_conflict(self): "Conflicting indexes may occur when\n" "- they relate to different sets of coordinate and/or dimension names\n" "- they don't have the same type\n" - "- they are used to reindex data along common dimensions" + "- they may be used to reindex data along common dimensions" ) def _need_reindex(self, dims, index, other_indexes, coords, other_coords) -> bool: @@ -257,10 +257,8 @@ def _need_reindex(self, dims, index, other_indexes, coords, other_coords) -> boo return index_not_equal or has_unindexed_dims def _get_index_joiner(self, index_cls) -> Callable: - if self.join == "outer": - return functools.partial(functools.reduce, index_cls.union) - elif self.join == "inner": - return functools.partial(functools.reduce, index_cls.intersection) + if self.join in ["outer", "inner"]: + return functools.partial(functools.reduce, index_cls.join, how=self.join) elif self.join == "left": return operator.itemgetter(0) elif self.join == "right": @@ -270,46 +268,54 @@ def _get_index_joiner(self, index_cls) -> Callable: return operator.itemgetter(0) else: # join='exact' return dummy lambda (error is raised) - return lambda *args: None + return lambda _: None def align_indexes(self): aligned_indexes = {} aligned_index_vars = {} + reindex_dims = set() for key, matching_indexes in self.all_indexes.items(): - matching_coords = self.all_coords[key] - dims = set([d for coord in matching_coords[0].values() for d in coord.dims]) + matching_index_vars = self.all_index_vars[key] + dims = set( + [d for coord in matching_index_vars[0].values() for d in coord.dims] + ) index_cls = key[1] if key in self.indexes: joined_index = self.indexes[key] - joined_index_vars = self.coords[key] + joined_index_vars = self.index_vars[key] reindex = self._need_reindex( dims, joined_index, matching_indexes, joined_index_vars, - matching_coords, + matching_index_vars, ) else: reindex = self._need_reindex( dims, matching_indexes[0], matching_indexes[1:], - matching_coords[0], - matching_coords[1:], + matching_index_vars[0], + matching_index_vars[1:], ) if reindex: if self.join == "exact": # TODO: more informative error message - raise ValueError("indexes are not equal") + raise ValueError( + "cannot align objects with join='exact' where " + "index/labels/sizes are not equal along " + "these coordinates (dimensions): " + + ", ".join(f"{name!r} {dims!r}" for name, dims in key[0]) + ) joiner = self._get_index_joiner(index_cls) try: joined_index = joiner(matching_indexes) if self.join == "left": - joined_index_vars = matching_coords[0] + joined_index_vars = matching_index_vars[0] elif self.join == "right": - joined_index_vars = matching_coords[-1] + joined_index_vars = matching_index_vars[-1] else: joined_index_vars = joined_index.create_variables() except NotImplementedError: @@ -319,15 +325,17 @@ def align_indexes(self): ) else: joined_index = matching_indexes[0] - joined_index_vars = matching_coords[0] + joined_index_vars = matching_index_vars[0] for name, var in joined_index_vars.items(): aligned_indexes[name] = joined_index aligned_index_vars[name] = var + if reindex: + reindex_dims |= dims + self.aligned_indexes = Indexes(aligned_indexes, aligned_index_vars) - if reindex: - self.reindex_dims = {dim for dim in self.aligned_indexes.dims} + self.reindex_dims = reindex_dims def assert_unindexed_dim_sizes_equal(self): for dim, sizes in self.unindexed_dim_sizes.items(): @@ -347,7 +355,7 @@ def assert_unindexed_dim_sizes_equal(self): + add_err_msg ) - def reindex(self, copy, fill_value) -> Tuple[Union["Dataset", "DataArray"], ...]: + def reindex(self, copy: bool, fill_value: Any) -> AlignedObjects: result = [] for obj in self.objects: @@ -359,9 +367,7 @@ def reindex(self, copy, fill_value) -> Tuple[Union["Dataset", "DataArray"], ...] # TODO: default dim var instead? and dim in self.aligned_indexes.variables ): - valid_indexers[dim] = safe_cast_to_index( - self.aligned_indexes.variables[dim] - ) + valid_indexers[dim] = self.aligned_indexes.variables[dim] if not valid_indexers: # fast path for no reindexing necessary new_obj = obj.copy(deep=copy) @@ -375,9 +381,13 @@ def reindex(self, copy, fill_value) -> Tuple[Union["Dataset", "DataArray"], ...] return tuple(result) - def align( - self, copy=True, fill_value=dtypes.NA - ) -> Tuple[Union["Dataset", "DataArray"], ...]: + def align(self, copy: bool = True, fill_value: Any = dtypes.NA) -> AlignedObjects: + + if not self.indexes and len(self.objects) == 1: + # fast path for the trivial case + (obj,) = self.objects + return (obj.copy(deep=copy),) + self.find_matching_indexes() self.find_matching_unindexed_dims() self.assert_no_index_conflict() diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index c7c0e99dcee..7d205c9f9f3 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -22,12 +22,13 @@ from . import formatting, utils from .indexing import PandasIndexingAdapter, PandasMultiIndexingAdapter, QueryResult +from .types import T_Index from .utils import Frozen, is_dict_like, is_scalar if TYPE_CHECKING: - from .variable import IndexVariable, Variable + from .variable import Variable -IndexVars = Dict[Any, "IndexVariable"] +IndexVars = Dict[Any, "Variable"] class Index: @@ -40,7 +41,9 @@ def from_variables( raise NotImplementedError() def create_variables( - self, attrs: Mapping[Any, Any], encoding: Mapping[Any, Any] + self, + attrs: Optional[Mapping[Any, Any]] = None, + encoding: Optional[Mapping[Any, Any]] = None, ) -> IndexVars: return {} @@ -57,6 +60,9 @@ def to_pandas_index(self) -> pd.Index: def query(self, labels: Dict[Any, Any]) -> QueryResult: raise NotImplementedError() + def join(self: T_Index, other: T_Index, how: str = "inner") -> T_Index: + raise NotImplementedError() + def equals(self, other): # pragma: no cover raise NotImplementedError() @@ -222,10 +228,17 @@ def from_variables( return obj, {name: index_var} def create_variables( - self, attrs: Mapping[Any, Any], encoding: Mapping[Any, Any] + self, + attrs: Optional[Mapping[Any, Any]] = None, + encoding: Optional[Mapping[Any, Any]] = None, ) -> IndexVars: from .variable import IndexVariable + if attrs is None: + attrs = {} + if encoding is None: + encoding = {} + name = self.index.name data = PandasIndexingAdapter(self.index, dtype=self.coord_dtype) var = IndexVariable( @@ -318,8 +331,19 @@ def query(self, labels: Dict[Any, Any], method=None, tolerance=None) -> QueryRes return QueryResult({self.dim: indexer}) - def equals(self, other): - return self.index.equals(other.index) + def equals(self, other: Index): + if not isinstance(other, PandasIndex): + return False + return self.index.equals(other.index) and self.dim == other.dim + + def join(self, other: "PandasIndex", how: str = "inner") -> "PandasIndex": + # TODO: handle coord_dtype + # Move logic from ``utils.maybe_coerce_to_str`` here + if how == "outer": + return type(self)(self.index.union(other.index), self.dim) + else: + # how = "inner" + return type(self)(self.index.intersection(other.index), self.dim) def union(self, other): new_index = self.index.union(other.index) @@ -588,8 +612,15 @@ def from_pandas_index( return cls(index, dim, level_coords_dtype=level_coords_dtype), index_vars def create_variables( - self, attrs: Mapping[Any, Any], encoding: Mapping[Any, Any] + self, + attrs: Optional[Mapping[Any, Any]] = None, + encoding: Optional[Mapping[Any, Any]] = None, ) -> IndexVars: + if attrs is None: + attrs = {} + if encoding is None: + encoding = {} + var_meta = {} for name in self.index.names: var_meta[name] = { @@ -729,9 +760,7 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: indexes = cast(Dict[Any, Index], {k: new_index for k in new_vars}) # add scalar variable for each dropped level - variables = cast( - Dict[Hashable, Union["Variable", "IndexVariable"]], new_vars - ) + variables = new_vars for name, val in scalar_coord_values.items(): variables[name] = Variable([], val) @@ -790,11 +819,11 @@ def remove_unused_levels_categories(index: pd.Index) -> pd.Index: return index -# generic type that represents either pandas or xarray indexes -T_Index = TypeVar("T_Index") +# generic type that represents either a pandas or an xarray index +T_PandasOrXarrayIndex = TypeVar("T_PandasOrXarrayIndex") -class Indexes(collections.abc.Mapping, Generic[T_Index]): +class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): """Immutable proxy for Dataset or DataArrary indexes. Keys are coordinate names and values may correspond to either pandas or @@ -804,7 +833,7 @@ class Indexes(collections.abc.Mapping, Generic[T_Index]): """ - _indexes: Dict[Any, T_Index] + _indexes: Dict[Any, T_PandasOrXarrayIndex] _variables: Dict[Any, "Variable"] __slots__ = ( @@ -816,7 +845,11 @@ class Indexes(collections.abc.Mapping, Generic[T_Index]): "__id_coord_names", ) - def __init__(self, indexes: Dict[Any, T_Index], variables: Dict[Any, "Variable"]): + def __init__( + self, + indexes: Dict[Any, T_PandasOrXarrayIndex], + variables: Dict[Any, "Variable"], + ): """Constructor not for public consumption. Parameters @@ -832,7 +865,7 @@ def __init__(self, indexes: Dict[Any, T_Index], variables: Dict[Any, "Variable"] self._dims: Optional[Mapping[Hashable, int]] = None self.__coord_name_id: Optional[Dict[Any, int]] = None - self.__id_index: Optional[Dict[int, T_Index]] = None + self.__id_index: Optional[Dict[int, T_PandasOrXarrayIndex]] = None self.__id_coord_names: Optional[Dict[int, Tuple[Hashable, ...]]] = None @property @@ -842,7 +875,7 @@ def _coord_name_id(self) -> Dict[Any, int]: return self.__coord_name_id @property - def _id_index(self) -> Dict[int, T_Index]: + def _id_index(self) -> Dict[int, T_PandasOrXarrayIndex]: if self.__id_index is None: self.__id_index = {id(idx): idx for idx in self.get_unique()} return self.__id_index @@ -870,11 +903,11 @@ def dims(self) -> Mapping[Hashable, int]: return Frozen(self._dims) - def get_unique(self) -> List[T_Index]: + def get_unique(self) -> List[T_PandasOrXarrayIndex]: """Return a list of unique indexes, preserving order.""" - unique_indexes: List[T_Index] = [] - seen: Set[T_Index] = set() + unique_indexes: List[T_PandasOrXarrayIndex] = [] + seen: Set[T_PandasOrXarrayIndex] = set() for index in self._indexes.values(): if index not in seen: @@ -914,7 +947,9 @@ def get_all_coords( all_coord_names = self._id_coord_names[self._coord_name_id[coord_name]] return {k: self._variables[k] for k in all_coord_names} - def group_by_index(self) -> List[Tuple[T_Index, Dict[Hashable, "Variable"]]]: + def group_by_index( + self, + ) -> List[Tuple[T_PandasOrXarrayIndex, Dict[Hashable, "Variable"]]]: """Returns a list of unique indexes and their corresponding coordinates.""" index_coords = [] @@ -926,14 +961,15 @@ def group_by_index(self) -> List[Tuple[T_Index, Dict[Hashable, "Variable"]]]: return index_coords - def to_pandas_indexes(self): + def to_pandas_indexes(self) -> "Indexes[pd.Index]": """Returns an immutable proxy for Dataset or DataArrary pandas indexes. Raises an error if this proxy contains indexes that cannot be coerced to pandas.Index objects. """ - indexes = {} + indexes: Dict[Hashable, pd.Index] = {} + for k, idx in self._indexes.items(): if isinstance(idx, pd.Index): indexes[k] = idx diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 217b98d5c34..4e140e891ee 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from .indexes import Index - from .variable import IndexVariable, Variable + from .variable import Variable @dataclass @@ -64,9 +64,7 @@ class QueryResult: dim_indexers: Dict[Any, Any] indexes: Dict[Any, "Index"] = field(default_factory=dict) - variables: Dict[Any, Union["Variable", "IndexVariable"]] = field( - default_factory=dict - ) + variables: Dict[Any, "Variable"] = field(default_factory=dict) drop_coords: List[Hashable] = field(default_factory=list) drop_indexes: List[Hashable] = field(default_factory=list) rename_dims: Dict[Any, Hashable] = field(default_factory=dict) diff --git a/xarray/core/types.py b/xarray/core/types.py index 9f0f9eee54c..3f368501b25 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -9,6 +9,7 @@ from .dataarray import DataArray from .dataset import Dataset from .groupby import DataArrayGroupBy, GroupBy + from .indexes import Index from .npcompat import ArrayLike from .variable import Variable @@ -21,6 +22,7 @@ T_Dataset = TypeVar("T_Dataset", bound="Dataset") T_DataArray = TypeVar("T_DataArray", bound="DataArray") T_Variable = TypeVar("T_Variable", bound="Variable") +T_Index = TypeVar("T_Index", bound="Index") # Maybe we rename this to T_Data or something less Fortran-y? T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") From 06242045463ecbb37f05490fde4755e77100ab1e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 18 Oct 2021 17:49:38 +0200 Subject: [PATCH 061/159] wip refactor alignment and reindex Imported all re-indexing logic into the `Alignator` class, which can be reused directly for both `align()` and `obj.reindex()`. TODO: - import override indexes into `Alignator` - connect Alignator to public API entry points - clean-up (remove old implementation functions) --- xarray/core/alignment.py | 281 ++++++++++++++++++++++++++------------- xarray/core/indexes.py | 31 ++++- 2 files changed, 215 insertions(+), 97 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 30709f99ebd..7e0998a1ea8 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -12,7 +12,6 @@ List, Mapping, Optional, - Sequence, Set, Tuple, Type, @@ -24,7 +23,7 @@ import pandas as pd from . import dtypes -from .indexes import Index, Indexes, PandasIndex, get_indexer_nd +from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex, get_indexer_nd from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index from .variable import Variable @@ -82,7 +81,8 @@ def _override_indexes(objects, all_indexes, exclude): class Alignator: - """Implements all the complex logic for the alignment of Xarray objects. + """Implements all the complex logic for the re-indexing and alignment of Xarray + objects. For internal use only, not public API. @@ -94,71 +94,103 @@ class Alignator: NormalizedIndexVars = Dict[MatchingIndexKey, Dict[Hashable, Variable]] AlignedObjects = Tuple[Union["Dataset", "DataArray"], ...] - objects: List[Union["Dataset", "DataArray"]] + objects: Tuple[Union["Dataset", "DataArray"], ...] + objects_matching_indexes: Tuple[Dict[MatchingIndexKey, Index], ...] join: str exclude_dims: FrozenSet - reindex_dims: Set + copy: bool + fill_value: Any + sparse: bool indexes: Dict[MatchingIndexKey, Index] index_vars: Dict[MatchingIndexKey, Dict[Hashable, Variable]] all_indexes: Dict[MatchingIndexKey, List[Index]] all_index_vars: Dict[MatchingIndexKey, List[Dict[Hashable, Variable]]] + aligned_indexes: Dict[MatchingIndexKey, Index] + aligned_index_vars: Dict[MatchingIndexKey, Dict[Hashable, Variable]] + reindex: Dict[MatchingIndexKey, bool] + reindex_kwargs: Dict[str, Any] unindexed_dim_sizes: Dict[Hashable, Set] - aligned_indexes: Indexes[Index] + new_indexes: Indexes[Index] def __init__( self, objects: List[Union["Dataset", "DataArray"]], join: str, - indexes: Union[Mapping[Any, Any], None], - exclude: Union[str, Set, Sequence], + indexes_or_indexers: Optional[Mapping[Any, Any]] = None, + exclude: Any = frozenset(), + method: Optional[str] = None, + tolerance: Any = None, + copy: bool = True, + fill_value: Any = dtypes.NA, + sparse: bool = False, ): - self.objects = objects + self.objects = tuple(objects) + self.objects_matching_indexes = () if join not in ["inner", "outer", "overwrite", "exact", "left", "right"]: raise ValueError(f"invalid value for join: {join}") self.join = join + self.copy = copy + self.fill_value = fill_value + self.sparse = sparse + + if method is None and tolerance is None: + self.reindex_kwargs = {} + else: + self.reindex_kwargs = {"method": method, "tolerance": tolerance} + if isinstance(exclude, str): exclude = [exclude] self.exclude_dims = frozenset(exclude) - self.reindex_dims = set() + if indexes_or_indexers is None: + indexes_or_indexers = {} + self.indexes, self.index_vars = self._normalize_indexes(indexes_or_indexers) - if indexes is None: - indexes = {} - self.indexes, self.index_vars = self._normalize_indexes(indexes) + self.all_indexes = {} + self.all_index_vars = {} + self.unindexed_dim_sizes = {} - self.all_indexes = defaultdict(list) - self.all_index_vars = defaultdict(list) - self.unindexed_dim_sizes = defaultdict(set) + self.aligned_indexes = {} + self.aligned_index_vars = {} + self.reindex = {} def _normalize_indexes( self, - indexes: Mapping[Any, Any], + indexes_or_indexers: Mapping[Any, Any], ) -> Tuple[NormalizedIndexes, NormalizedIndexVars]: - """Normalize the indexes used for alignment. + """Normalize the indexes/indexers used for re-indexing or alignment. Return dictionaries of xarray Index objects and coordinate variables such that we can group matching indexes based on the dictionary keys. """ - if isinstance(indexes, Indexes): - variables = dict(indexes.variables) + if isinstance(indexes_or_indexers, Indexes): + xr_variables = dict(indexes_or_indexers.variables) else: - variables = {} + xr_variables = {} xr_indexes = {} - for k, idx in indexes.items(): + for k, idx in indexes_or_indexers.items(): if not isinstance(idx, Index): + if getattr(idx, "dims", (k,)) != (k,): + raise ValueError( + "Indexer has dimensions {:s} that are different " + "from that to be indexed along {:s}".format(str(idx.dims), k) + ) pd_idx = safe_cast_to_index(idx).copy() pd_idx.name = k - idx, _ = PandasIndex.from_pandas_index(pd_idx, k) - variables.update(idx.create_variables()) + if isinstance(pd_idx, pd.MultiIndex): + idx, _ = PandasMultiIndex.from_pandas_index(pd_idx, k) + else: + idx, _ = PandasIndex.from_pandas_index(pd_idx, k) + xr_variables.update(idx.create_variables()) xr_indexes[k] = idx normalized_indexes = {} normalized_index_vars = {} - for idx, index_vars in Indexes(xr_indexes, variables).group_by_index(): + for idx, index_vars in Indexes(xr_indexes, xr_variables).group_by_index(): coord_names_and_dims = [] all_dims = set() @@ -186,24 +218,37 @@ def _normalize_indexes( return normalized_indexes, normalized_index_vars def find_matching_indexes(self): + all_indexes = defaultdict(list) + all_index_vars = defaultdict(list) + objects_matching_indexes = [] + for obj in self.objects: obj_indexes, obj_index_vars = self._normalize_indexes(obj.xindexes) + objects_matching_indexes.append(obj_indexes) for key, idx in obj_indexes.items(): - self.all_indexes[key].append(idx) - self.all_index_vars[key].append(obj_index_vars[key]) + all_indexes[key].append(idx) + all_index_vars[key].append(obj_index_vars[key]) + + self.objects_matching_indexes = tuple(objects_matching_indexes) + self.all_indexes = all_indexes + self.all_index_vars = all_index_vars def find_matching_unindexed_dims(self): + unindexed_dim_sizes = defaultdict(set) + for obj in self.objects: for dim in obj.dims: if dim not in self.exclude_dims and dim not in obj.xindexes.dims: - self.unindexed_dim_sizes[dim].add(obj.sizes[dim]) + unindexed_dim_sizes[dim].add(obj.sizes[dim]) + + self.unindexed_dim_sizes = unindexed_dim_sizes def assert_no_index_conflict(self): """Check for uniqueness of both coordinate and dimension names accross all sets of matching indexes. - We need to make sure that all indexes used for alignment are fully compatible - and do not conflict each other. + We need to make sure that all indexes used for re-indexing or alignment + are fully compatible and do not conflict each other. """ matching_keys = set(self.all_indexes) | set(self.indexes) @@ -225,7 +270,7 @@ def assert_no_index_conflict(self): f"{k} ({v} conflicting indexes)" for k, v in dup.items() ) raise ValueError( - "cannot align objects with conflicting indexes found for " + "cannot re-index or align objects with conflicting indexes found for " f"the following {msg}: {items_msg}\n" "Conflicting indexes may occur when\n" "- they relate to different sets of coordinate and/or dimension names\n" @@ -273,7 +318,9 @@ def _get_index_joiner(self, index_cls) -> Callable: def align_indexes(self): aligned_indexes = {} aligned_index_vars = {} - reindex_dims = set() + reindex = {} + new_indexes = {} + new_index_vars = {} for key, matching_indexes in self.all_indexes.items(): matching_index_vars = self.all_index_vars[key] @@ -285,7 +332,7 @@ def align_indexes(self): if key in self.indexes: joined_index = self.indexes[key] joined_index_vars = self.index_vars[key] - reindex = self._need_reindex( + need_reindex = self._need_reindex( dims, joined_index, matching_indexes, @@ -293,14 +340,17 @@ def align_indexes(self): matching_index_vars, ) else: - reindex = self._need_reindex( - dims, - matching_indexes[0], - matching_indexes[1:], - matching_index_vars[0], - matching_index_vars[1:], - ) - if reindex: + if len(matching_indexes) > 1: + need_reindex = self._need_reindex( + dims, + matching_indexes[0], + matching_indexes[1:], + matching_index_vars[0], + matching_index_vars[1:], + ) + else: + need_reindex = False + if need_reindex: if self.join == "exact": # TODO: more informative error message raise ValueError( @@ -310,90 +360,141 @@ def align_indexes(self): + ", ".join(f"{name!r} {dims!r}" for name, dims in key[0]) ) joiner = self._get_index_joiner(index_cls) - try: - joined_index = joiner(matching_indexes) - if self.join == "left": - joined_index_vars = matching_index_vars[0] - elif self.join == "right": - joined_index_vars = matching_index_vars[-1] - else: - joined_index_vars = joined_index.create_variables() - except NotImplementedError: - raise TypeError( - f"{index_cls.__qualname__} doesn't support alignment " - "with inner/outer join method" - ) + joined_index = joiner(matching_indexes) + if self.join == "left": + joined_index_vars = matching_index_vars[0] + elif self.join == "right": + joined_index_vars = matching_index_vars[-1] + else: + joined_index_vars = joined_index.create_variables() else: joined_index = matching_indexes[0] joined_index_vars = matching_index_vars[0] - for name, var in joined_index_vars.items(): - aligned_indexes[name] = joined_index - aligned_index_vars[name] = var + reindex[key] = need_reindex + aligned_indexes[key] = joined_index + aligned_index_vars[key] = joined_index_vars - if reindex: - reindex_dims |= dims + for name, var in joined_index_vars.items(): + new_indexes[name] = joined_index + new_index_vars[name] = var - self.aligned_indexes = Indexes(aligned_indexes, aligned_index_vars) - self.reindex_dims = reindex_dims + self.aligned_indexes = aligned_indexes + self.aligned_index_vars = aligned_index_vars + self.reindex = reindex + self.new_indexes = Indexes(new_indexes, new_index_vars) def assert_unindexed_dim_sizes_equal(self): for dim, sizes in self.unindexed_dim_sizes.items(): - index_size = self.aligned_indexes.dims.get(dim) + index_size = self.new_indexes.dims.get(dim) if index_size is not None: sizes.add(index_size) add_err_msg = ( - f" (note: indexed labels also found for dimension {dim!r} " + f" (note: an index is found for dimension {dim!r} " f"with size {index_size!r})" ) else: add_err_msg = "" if len(sizes) > 1: raise ValueError( - f"arguments without labels along dimension {dim!r} cannot be " - f"aligned because they have different dimension sizes: {sizes!r}" - + add_err_msg + f"cannot reindex or align along dimension {dim!r} without labels " + f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg ) - def reindex(self, copy: bool, fill_value: Any) -> AlignedObjects: - result = [] + def _reindex_one(self, obj, matching_indexes): + from .dataarray import DataArray - for obj in self.objects: - valid_indexers = {} - for dim in self.aligned_indexes.dims: - if ( - dim in obj.dims - and dim in self.reindex_dims - # TODO: default dim var instead? - and dim in self.aligned_indexes.variables - ): - valid_indexers[dim] = self.aligned_indexes.variables[dim] - if not valid_indexers: - # fast path for no reindexing necessary - new_obj = obj.copy(deep=copy) + new_variables = {} + new_indexes = {} + dim_reindexers = {} + + for key, aligned_idx in self.aligned_indexes.items(): + obj_idx = matching_indexes.get(key) + if obj_idx is not None: + for name, var in self.aligned_index_vars[key].items(): + new_indexes[name] = aligned_idx + new_variables[name] = var + if self.reindex[key]: + indexers = obj_idx.reindex_like(aligned_idx, **self.reindex_kwargs) + dim_reindexers.update(indexers) + + if not dim_reindexers: + # fast path for no reindexing necessary + new_obj = obj.copy(deep=self.copy) + else: + if isinstance(obj, DataArray): + ds_obj = obj._to_temp_dataset() else: - # TODO: propagate aligned indexes and index vars - new_obj = obj.reindex( - copy=copy, fill_value=fill_value, indexers=valid_indexers + ds_obj = obj + + # Negative values in dim_indexers mean values missing in the new index + masked_dims = [(indxr < 0).any() for indxr in dim_reindexers] + unchanged_dims = [dim not in dim_reindexers for dim in obj.dims] + + for name, var in ds_obj.variables.items(): + if name in new_variables: + continue + + if isinstance(self.fill_value, dict): + fill_value = self.fill_value.get(name, dtypes.NA) + else: + fill_value = self.fill_value + + if self.sparse: + var = var._as_sparse(fill_value=fill_value) + key = tuple( + slice(None) + if d in unchanged_dims + else dim_reindexers.get(d, slice(None)) + for d in var.dims ) - new_obj.encoding = obj.encoding - result.append(new_obj) + needs_masking = any(d in masked_dims for d in var.dims) + + if needs_masking: + new_var = var._getitem_with_mask(key, fill_value=fill_value) + elif all(is_full_slice(k) for k in key): + # no reindexing necessary + # here we need to manually deal with copying data, since + # we neither created a new ndarray nor used fancy indexing + new_var = var.copy(deep=self.copy) + else: + new_var = var[key] - return tuple(result) + new_variables[name] = new_var + + new_coord_names = ds_obj._coord_names | set(new_indexes) + new_ds_obj = ds_obj._replace_with_new_dims( + new_variables, new_coord_names, indexes=new_indexes + ) + + if isinstance(obj, DataArray): + new_obj = obj._from_temp_dataset(new_ds_obj) + else: + new_obj = new_ds_obj + + new_obj.encoding = obj.encoding + return new_obj - def align(self, copy: bool = True, fill_value: Any = dtypes.NA) -> AlignedObjects: + def reindex_all(self) -> AlignedObjects: + result = [] + + for obj, matching_indexes in zip(self.objects, self.objects_matching_indexes): + result.append(self._reindex_one(obj, matching_indexes)) + + return tuple(result) + def align(self) -> AlignedObjects: if not self.indexes and len(self.objects) == 1: # fast path for the trivial case (obj,) = self.objects - return (obj.copy(deep=copy),) + return (obj.copy(deep=self.copy),) self.find_matching_indexes() self.find_matching_unindexed_dims() self.assert_no_index_conflict() self.align_indexes() self.assert_unindexed_dim_sizes_equal() - return self.reindex(copy, fill_value) + return self.reindex_all() def align( diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 7d205c9f9f3..1f67f4c3e46 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -55,13 +55,18 @@ def to_pandas_index(self) -> pd.Index: pandas.Index object. """ - raise TypeError(f"{type(self)} cannot be cast to a pandas.Index object.") + raise TypeError(f"{self!r} cannot be cast to a pandas.Index object") def query(self, labels: Dict[Any, Any]) -> QueryResult: - raise NotImplementedError() + raise NotImplementedError(f"{self!r} doesn't support label-based selection") def join(self: T_Index, other: T_Index, how: str = "inner") -> T_Index: - raise NotImplementedError() + raise NotImplementedError( + f"{self!r} doesn't support alignment with inner/outer join method" + ) + + def reindex_like(self: T_Index, other: T_Index) -> Dict[Hashable, Any]: + raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") def equals(self, other): # pragma: no cover raise NotImplementedError() @@ -337,13 +342,14 @@ def equals(self, other: Index): return self.index.equals(other.index) and self.dim == other.dim def join(self, other: "PandasIndex", how: str = "inner") -> "PandasIndex": - # TODO: handle coord_dtype - # Move logic from ``utils.maybe_coerce_to_str`` here if how == "outer": - return type(self)(self.index.union(other.index), self.dim) + index = self.index.union(other.index) else: # how = "inner" - return type(self)(self.index.intersection(other.index), self.dim) + index = self.index.intersection(other.index) + + coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype).type + return type(self)(index, self.dim, coord_dtype=coord_dtype) def union(self, other): new_index = self.index.union(other.index) @@ -353,6 +359,17 @@ def intersection(self, other): new_index = self.index.intersection(other.index) return type(self)(new_index, self.dim) + def reindex_like( + self, other: "PandasIndex", method=None, tolerance=None + ) -> Dict[Hashable, Any]: + if not self.index.is_unique: + raise ValueError( + f"cannot reindex or align along dimension {self.dim!r} because the " + "(pandas) index has duplicate values" + ) + + return {self.dim: get_indexer_nd(self.index, other.index, method, tolerance)} + def rename(self, name_dict, dims_dict): if self.index.name not in name_dict and self.dim not in dims_dict: return self, {} From 30fee83913c78439ace594b47f10483102d4f10d Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 19 Oct 2021 16:10:50 +0200 Subject: [PATCH 062/159] wip refactor alignment / reindex: fixes and tweaks --- xarray/core/alignment.py | 35 ++++++++++++++++++++++++----------- xarray/core/indexes.py | 24 ++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 7e0998a1ea8..e3001491774 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -85,10 +85,13 @@ class Alignator: objects. For internal use only, not public API. + Usage: + + aligned_objects = Alignator(objects, **kwargs).align() """ - CoordNamesAndDims = FrozenSet[Tuple[Hashable, Tuple[Hashable, ...]]] + CoordNamesAndDims = Tuple[Tuple[Hashable, Tuple[Hashable, ...]], ...] MatchingIndexKey = Tuple[CoordNamesAndDims, Type[Index]] NormalizedIndexes = Dict[MatchingIndexKey, Index] NormalizedIndexVars = Dict[MatchingIndexKey, Dict[Hashable, Variable]] @@ -171,7 +174,7 @@ def _normalize_indexes( else: xr_variables = {} - xr_indexes = {} + xr_indexes: Dict[Hashable, Index] = {} for k, idx in indexes_or_indexers.items(): if not isinstance(idx, Index): if getattr(idx, "dims", (k,)) != (k,): @@ -211,7 +214,7 @@ def _normalize_indexes( f"{incl_dims_str}" ) - key = (frozenset(coord_names_and_dims), type(idx)) + key = (tuple(coord_names_and_dims), type(idx)) normalized_indexes[key] = idx normalized_index_vars[key] = index_vars @@ -250,6 +253,12 @@ def assert_no_index_conflict(self): We need to make sure that all indexes used for re-indexing or alignment are fully compatible and do not conflict each other. + Note: perhaps we could choose less restrictive constraints and instead + check for conflicts among the dimension (position) indexers returned by + `Index.reindex_like()` for each matching pair of object index / aligned + index? + (ref: https://github.com/pydata/xarray/issues/1603#issuecomment-442965602) + """ matching_keys = set(self.all_indexes) | set(self.indexes) @@ -259,7 +268,7 @@ def assert_no_index_conflict(self): dims_set = set() for name, dims in coord_names_dims: coord_count[name] += 1 - dims_set |= dims + dims_set.update(dims) for dim in dims_set: dim_count[dim] += 1 @@ -267,7 +276,7 @@ def assert_no_index_conflict(self): dup = {k: v for k, v in count.items() if v > 1} if dup: items_msg = ", ".join( - f"{k} ({v} conflicting indexes)" for k, v in dup.items() + f"{k!r} ({v} conflicting indexes)" for k, v in dup.items() ) raise ValueError( "cannot re-index or align objects with conflicting indexes found for " @@ -303,7 +312,10 @@ def _need_reindex(self, dims, index, other_indexes, coords, other_coords) -> boo def _get_index_joiner(self, index_cls) -> Callable: if self.join in ["outer", "inner"]: - return functools.partial(functools.reduce, index_cls.join, how=self.join) + return functools.partial( + functools.reduce, + functools.partial(index_cls.join, how=self.join), + ) elif self.join == "left": return operator.itemgetter(0) elif self.join == "right": @@ -352,7 +364,6 @@ def align_indexes(self): need_reindex = False if need_reindex: if self.join == "exact": - # TODO: more informative error message raise ValueError( "cannot align objects with join='exact' where " "index/labels/sizes are not equal along " @@ -390,14 +401,14 @@ def assert_unindexed_dim_sizes_equal(self): if index_size is not None: sizes.add(index_size) add_err_msg = ( - f" (note: an index is found for dimension {dim!r} " - f"with size {index_size!r})" + f" (note: an index is found along that dimension " + f"with size={index_size!r})" ) else: add_err_msg = "" if len(sizes) > 1: raise ValueError( - f"cannot reindex or align along dimension {dim!r} without labels " + f"cannot reindex or align along unlabeled dimension {dim!r} " f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg ) @@ -428,7 +439,9 @@ def _reindex_one(self, obj, matching_indexes): ds_obj = obj # Negative values in dim_indexers mean values missing in the new index - masked_dims = [(indxr < 0).any() for indxr in dim_reindexers] + masked_dims = [ + dim for dim, indxr in dim_reindexers.items() if (indxr < 0).any() + ] unchanged_dims = [dim not in dim_reindexers for dim in obj.dims] for name, var in ds_obj.variables.items(): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 1f67f4c3e46..fa33a92e55e 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -341,14 +341,16 @@ def equals(self, other: Index): return False return self.index.equals(other.index) and self.dim == other.dim - def join(self, other: "PandasIndex", how: str = "inner") -> "PandasIndex": + def join( + self: "PandasIndex", other: "PandasIndex", how: str = "inner" + ) -> "PandasIndex": if how == "outer": index = self.index.union(other.index) else: # how = "inner" index = self.index.intersection(other.index) - coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype).type + coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype) return type(self)(index, self.dim, coord_dtype=coord_dtype) def union(self, other): @@ -793,6 +795,24 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: else: return QueryResult({self.dim: indexer}) + def join(self, other, how: str = "inner"): + if how == "outer": + # bug in pandas? need to reset index.name + other_index = other.index.copy() + other_index.name = None + index = self.index.union(other_index) + index.name = self.dim + else: + # how = "inner" + index = self.index.intersection(other.index) + + level_coords_dtype = { + k: np.result_type(lvl_dtype, other.level_coords_dtype[k]) + for k, lvl_dtype in self.level_coords_dtype.items() + } + + return type(self)(index, self.dim, level_coords_dtype=level_coords_dtype) + def rename(self, name_dict, dims_dict): if not set(self.index.names) & set(name_dict) and self.dim not in dims_dict: return self, {} From 981fa0deaff7cd99f0bbe7ec21baa1883e802bfa Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 19 Oct 2021 17:27:20 +0200 Subject: [PATCH 063/159] wip refactor alignment (support join='override') --- xarray/core/alignment.py | 50 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index e3001491774..3e120500517 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -25,7 +25,7 @@ from . import dtypes from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex, get_indexer_nd from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index -from .variable import Variable +from .variable import Variable, calculate_dimensions if TYPE_CHECKING: from .common import DataWithCoords @@ -130,7 +130,7 @@ def __init__( self.objects = tuple(objects) self.objects_matching_indexes = () - if join not in ["inner", "outer", "overwrite", "exact", "left", "right"]: + if join not in ["inner", "outer", "override", "exact", "left", "right"]: raise ValueError(f"invalid value for join: {join}") self.join = join @@ -223,6 +223,7 @@ def _normalize_indexes( def find_matching_indexes(self): all_indexes = defaultdict(list) all_index_vars = defaultdict(list) + all_indexes_dim_sizes = defaultdict(lambda: defaultdict(set)) objects_matching_indexes = [] for obj in self.objects: @@ -230,12 +231,24 @@ def find_matching_indexes(self): objects_matching_indexes.append(obj_indexes) for key, idx in obj_indexes.items(): all_indexes[key].append(idx) - all_index_vars[key].append(obj_index_vars[key]) + for key, index_vars in obj_index_vars.items(): + all_index_vars[key].append(index_vars) + for dim, size in calculate_dimensions(index_vars).items(): + all_indexes_dim_sizes[key][dim].add(size) self.objects_matching_indexes = tuple(objects_matching_indexes) self.all_indexes = all_indexes self.all_index_vars = all_index_vars + if self.join == "override": + for dim_sizes in all_indexes_dim_sizes.values(): + for dim, sizes in dim_sizes.items(): + if len(sizes) > 1: + raise ValueError( + "cannot align objects with join='override' with matching indexes " + f"along dimension {dim!r} that don't have the same size." + ) + def find_matching_unindexed_dims(self): unindexed_dim_sizes = defaultdict(set) @@ -341,7 +354,11 @@ def align_indexes(self): ) index_cls = key[1] - if key in self.indexes: + if self.join == "override": + joined_index = matching_indexes[0] + joined_index_vars = matching_index_vars[0] + need_reindex = False + elif key in self.indexes: joined_index = self.indexes[key] joined_index_vars = self.index_vars[key] need_reindex = self._need_reindex( @@ -412,6 +429,25 @@ def assert_unindexed_dim_sizes_equal(self): f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg ) + def override_indexes(self) -> AlignedObjects: + objects = list(self.objects) + + for i, obj in enumerate(objects[1:]): + new_indexes = {} + new_variables = {} + matching_indexes = self.objects_matching_indexes[i + 1] + + for key, aligned_idx in self.aligned_indexes.items(): + obj_idx = matching_indexes.get(key) + if obj_idx is not None: + for name, var in self.aligned_index_vars[key].items(): + new_indexes[name] = aligned_idx + new_variables[name] = var + + objects[i + 1] = obj._overwrite_indexes(new_indexes, new_variables) + + return tuple(objects) + def _reindex_one(self, obj, matching_indexes): from .dataarray import DataArray @@ -507,7 +543,11 @@ def align(self) -> AlignedObjects: self.assert_no_index_conflict() self.align_indexes() self.assert_unindexed_dim_sizes_equal() - return self.reindex_all() + + if self.join == "override": + return self.override_indexes() + else: + return self.reindex_all() def align( From 79b52237b7dfb71042f0d72a6faffdd4dc34b993 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 20 Oct 2021 17:52:57 +0200 Subject: [PATCH 064/159] wip alignment: clean-up + tests I will fix mypy errors later. --- xarray/core/alignment.py | 261 +++++++++++-------------------- xarray/core/dataarray.py | 43 +++-- xarray/core/dataset.py | 46 +++--- xarray/core/merge.py | 25 ++- xarray/tests/test_computation.py | 2 +- xarray/tests/test_dataarray.py | 46 ++++-- xarray/tests/test_dataset.py | 18 ++- xarray/tests/test_indexes.py | 94 ++++++++--- xarray/tests/test_merge.py | 2 +- 9 files changed, 257 insertions(+), 280 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 3e120500517..74041bbbca7 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -2,6 +2,7 @@ import operator from collections import defaultdict from contextlib import suppress +from numbers import Number from typing import ( TYPE_CHECKING, Any, @@ -12,11 +13,13 @@ List, Mapping, Optional, + Sequence, Set, Tuple, Type, TypeVar, Union, + cast, ) import numpy as np @@ -24,70 +27,26 @@ from . import dtypes from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex, get_indexer_nd -from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index +from .utils import is_dict_like, is_full_slice, safe_cast_to_index from .variable import Variable, calculate_dimensions if TYPE_CHECKING: - from .common import DataWithCoords from .dataarray import DataArray from .dataset import Dataset - DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords) - - -def _get_joiner(join, index_cls): - if join == "outer": - return functools.partial(functools.reduce, index_cls.union) - elif join == "inner": - return functools.partial(functools.reduce, index_cls.intersection) - elif join == "left": - return operator.itemgetter(0) - elif join == "right": - return operator.itemgetter(-1) - elif join == "exact": - # We cannot return a function to "align" in this case, because it needs - # access to the dimension name to give a good error message. - return None - elif join == "override": - # We rewrite all indexes and then use join='left' - return operator.itemgetter(0) - else: - raise ValueError(f"invalid value for join: {join}") - - -def _override_indexes(objects, all_indexes, exclude): - for dim, dim_indexes in all_indexes.items(): - if dim not in exclude: - lengths = { - getattr(index, "size", index.to_pandas_index().size) - for index in dim_indexes - } - if len(lengths) != 1: - raise ValueError( - f"Indexes along dimension {dim!r} don't have the same length." - " Cannot use join='override'." - ) - - objects = list(objects) - for idx, obj in enumerate(objects[1:]): - new_indexes = { - dim: all_indexes[dim][0] for dim in obj.xindexes if dim not in exclude - } - - # TODO: benbovy - explicit indexes: not refactored yet! - objects[idx + 1] = obj._overwrite_indexes(new_indexes) + T_Xarray = TypeVar("T_Xarray", bound=Union[Dataset, DataArray]) + DataAlignable = Sequence[Union[Dataset, DataArray]] + DataAligned = Tuple[Union[Dataset, DataArray], ...] - return objects - -class Alignator: +class Aligner: """Implements all the complex logic for the re-indexing and alignment of Xarray objects. For internal use only, not public API. Usage: - aligned_objects = Alignator(objects, **kwargs).align() + aligned_objects = Alignator(*objects, **kwargs).align() """ @@ -95,9 +54,8 @@ class Alignator: MatchingIndexKey = Tuple[CoordNamesAndDims, Type[Index]] NormalizedIndexes = Dict[MatchingIndexKey, Index] NormalizedIndexVars = Dict[MatchingIndexKey, Dict[Hashable, Variable]] - AlignedObjects = Tuple[Union["Dataset", "DataArray"], ...] - objects: Tuple[Union["Dataset", "DataArray"], ...] + objects: "DataAlignable" objects_matching_indexes: Tuple[Dict[MatchingIndexKey, Index], ...] join: str exclude_dims: FrozenSet @@ -117,12 +75,12 @@ class Alignator: def __init__( self, - objects: List[Union["Dataset", "DataArray"]], - join: str, - indexes_or_indexers: Optional[Mapping[Any, Any]] = None, + objects: "DataAlignable", + join: str = "inner", + indexes: Mapping[Any, Any] = None, exclude: Any = frozenset(), - method: Optional[str] = None, - tolerance: Any = None, + method: str = None, + tolerance: Number = None, copy: bool = True, fill_value: Any = dtypes.NA, sparse: bool = False, @@ -147,9 +105,9 @@ def __init__( exclude = [exclude] self.exclude_dims = frozenset(exclude) - if indexes_or_indexers is None: - indexes_or_indexers = {} - self.indexes, self.index_vars = self._normalize_indexes(indexes_or_indexers) + if indexes is None: + indexes = {} + self.indexes, self.index_vars = self._normalize_indexes(indexes) self.all_indexes = {} self.all_index_vars = {} @@ -161,7 +119,7 @@ def __init__( def _normalize_indexes( self, - indexes_or_indexers: Mapping[Any, Any], + indexes: Mapping[Any, Any], ) -> Tuple[NormalizedIndexes, NormalizedIndexVars]: """Normalize the indexes/indexers used for re-indexing or alignment. @@ -169,18 +127,18 @@ def _normalize_indexes( such that we can group matching indexes based on the dictionary keys. """ - if isinstance(indexes_or_indexers, Indexes): - xr_variables = dict(indexes_or_indexers.variables) + if isinstance(indexes, Indexes): + xr_variables = dict(indexes.variables) else: xr_variables = {} xr_indexes: Dict[Hashable, Index] = {} - for k, idx in indexes_or_indexers.items(): + for k, idx in indexes.items(): if not isinstance(idx, Index): if getattr(idx, "dims", (k,)) != (k,): raise ValueError( - "Indexer has dimensions {:s} that are different " - "from that to be indexed along {:s}".format(str(idx.dims), k) + f"Indexer has dimensions {idx.dims} that are different " + f"from that to be indexed along '{k}'" ) pd_idx = safe_cast_to_index(idx).copy() pd_idx.name = k @@ -246,7 +204,7 @@ def find_matching_indexes(self): if len(sizes) > 1: raise ValueError( "cannot align objects with join='override' with matching indexes " - f"along dimension {dim!r} that don't have the same size." + f"along dimension {dim!r} that don't have the same size" ) def find_matching_unindexed_dims(self): @@ -429,7 +387,7 @@ def assert_unindexed_dim_sizes_equal(self): f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg ) - def override_indexes(self) -> AlignedObjects: + def override_indexes(self) -> "DataAligned": objects = list(self.objects) for i, obj in enumerate(objects[1:]): @@ -448,17 +406,27 @@ def override_indexes(self) -> AlignedObjects: return tuple(objects) - def _reindex_one(self, obj, matching_indexes): + def _reindex_one(self, obj: "T_Xarray", matching_indexes) -> "T_Xarray": from .dataarray import DataArray new_variables = {} new_indexes = {} dim_reindexers = {} + added_new_indexes = False for key, aligned_idx in self.aligned_indexes.items(): + index_vars = self.aligned_index_vars[key] obj_idx = matching_indexes.get(key) + if obj_idx is None: + # add the index if it relates to unindexed dimensions in obj + index_vars_dims = set( + d for var in index_vars.values() for d in var.dims + ) + if index_vars_dims <= set(obj.dims): + obj_idx = aligned_idx + added_new_indexes = True if obj_idx is not None: - for name, var in self.aligned_index_vars[key].items(): + for name, var in index_vars.items(): new_indexes[name] = aligned_idx new_variables[name] = var if self.reindex[key]: @@ -468,6 +436,8 @@ def _reindex_one(self, obj, matching_indexes): if not dim_reindexers: # fast path for no reindexing necessary new_obj = obj.copy(deep=self.copy) + if added_new_indexes: + new_obj = new_obj._overwrite_indexes(new_indexes, new_variables) else: if isinstance(obj, DataArray): ds_obj = obj._to_temp_dataset() @@ -524,7 +494,7 @@ def _reindex_one(self, obj, matching_indexes): new_obj.encoding = obj.encoding return new_obj - def reindex_all(self) -> AlignedObjects: + def reindex_all(self) -> "DataAligned": result = [] for obj, matching_indexes in zip(self.objects, self.objects_matching_indexes): @@ -532,7 +502,7 @@ def reindex_all(self) -> AlignedObjects: return tuple(result) - def align(self) -> AlignedObjects: + def align(self) -> "DataAligned": if not self.indexes and len(self.objects) == 1: # fast path for the trivial case (obj,) = self.objects @@ -551,13 +521,13 @@ def align(self) -> AlignedObjects: def align( - *objects: "DataAlignable", + *objects: Union["Dataset", "DataArray"], join="inner", copy=True, indexes=None, exclude=frozenset(), fill_value=dtypes.NA, -) -> Tuple["DataAlignable", ...]: +) -> "DataAligned": """ Given any number of Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -746,107 +716,16 @@ def align( * lon (lon) float64 100.0 120.0 """ - if indexes is None: - indexes = {} - - if not indexes and len(objects) == 1: - # fast path for the trivial case - (obj,) = objects - return (obj.copy(deep=copy),) - - all_indexes = defaultdict(list) - all_coords = defaultdict(list) - unlabeled_dim_sizes = defaultdict(set) - for obj in objects: - for dim in obj.dims: - if dim not in exclude: - all_coords[dim].append(obj.coords[dim]) - try: - index = obj.xindexes[dim] - except KeyError: - unlabeled_dim_sizes[dim].add(obj.sizes[dim]) - else: - all_indexes[dim].append(index) - - if join == "override": - objects = _override_indexes(objects, all_indexes, exclude) - - # We don't reindex over dimensions with all equal indexes for two reasons: - # - It's faster for the usual case (already aligned objects). - # - It ensures it's possible to do operations that don't require alignment - # on indexes with duplicate values (which cannot be reindexed with - # pandas). This is useful, e.g., for overwriting such duplicate indexes. - joined_indexes = {} - for dim, matching_indexes in all_indexes.items(): - if dim in indexes: - index, _ = PandasIndex.from_pandas_index( - safe_cast_to_index(indexes[dim]), dim - ) - if ( - any(not index.equals(other) for other in matching_indexes) - or dim in unlabeled_dim_sizes - ): - joined_indexes[dim] = indexes[dim] - else: - if ( - any( - not matching_indexes[0].equals(other) - for other in matching_indexes[1:] - ) - or dim in unlabeled_dim_sizes - ): - if join == "exact": - raise ValueError(f"indexes along dimension {dim!r} are not equal") - joiner = _get_joiner(join, type(matching_indexes[0])) - index = joiner(matching_indexes) - # make sure str coords are not cast to object - index = maybe_coerce_to_str(index.to_pandas_index(), all_coords[dim]) - joined_indexes[dim] = index - else: - index = all_coords[dim][0] - - if dim in unlabeled_dim_sizes: - unlabeled_sizes = unlabeled_dim_sizes[dim] - # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 - if isinstance(index, PandasIndex): - labeled_size = index.to_pandas_index().size - else: - labeled_size = index.size - if len(unlabeled_sizes | {labeled_size}) > 1: - raise ValueError( - f"arguments without labels along dimension {dim!r} cannot be " - f"aligned because they have different dimension size(s) {unlabeled_sizes!r} " - f"than the size of the aligned dimension labels: {labeled_size!r}" - ) - - for dim, sizes in unlabeled_dim_sizes.items(): - if dim not in all_indexes and len(sizes) > 1: - raise ValueError( - f"arguments without labels along dimension {dim!r} cannot be " - f"aligned because they have different dimension sizes: {sizes!r}" - ) - - result = [] - for obj in objects: - # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 - valid_indexers = {} - for k, index in joined_indexes.items(): - if k in obj.dims: - if isinstance(index, Index): - valid_indexers[k] = index.to_pandas_index() - else: - valid_indexers[k] = index - if not valid_indexers: - # fast path for no reindexing necessary - new_obj = obj.copy(deep=copy) - else: - new_obj = obj.reindex( - copy=copy, fill_value=fill_value, indexers=valid_indexers - ) - new_obj.encoding = obj.encoding - result.append(new_obj) + aligner = Aligner( + objects, + join=join, + copy=copy, + indexes=indexes, + exclude=exclude, + fill_value=fill_value, + ) - return tuple(result) + return aligner.align() def deep_align( @@ -932,6 +811,40 @@ def is_alignable(obj): return out +def reindex( + obj: "T_Xarray", + indexers: Mapping[Any, Any], + method: str = None, + tolerance: Number = None, + copy: bool = True, + fill_value: Any = dtypes.NA, + sparse: bool = False, +) -> "T_Xarray": + """Re-index either a Dataset or a DataArray. + + Not public API. + + """ + bad_keys = [k for k in indexers if k not in obj.xindexes and k not in obj.dims] + if bad_keys: + raise ValueError( + f"indexer keys {bad_keys} do not correspond to any indexed coordinate " + "or unindexed dimension in the object to reindex" + ) + + aligner = Aligner( + (obj,), + indexes=indexers, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, + sparse=sparse, + ) + + return cast("T_Xarray", aligner.align()[0]) + + def reindex_like_indexers( target: "Union[DataArray, Dataset]", other: "Union[DataArray, Dataset]" ) -> Dict[Hashable, pd.Index]: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 34049288f90..f590f85bf8a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -23,6 +23,7 @@ from ..plot.plot import _PlotMethods from . import ( + alignment, computation, dtypes, groupby, @@ -36,12 +37,7 @@ ) from .accessor_dt import CombinedDatetimelikeAccessor from .accessor_str import StringAccessor -from .alignment import ( - _broadcast_helper, - _get_broadcast_dims_map_common_coords, - align, - reindex_like_indexers, -) +from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align from .arithmetic import DataArrayArithmetic from .common import AbstractArray, DataWithCoords from .computation import unify_chunks @@ -1443,6 +1439,16 @@ def broadcast_like( return _broadcast_helper(args[1], exclude, dims_map, common_coords) + def _normalize_fill_value(self, fill_value): + if isinstance(fill_value, dict): + fill_value = fill_value.copy() + sentinel = object() + value = fill_value.pop(self.name, sentinel) + if value is not sentinel: + fill_value[_THIS_ARRAY] = value + + return fill_value + def reindex_like( self, other: Union["DataArray", Dataset], @@ -1496,13 +1502,13 @@ def reindex_like( DataArray.reindex align """ - indexers = reindex_like_indexers(self, other) - return self.reindex( - indexers=indexers, + return alignment.reindex( + self, + indexers=other.xindexes, method=method, - tolerance=tolerance, copy=copy, - fill_value=fill_value, + fill_value=self._normalize_fill_value(fill_value), + tolerance=tolerance, ) def reindex( @@ -1581,22 +1587,15 @@ def reindex( DataArray.reindex_like align """ - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") - if isinstance(fill_value, dict): - fill_value = fill_value.copy() - sentinel = object() - value = fill_value.pop(self.name, sentinel) - if value is not sentinel: - fill_value[_THIS_ARRAY] = value - - ds = self._to_temp_dataset().reindex( + indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") + return alignment.reindex( + self, indexers=indexers, method=method, tolerance=tolerance, copy=copy, - fill_value=fill_value, + fill_value=self._normalize_fill_value(fill_value), ) - return self._from_temp_dataset(ds) def interp( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4e946e2bfc4..f4be598d67d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -93,11 +93,11 @@ maybe_wrap_array, ) from .variable import ( - calculate_dimensions, IndexVariable, Variable, as_variable, broadcast_variables, + calculate_dimensions, propagate_attrs_encoding, ) @@ -2563,9 +2563,9 @@ def reindex_like( Dataset.reindex align """ - indexers = alignment.reindex_like_indexers(self, other) - return self.reindex( - indexers=indexers, + return alignment.reindex( + self, + indexers=other.xindexes, method=method, copy=copy, fill_value=fill_value, @@ -2772,14 +2772,14 @@ def reindex( original dataset, use the :py:meth:`~Dataset.fillna()` method. """ - return self._reindex( - indexers, - method, - tolerance, - copy, - fill_value, - sparse=False, - **indexers_kwargs, + indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") + return alignment.reindex( + self, + indexers=indexers, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, ) def _reindex( @@ -2793,28 +2793,18 @@ def _reindex( **indexers_kwargs: Any, ) -> "Dataset": """ - same to _reindex but support sparse option + Same than reindex but supports sparse option. """ indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") - - bad_dims = [d for d in indexers if d not in self.dims] - if bad_dims: - raise ValueError(f"invalid reindex dimensions: {bad_dims}") - - variables, indexes = alignment.reindex_variables( - self.variables, - self.sizes, - self.xindexes, - indexers, - method, - tolerance, + return alignment.reindex( + self, + indexers=indexers, + method=method, + tolerance=tolerance, copy=copy, fill_value=fill_value, sparse=sparse, ) - coord_names = set(self._coord_names) - coord_names.update(indexers) - return self._replace_with_new_dims(variables, coord_names, indexes=indexes) def interp( self, diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 36ef65110f2..a137d2c6bd2 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -22,9 +22,13 @@ from . import dtypes, pdcompat from .alignment import deep_align from .duck_array_ops import lazy_array_equiv -from .indexes import Index, PandasIndex, PandasMultiIndex +from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex from .utils import Frozen, compat_dict_union, dict_equiv, equivalent -from .variable import calculate_dimensions, Variable, as_variable # , assert_unique_multiindex_level_names +from .variable import ( # , assert_unique_multiindex_level_names + Variable, + as_variable, + calculate_dimensions, +) if TYPE_CHECKING: from .coordinates import Coordinates @@ -497,7 +501,11 @@ def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="ou objects = [data_vars, coords] explicit_coords = coords.keys() return merge_core( - objects, compat, join, explicit_coords=explicit_coords, indexes=indexes + objects, + compat, + join, + explicit_coords=explicit_coords, + indexes=Indexes(indexes, coords), ) @@ -1027,18 +1035,9 @@ def dataset_update_method( if coord_names: other[key] = value.drop_vars(coord_names) - # use ds.coords and not ds.indexes, else str coords are cast to object - # TODO: benbovy - flexible indexes: make it work with any xarray index - indexes = {} - for key, index in dataset.xindexes.items(): - if isinstance(index, PandasIndex): - indexes[key] = dataset.coords[key] - else: - indexes[key] = index - return merge_core( [dataset, other], priority_arg=1, - indexes=indexes, # type: ignore + indexes=dataset.xindexes, combine_attrs="override", ) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 22a3efce999..4680857219d 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1872,7 +1872,7 @@ def test_dot_align_coords(use_dask) -> None: xr.testing.assert_allclose(expected, actual) with xr.set_options(arithmetic_join="exact"): - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): xr.dot(da_a, da_b) # NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index a6d466014c0..3ca89911323 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -92,8 +92,8 @@ def test_repr_multiindex(self): array([0, 1, 2, 3]) Coordinates: * x (x) object MultiIndex - level_1 (x) object 'a' 'a' 'b' 'b' - level_2 (x) int64 1 2 1 2""" + * level_1 (x) object 'a' 'a' 'b' 'b' + * level_2 (x) int64 1 2 1 2""" ) assert expected == repr(self.mda) @@ -114,8 +114,8 @@ def test_repr_multiindex_long(self): 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) Coordinates: * x (x) object MultiIndex - level_1 (x) object 'a' 'a' 'a' 'a' 'a' 'a' 'a' ... 'd' 'd' 'd' 'd' 'd' 'd' - level_2 (x) int64 1 2 3 4 5 6 7 8 1 2 3 4 5 6 ... 4 5 6 7 8 1 2 3 4 5 6 7 8""" + * level_1 (x) object 'a' 'a' 'a' 'a' 'a' 'a' 'a' ... 'd' 'd' 'd' 'd' 'd' 'd' + * level_2 (x) int64 1 2 3 4 5 6 7 8 1 2 3 4 5 6 ... 4 5 6 7 8 1 2 3 4 5 6 7 8""" ) assert expected == repr(mda_long) @@ -1323,9 +1323,9 @@ def test_coords(self): expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo") assert_identical(da, expected) - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): - self.mda["level_1"] = np.arange(4) - self.mda.coords["level_1"] = np.arange(4) + # TODO: benbovy (explicit indexes) check that multi-index is reset + self.mda["level_1"] = ("x", np.arange(4)) + self.mda.coords["level_1"] = ("x", np.arange(4)) def test_coords_to_index(self): da = DataArray(np.zeros((2, 3)), [("x", [1, 2]), ("y", list("abc"))]) @@ -1442,8 +1442,8 @@ def test_assign_coords(self): expected = DataArray(10, {"c": 42}) assert_identical(actual, expected) - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): - self.mda.assign_coords(level_1=range(4)) + # TODO: benbovy (explicit indexes) check that multi-index is reset + self.mda.assign_coords(level_1=("x", range(4))) # GH: 2112 da = xr.DataArray([0, 1, 2], dims="x") @@ -1451,6 +1451,8 @@ def test_assign_coords(self): da["x"] = [0, 1, 2, 3] # size conflict with pytest.raises(ValueError): da.coords["x"] = [0, 1, 2, 3] # size conflict + with pytest.raises(ValueError): + da.coords["x"] = ("y", [1, 2, 3]) # no new dimension to a DataArray def test_coords_alignment(self): lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) @@ -2690,7 +2692,9 @@ def test_align_override(self): assert_identical(left.isel(x=0, drop=True), new_left) assert_identical(right, new_right) - with pytest.raises(ValueError, match=r"Indexes along dimension 'x' don't have"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*override.*same size" + ): align(left.isel(x=0).expand_dims("x"), right, join="override") @pytest.mark.parametrize( @@ -2709,7 +2713,9 @@ def test_align_override(self): ], ) def test_align_override_error(self, darrays): - with pytest.raises(ValueError, match=r"Indexes along dimension 'x' don't have"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*override.*same size" + ): xr.align(*darrays, join="override") def test_align_exclude(self): @@ -2765,10 +2771,16 @@ def test_align_mixed_indexes(self): assert_identical(result1, array_with_coord) def test_align_without_indexes_errors(self): - with pytest.raises(ValueError, match=r"cannot be aligned"): + with pytest.raises( + ValueError, + match=r"cannot.*align.*unlabeled dimension.*conflicting.*sizes.*", + ): align(DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], dims=["x"])) - with pytest.raises(ValueError, match=r"cannot be aligned"): + with pytest.raises( + ValueError, + match=r"cannot.*align.*unlabeled dimension.*conflicting.*sizes.*", + ): align( DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], coords=[("x", [0, 1])]), @@ -3576,7 +3588,9 @@ def test_dot_align_coords(self): dm = DataArray(dm_vals, coords=[z_m], dims=["z"]) with xr.set_options(arithmetic_join="exact"): - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*exact.*not equal.*" + ): da.dot(dm) da_aligned, dm_aligned = xr.align(da, dm, join="inner") @@ -3627,7 +3641,9 @@ def test_matmul_align_coords(self): assert_identical(result, expected) with xr.set_options(arithmetic_join="exact"): - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*exact.*not equal.*" + ): da_a @ da_b def test_binary_op_propagate_indexes(self): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2d4bf66097a..754386e4611 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2044,7 +2044,7 @@ def test_align_exact(self): assert_identical(left1, left) assert_identical(left2, left) - with pytest.raises(ValueError, match=r"indexes .* not equal"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): xr.align(left, right, join="exact") def test_align_override(self): @@ -2066,7 +2066,9 @@ def test_align_override(self): assert_identical(left.isel(x=0, drop=True), new_left) assert_identical(right, new_right) - with pytest.raises(ValueError, match=r"Indexes along dimension 'x' don't have"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*override.*same size" + ): xr.align(left.isel(x=0).expand_dims("x"), right, join="override") def test_align_exclude(self): @@ -2146,11 +2148,15 @@ def test_align_non_unique(self): def test_align_str_dtype(self): - a = Dataset({"foo": ("x", [0, 1]), "x": ["a", "b"]}) - b = Dataset({"foo": ("x", [1, 2]), "x": ["b", "c"]}) + a = Dataset({"foo": ("x", [0, 1])}, coords={"x": ["a", "b"]}) + b = Dataset({"foo": ("x", [1, 2])}, coords={"x": ["b", "c"]}) - expected_a = Dataset({"foo": ("x", [0, 1, np.NaN]), "x": ["a", "b", "c"]}) - expected_b = Dataset({"foo": ("x", [np.NaN, 1, 2]), "x": ["a", "b", "c"]}) + expected_a = Dataset( + {"foo": ("x", [0, 1, np.NaN])}, coords={"x": ["a", "b", "c"]} + ) + expected_b = Dataset( + {"foo": ("x", [np.NaN, 1, 2])}, coords={"x": ["a", "b", "c"]} + ) actual_a, actual_b = xr.align(a, b, join="outer") diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index ec24a1d3f79..3ac06622356 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -137,6 +137,22 @@ def test_equals(self) -> None: index2 = PandasIndex([1, 2, 3], "x") assert index1.equals(index2) is True + def test_join(self) -> None: + index1 = PandasIndex(["a", "aa", "aaa"], "x", coord_dtype=" None: index1 = PandasIndex([1, 2, 3], "x") index2 = PandasIndex([4, 5, 6], "y") @@ -151,6 +167,19 @@ def test_intersection(self) -> None: assert actual.index.equals(pd.Index([2, 3])) assert actual.dim == "x" + def test_reindex_like(self) -> None: + index1 = PandasIndex([0, 1, 2], "x") + index2 = PandasIndex([1, 2, 3, 4], "x") + + expected = {"x": [1, 2, -1, -1]} + actual = index1.reindex_like(index2) + assert actual.keys() == expected.keys() + np.testing.assert_array_equal(actual["x"], expected["x"]) + + index3 = PandasIndex([1, 1, 2], "x") + with pytest.raises(ValueError, match=r".*index has duplicate values"): + index3.reindex_like(index2) + def test_rename(self) -> None: index = PandasIndex(pd.Index([1, 2, 3], name="a"), "x", coord_dtype=np.int32) @@ -331,6 +360,20 @@ def test_query(self) -> None: with pytest.raises(IndexError): index.query({"x": (slice(None), 1, "no_level")}) + def test_join(self): + midx = pd.MultiIndex.from_product([["a", "aa"], [1, 2]], names=("one", "two")) + level_coords_dtype = {"one": " None: level_coords_dtype = {"one": " None: class TestIndexes: - def _create_indexes(self) -> Tuple[Indexes[Index], List[PandasIndex]]: + @pytest.fixture + def unique_indexes(self) -> List[PandasIndex]: x_idx = PandasIndex(pd.Index([1, 2, 3], name="x"), "x") y_idx = PandasIndex(pd.Index([4, 5, 6], name="y"), "y") z_pd_midx = pd.MultiIndex.from_product( @@ -384,7 +428,11 @@ def _create_indexes(self) -> Tuple[Indexes[Index], List[PandasIndex]]: ) z_midx = PandasMultiIndex(z_pd_midx, "z") - unique_indexes = [x_idx, y_idx, z_midx] + return [x_idx, y_idx, z_midx] + + @pytest.fixture + def indexes(self, unique_indexes) -> Indexes[Index]: + x_idx, y_idx, z_midx = unique_indexes indexes: Dict[Any, Index] = { "x": x_idx, "y": y_idx, @@ -394,25 +442,27 @@ def _create_indexes(self) -> Tuple[Indexes[Index], List[PandasIndex]]: } variables: Dict[Any, Variable] = {} for idx in unique_indexes: - variables.update(idx.create_variables({}, {})) + variables.update(idx.create_variables()) + + return Indexes(indexes, variables) - return Indexes(indexes, variables), unique_indexes + def test_interface(self, unique_indexes, indexes) -> None: + x_idx = unique_indexes[0] + assert list(indexes) == ["x", "y", "z", "one", "two"] + assert len(indexes) == 5 + assert "x" in indexes + assert indexes["x"] is x_idx - def test_variables(self) -> None: - indexes, _ = self._create_indexes() + def test_variables(self, indexes) -> None: assert tuple(indexes.variables) == ("x", "y", "z", "one", "two") - def test_dims(self) -> None: - indexes, _ = self._create_indexes() + def test_dims(self, indexes) -> None: assert indexes.dims == {"x": 3, "y": 3, "z": 4} - def test_get_unique(self) -> None: - indexes, unique = self._create_indexes() - assert indexes.get_unique() == unique - - def test_get_all_coords(self) -> None: - indexes, _ = self._create_indexes() + def test_get_unique(self, unique_indexes, indexes) -> None: + assert indexes.get_unique() == unique_indexes + def test_get_all_coords(self, indexes) -> None: expected = { "z": indexes.variables["z"], "one": indexes.variables["one"], @@ -428,14 +478,12 @@ def test_get_all_coords(self) -> None: assert indexes.get_all_coords("no_coord", errors="ignore") == {} - def test_group_by_index(self): - indexes, unique = self._create_indexes() - + def test_group_by_index(self, unique_indexes, indexes): expected = [ - (unique[0], {"x": indexes.variables["x"]}), - (unique[1], {"y": indexes.variables["y"]}), + (unique_indexes[0], {"x": indexes.variables["x"]}), + (unique_indexes[1], {"y": indexes.variables["y"]}), ( - unique[2], + unique_indexes[2], { "z": indexes.variables["z"], "one": indexes.variables["one"], @@ -445,3 +493,9 @@ def test_group_by_index(self): ] assert indexes.group_by_index() == expected + + def test_to_pandas_indexes(self, indexes) -> None: + pd_indexes = indexes.to_pandas_indexes() + assert isinstance(pd_indexes, Indexes) + assert all([isinstance(idx, pd.Index) for idx in pd_indexes.values()]) + assert indexes.variables == pd_indexes.variables diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 555a29b1952..6dca04ed069 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -242,7 +242,7 @@ def test_merge_error(self): def test_merge_alignment_error(self): ds = xr.Dataset(coords={"x": [1, 2]}) other = xr.Dataset(coords={"x": [2, 3]}) - with pytest.raises(ValueError, match=r"indexes .* not equal"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): xr.merge([ds, other], join="exact") def test_merge_wrong_input_error(self): From 63a4291df008ad42e0db736718ded08dd8550d25 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 21 Oct 2021 16:57:48 +0200 Subject: [PATCH 065/159] refactor alignment fix tests & types annotations All `*align*` and `*reindex*` tests are passing! (locally) Mypy doesn't complain. TODO: refactor `interp` and `interp_like` (I only plan to fix it with the new Aligner class for now) --- xarray/core/alignment.py | 214 ++++++++++++++++++++------------- xarray/core/dataarray.py | 27 +++-- xarray/core/dataset.py | 45 ++++++- xarray/tests/test_dataarray.py | 4 +- xarray/tests/test_dataset.py | 2 +- 5 files changed, 194 insertions(+), 98 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 74041bbbca7..52b1bbad6fc 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -9,6 +9,7 @@ Callable, Dict, FrozenSet, + Generic, Hashable, List, Mapping, @@ -19,13 +20,13 @@ Type, TypeVar, Union, - cast, ) import numpy as np import pandas as pd from . import dtypes +from .common import DataWithCoords from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex, get_indexer_nd from .utils import is_dict_like, is_full_slice, safe_cast_to_index from .variable import Variable, calculate_dimensions @@ -34,12 +35,65 @@ from .dataarray import DataArray from .dataset import Dataset - T_Xarray = TypeVar("T_Xarray", bound=Union[Dataset, DataArray]) - DataAlignable = Sequence[Union[Dataset, DataArray]] - DataAligned = Tuple[Union[Dataset, DataArray], ...] +DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords) -class Aligner: +def reindex_variables( + variables: Mapping[Any, Variable], + dim_pos_indexers: Mapping[Any, Any], + copy: bool = True, + fill_value: Any = dtypes.NA, + sparse: bool = False, +) -> Dict[Hashable, Variable]: + """Conform a dictionary of variables onto a new set of variables reindexed + with dimension positional indexers and possibly filled with missing values. + + Not public API. + + """ + new_variables = {} + dim_sizes = calculate_dimensions(variables) + + masked_dims = set() + unchanged_dims = set() + for dim, indxr in dim_pos_indexers.items(): + # Negative values in dim_pos_indexers mean values missing in the new index + # See ``Index.reindex_like``. + if (indxr < 0).any(): + masked_dims.add(dim) + elif np.array_equal(indxr, np.arange(dim_sizes.get(dim, 0))): + unchanged_dims.add(dim) + + for name, var in variables.items(): + if isinstance(fill_value, dict): + fill_value_ = fill_value.get(name, dtypes.NA) + else: + fill_value_ = fill_value + + if sparse: + var = var._as_sparse(fill_value=fill_value_) + indxr = tuple( + slice(None) if d in unchanged_dims else dim_pos_indexers.get(d, slice(None)) + for d in var.dims + ) + needs_masking = any(d in masked_dims for d in var.dims) + + if needs_masking: + new_var = var._getitem_with_mask(indxr, fill_value=fill_value_) + elif all(is_full_slice(k) for k in indxr): + # no reindexing necessary + # here we need to manually deal with copying data, since + # we neither created a new ndarray nor used fancy indexing + new_var = var.copy(deep=copy) + else: + new_var = var[indxr] + + new_variables[name] = new_var + + return new_variables + + +class Aligner(Generic[DataAlignable]): """Implements all the complex logic for the re-indexing and alignment of Xarray objects. @@ -55,7 +109,7 @@ class Aligner: NormalizedIndexes = Dict[MatchingIndexKey, Index] NormalizedIndexVars = Dict[MatchingIndexKey, Dict[Hashable, Variable]] - objects: "DataAlignable" + objects: Tuple[DataAlignable, ...] objects_matching_indexes: Tuple[Dict[MatchingIndexKey, Index], ...] join: str exclude_dims: FrozenSet @@ -75,7 +129,7 @@ class Aligner: def __init__( self, - objects: "DataAlignable", + objects: Sequence[DataAlignable], join: str = "inner", indexes: Mapping[Any, Any] = None, exclude: Any = frozenset(), @@ -299,6 +353,8 @@ def _get_index_joiner(self, index_cls) -> Callable: return lambda _: None def align_indexes(self): + """Compute all aligned indexes and their corresponding coordinate variables.""" + aligned_indexes = {} aligned_index_vars = {} reindex = {} @@ -365,6 +421,18 @@ def align_indexes(self): new_indexes[name] = joined_index new_index_vars[name] = var + # Explicitly provided indexes that are not found in objects to align + # may relate to unindexed dimensions so we add them too + for key, idx in self.indexes.items(): + if key not in aligned_indexes: + index_vars = self.index_vars[key] + reindex[key] = False + aligned_indexes[key] = idx + aligned_index_vars[key] = index_vars + for name, var in index_vars.items(): + new_indexes[name] = idx + new_index_vars[name] = var + self.aligned_indexes = aligned_indexes self.aligned_index_vars = aligned_index_vars self.reindex = reindex @@ -383,11 +451,11 @@ def assert_unindexed_dim_sizes_equal(self): add_err_msg = "" if len(sizes) > 1: raise ValueError( - f"cannot reindex or align along unlabeled dimension {dim!r} " + f"cannot reindex or align along dimension {dim!r} " f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg ) - def override_indexes(self) -> "DataAligned": + def override_indexes(self) -> Tuple[DataAlignable, ...]: objects = list(self.objects) for i, obj in enumerate(objects[1:]): @@ -406,13 +474,14 @@ def override_indexes(self) -> "DataAligned": return tuple(objects) - def _reindex_one(self, obj: "T_Xarray", matching_indexes) -> "T_Xarray": - from .dataarray import DataArray - + def _reindex_one( + self, + obj: DataAlignable, + matching_indexes: Dict[MatchingIndexKey, Index], + ) -> DataAlignable: new_variables = {} new_indexes = {} - dim_reindexers = {} - added_new_indexes = False + dim_pos_indexers = {} for key, aligned_idx in self.aligned_indexes.items(): index_vars = self.aligned_index_vars[key] @@ -424,77 +493,21 @@ def _reindex_one(self, obj: "T_Xarray", matching_indexes) -> "T_Xarray": ) if index_vars_dims <= set(obj.dims): obj_idx = aligned_idx - added_new_indexes = True if obj_idx is not None: for name, var in index_vars.items(): new_indexes[name] = aligned_idx new_variables[name] = var if self.reindex[key]: - indexers = obj_idx.reindex_like(aligned_idx, **self.reindex_kwargs) - dim_reindexers.update(indexers) - - if not dim_reindexers: - # fast path for no reindexing necessary - new_obj = obj.copy(deep=self.copy) - if added_new_indexes: - new_obj = new_obj._overwrite_indexes(new_indexes, new_variables) - else: - if isinstance(obj, DataArray): - ds_obj = obj._to_temp_dataset() - else: - ds_obj = obj - - # Negative values in dim_indexers mean values missing in the new index - masked_dims = [ - dim for dim, indxr in dim_reindexers.items() if (indxr < 0).any() - ] - unchanged_dims = [dim not in dim_reindexers for dim in obj.dims] - - for name, var in ds_obj.variables.items(): - if name in new_variables: - continue - - if isinstance(self.fill_value, dict): - fill_value = self.fill_value.get(name, dtypes.NA) - else: - fill_value = self.fill_value - - if self.sparse: - var = var._as_sparse(fill_value=fill_value) - key = tuple( - slice(None) - if d in unchanged_dims - else dim_reindexers.get(d, slice(None)) - for d in var.dims - ) - needs_masking = any(d in masked_dims for d in var.dims) - - if needs_masking: - new_var = var._getitem_with_mask(key, fill_value=fill_value) - elif all(is_full_slice(k) for k in key): - # no reindexing necessary - # here we need to manually deal with copying data, since - # we neither created a new ndarray nor used fancy indexing - new_var = var.copy(deep=self.copy) - else: - new_var = var[key] - - new_variables[name] = new_var - - new_coord_names = ds_obj._coord_names | set(new_indexes) - new_ds_obj = ds_obj._replace_with_new_dims( - new_variables, new_coord_names, indexes=new_indexes - ) - - if isinstance(obj, DataArray): - new_obj = obj._from_temp_dataset(new_ds_obj) - else: - new_obj = new_ds_obj + indexers = obj_idx.reindex_like(aligned_idx, **self.reindex_kwargs) # type: ignore[call-arg] + dim_pos_indexers.update(indexers) + new_obj = obj._reindex_callback( + self, dim_pos_indexers, new_variables, new_indexes, self.fill_value + ) new_obj.encoding = obj.encoding return new_obj - def reindex_all(self) -> "DataAligned": + def reindex_all(self) -> Tuple[DataAlignable, ...]: result = [] for obj, matching_indexes in zip(self.objects, self.objects_matching_indexes): @@ -502,7 +515,7 @@ def reindex_all(self) -> "DataAligned": return tuple(result) - def align(self) -> "DataAligned": + def align(self) -> Tuple[DataAlignable, ...]: if not self.indexes and len(self.objects) == 1: # fast path for the trivial case (obj,) = self.objects @@ -521,13 +534,13 @@ def align(self) -> "DataAligned": def align( - *objects: Union["Dataset", "DataArray"], + *objects: DataAlignable, join="inner", copy=True, indexes=None, exclude=frozenset(), fill_value=dtypes.NA, -) -> "DataAligned": +) -> Tuple[DataAlignable, ...]: """ Given any number of Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -724,7 +737,6 @@ def align( exclude=exclude, fill_value=fill_value, ) - return aligner.align() @@ -812,14 +824,14 @@ def is_alignable(obj): def reindex( - obj: "T_Xarray", + obj: DataAlignable, indexers: Mapping[Any, Any], method: str = None, tolerance: Number = None, copy: bool = True, fill_value: Any = dtypes.NA, sparse: bool = False, -) -> "T_Xarray": +) -> DataAlignable: """Re-index either a Dataset or a DataArray. Not public API. @@ -841,8 +853,42 @@ def reindex( fill_value=fill_value, sparse=sparse, ) + return aligner.align()[0] + + +def reindex_like( + obj: DataAlignable, + other: Union["Dataset", "DataArray"], + method: str = None, + tolerance: Number = None, + copy: bool = True, + fill_value: Any = dtypes.NA, +) -> DataAlignable: + """Re-index either a Dataset or a DataArray like another Dataset/DataArray. - return cast("T_Xarray", aligner.align()[0]) + Not public API. + + """ + if not other.xindexes: + # This check is not performed in Aligner. + for dim in other.dims: + if dim in obj.dims: + other_size = other.sizes[dim] + obj_size = obj.sizes[dim] + if other_size != obj_size: + raise ValueError( + "different size for unlabeled " + f"dimension on argument {dim!r}: {other_size!r} vs {obj_size!r}" + ) + + return reindex( + obj, + indexers=other.xindexes, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, + ) def reindex_like_indexers( @@ -887,7 +933,7 @@ def reindex_like_indexers( return indexers -def reindex_variables( +def _reindex_variables( variables: Mapping[Any, Variable], sizes: Mapping[Any, int], indexes: Mapping[Any, Index], diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index f590f85bf8a..b002baf92c3 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1439,7 +1439,16 @@ def broadcast_like( return _broadcast_helper(args[1], exclude, dims_map, common_coords) - def _normalize_fill_value(self, fill_value): + def _reindex_callback( + self, + aligner: alignment.Aligner, + dim_pos_indexers: Dict[Hashable, Any], + variables: Dict[Hashable, Variable], + indexes: Dict[Hashable, Index], + fill_value: Any = None, + ) -> "DataArray": + """Callback called from ``Aligner`` to create a new reindexed DataArray.""" + if isinstance(fill_value, dict): fill_value = fill_value.copy() sentinel = object() @@ -1447,7 +1456,11 @@ def _normalize_fill_value(self, fill_value): if value is not sentinel: fill_value[_THIS_ARRAY] = value - return fill_value + ds = self._to_temp_dataset() + reindexed = ds._reindex_callback( + aligner, dim_pos_indexers, variables, indexes, fill_value + ) + return self._from_temp_dataset(reindexed) def reindex_like( self, @@ -1502,13 +1515,13 @@ def reindex_like( DataArray.reindex align """ - return alignment.reindex( + return alignment.reindex_like( self, - indexers=other.xindexes, + other=other, method=method, - copy=copy, - fill_value=self._normalize_fill_value(fill_value), tolerance=tolerance, + copy=copy, + fill_value=fill_value, ) def reindex( @@ -1594,7 +1607,7 @@ def reindex( method=method, tolerance=tolerance, copy=copy, - fill_value=self._normalize_fill_value(fill_value), + fill_value=fill_value, ) def interp( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f4be598d67d..c3e46ce7b1c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1032,6 +1032,7 @@ def _overwrite_indexes( new_indexes[name] = indexes[name] for name in index_variables: + new_coord_names.add(name) new_variables[name] = variables[name] # append no-index variables at the end @@ -2511,6 +2512,42 @@ def broadcast_like( return _broadcast_helper(args[1], exclude, dims_map, common_coords) + def _reindex_callback( + self, + aligner: alignment.Aligner, + dim_pos_indexers: Dict[Hashable, Any], + variables: Dict[Hashable, Variable], + indexes: Dict[Hashable, Index], + fill_value: Any, + ) -> "Dataset": + """Callback called from ``Aligner`` to create a new reindexed Dataset.""" + + new_variables = variables.copy() + + if not dim_pos_indexers: + # fast path for no reindexing necessary + if set(indexes) - set(self.xindexes): + # this only adds new indexes and their coordinate variables + reindexed = self._overwrite_indexes(indexes, variables) + else: + reindexed = self.copy(deep=aligner.copy) + else: + to_reindex = {k: v for k, v in self.variables.items() if k not in variables} + reindexed_vars = alignment.reindex_variables( + to_reindex, + dim_pos_indexers, + copy=aligner.copy, + fill_value=fill_value, + sparse=aligner.sparse, + ) + new_variables.update(reindexed_vars) + new_coord_names = self._coord_names | set(indexes) + reindexed = self._replace_with_new_dims( + new_variables, new_coord_names, indexes=indexes + ) + + return reindexed + def reindex_like( self, other: Union["Dataset", "DataArray"], @@ -2563,13 +2600,13 @@ def reindex_like( Dataset.reindex align """ - return alignment.reindex( + return alignment.reindex_like( self, - indexers=other.xindexes, + other=other, method=method, + tolerance=tolerance, copy=copy, fill_value=fill_value, - tolerance=tolerance, ) def reindex( @@ -3028,7 +3065,7 @@ def _validate_interp_indexer(x, new_x): if to_reindex: # Reindex variables: - variables_reindex = alignment.reindex_variables( + variables_reindex = alignment._reindex_variables( variables=to_reindex, sizes=obj.sizes, indexes=obj.xindexes, diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 3ca89911323..17fb792ead5 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2773,13 +2773,13 @@ def test_align_mixed_indexes(self): def test_align_without_indexes_errors(self): with pytest.raises( ValueError, - match=r"cannot.*align.*unlabeled dimension.*conflicting.*sizes.*", + match=r"cannot.*align.*dimension.*conflicting.*sizes.*", ): align(DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], dims=["x"])) with pytest.raises( ValueError, - match=r"cannot.*align.*unlabeled dimension.*conflicting.*sizes.*", + match=r"cannot.*align.*dimension.*conflicting.*sizes.*", ): align( DataArray([1, 2, 3], dims=["x"]), diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 754386e4611..d818de7ce17 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1819,7 +1819,7 @@ def test_reindex(self): data.reindex("foo") # invalid dimension - with pytest.raises(ValueError, match=r"invalid reindex dim"): + with pytest.raises(ValueError, match=r"indexer keys.*not correspond.*"): data.reindex(invalid=0) # out of order From 9f8d26d7fff0078fd150b3c6fbcc25af902e97e6 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 21 Oct 2021 23:52:26 +0200 Subject: [PATCH 066/159] refactor interp and interp_like It now reuses the new alignment/reindex internals, but it still only supports dimension coordinates with a single index, --- xarray/core/alignment.py | 255 ++++++++------------------------------- xarray/core/dataarray.py | 6 +- xarray/core/dataset.py | 59 ++++++--- 3 files changed, 93 insertions(+), 227 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 52b1bbad6fc..c143f50a5fe 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -11,10 +11,9 @@ FrozenSet, Generic, Hashable, + Iterable, List, Mapping, - Optional, - Sequence, Set, Tuple, Type, @@ -27,7 +26,7 @@ from . import dtypes from .common import DataWithCoords -from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex, get_indexer_nd +from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex from .utils import is_dict_like, is_full_slice, safe_cast_to_index from .variable import Variable, calculate_dimensions @@ -112,7 +111,8 @@ class Aligner(Generic[DataAlignable]): objects: Tuple[DataAlignable, ...] objects_matching_indexes: Tuple[Dict[MatchingIndexKey, Index], ...] join: str - exclude_dims: FrozenSet + exclude_dims: FrozenSet[Hashable] + exclude_vars: FrozenSet[Hashable] copy: bool fill_value: Any sparse: bool @@ -129,10 +129,11 @@ class Aligner(Generic[DataAlignable]): def __init__( self, - objects: Sequence[DataAlignable], + objects: Iterable[DataAlignable], join: str = "inner", indexes: Mapping[Any, Any] = None, - exclude: Any = frozenset(), + exclude_dims: Iterable = frozenset(), + exclude_vars: Iterable[Hashable] = frozenset(), method: str = None, tolerance: Number = None, copy: bool = True, @@ -155,9 +156,10 @@ def __init__( else: self.reindex_kwargs = {"method": method, "tolerance": tolerance} - if isinstance(exclude, str): - exclude = [exclude] - self.exclude_dims = frozenset(exclude) + if isinstance(exclude_dims, str): + exclude_dims = [exclude_dims] + self.exclude_dims = frozenset(exclude_dims) + self.exclude_vars = frozenset(exclude_vars) if indexes is None: indexes = {} @@ -474,14 +476,28 @@ def override_indexes(self) -> Tuple[DataAlignable, ...]: return tuple(objects) - def _reindex_one( + def _get_dim_pos_indexers( + self, + matching_indexes: Dict[MatchingIndexKey, Index], + ) -> Dict[Hashable, Any]: + dim_pos_indexers = {} + + for key, aligned_idx in self.aligned_indexes.items(): + obj_idx = matching_indexes.get(key) + if obj_idx is not None: + if self.reindex[key]: + indexers = obj_idx.reindex_like(aligned_idx, **self.reindex_kwargs) # type: ignore[call-arg] + dim_pos_indexers.update(indexers) + + return dim_pos_indexers + + def _get_indexes_and_vars( self, obj: DataAlignable, matching_indexes: Dict[MatchingIndexKey, Index], - ) -> DataAlignable: - new_variables = {} + ) -> Tuple[Dict[Hashable, Index], Dict[Hashable, Variable]]: new_indexes = {} - dim_pos_indexers = {} + new_variables = {} for key, aligned_idx in self.aligned_indexes.items(): index_vars = self.aligned_index_vars[key] @@ -497,12 +513,24 @@ def _reindex_one( for name, var in index_vars.items(): new_indexes[name] = aligned_idx new_variables[name] = var - if self.reindex[key]: - indexers = obj_idx.reindex_like(aligned_idx, **self.reindex_kwargs) # type: ignore[call-arg] - dim_pos_indexers.update(indexers) + + return new_indexes, new_variables + + def _reindex_one( + self, + obj: DataAlignable, + matching_indexes: Dict[MatchingIndexKey, Index], + ) -> DataAlignable: + new_indexes, new_variables = self._get_indexes_and_vars(obj, matching_indexes) + dim_pos_indexers = self._get_dim_pos_indexers(matching_indexes) new_obj = obj._reindex_callback( - self, dim_pos_indexers, new_variables, new_indexes, self.fill_value + self, + dim_pos_indexers, + new_variables, + new_indexes, + self.fill_value, + self.exclude_vars, ) new_obj.encoding = obj.encoding return new_obj @@ -734,7 +762,7 @@ def align( join=join, copy=copy, indexes=indexes, - exclude=exclude, + exclude_dims=exclude, fill_value=fill_value, ) return aligner.align() @@ -831,6 +859,7 @@ def reindex( copy: bool = True, fill_value: Any = dtypes.NA, sparse: bool = False, + exclude_vars: Iterable[Hashable] = frozenset(), ) -> DataAlignable: """Re-index either a Dataset or a DataArray. @@ -852,6 +881,7 @@ def reindex( copy=copy, fill_value=fill_value, sparse=sparse, + exclude_vars=exclude_vars, ) return aligner.align()[0] @@ -891,195 +921,6 @@ def reindex_like( ) -def reindex_like_indexers( - target: "Union[DataArray, Dataset]", other: "Union[DataArray, Dataset]" -) -> Dict[Hashable, pd.Index]: - """Extract indexers to align target with other. - - Not public API. - - Parameters - ---------- - target : Dataset or DataArray - Object to be aligned. - other : Dataset or DataArray - Object to be aligned with. - - Returns - ------- - Dict[Hashable, pandas.Index] providing indexes for reindex keyword - arguments. - - Raises - ------ - ValueError - If any dimensions without labels have different sizes. - """ - # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 - # this doesn't support yet indexes other than pd.Index - indexers = { - k: v.to_pandas_index() for k, v in other.xindexes.items() if k in target.dims - } - - for dim in other.dims: - if dim not in indexers and dim in target.dims: - other_size = other.sizes[dim] - target_size = target.sizes[dim] - if other_size != target_size: - raise ValueError( - "different size for unlabeled " - f"dimension on argument {dim!r}: {other_size!r} vs {target_size!r}" - ) - return indexers - - -def _reindex_variables( - variables: Mapping[Any, Variable], - sizes: Mapping[Any, int], - indexes: Mapping[Any, Index], - indexers: Mapping, - method: Optional[str] = None, - tolerance: Any = None, - copy: bool = True, - fill_value: Optional[Any] = dtypes.NA, - sparse: bool = False, -) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: - """Conform a dictionary of aligned variables onto a new set of variables, - filling in missing values with NaN. - - Not public API. - - Parameters - ---------- - variables : dict-like - Dictionary of xarray.Variable objects. - sizes : dict-like - Dictionary from dimension names to integer sizes. - indexes : dict-like - Dictionary of indexes associated with variables. - indexers : dict - Dictionary with keys given by dimension names and values given by - arrays of coordinates tick labels. Any mis-matched coordinate values - will be filled in with NaN, and any mis-matched dimension names will - simply be ignored. - method : {None, 'nearest', 'pad'/'ffill', 'backfill'/'bfill'}, optional - Method to use for filling index values in ``indexers`` not found in - this dataset: - * None (default): don't fill gaps - * pad / ffill: propagate last valid index value forward - * backfill / bfill: propagate next valid index value backward - * nearest: use nearest valid index value - tolerance : optional - Maximum distance between original and new labels for inexact matches. - The values of the index at the matching locations must satisfy the - equation ``abs(index[indexer] - target) <= tolerance``. - copy : bool, optional - If ``copy=True``, data in the return values is always copied. If - ``copy=False`` and reindexing is unnecessary, or can be performed - with only slice operations, then the output may share memory with - the input. In either case, new xarray objects are always returned. - fill_value : scalar, optional - Value to use for newly missing values - sparse : bool, optional - Use an sparse-array - - Returns - ------- - reindexed : dict - Dict of reindexed variables. - new_indexes : dict - Dict of indexes associated with the reindexed variables. - """ - from .dataarray import DataArray - - # create variables for the new dataset - reindexed: Dict[Hashable, Variable] = {} - - # build up indexers for assignment along each dimension - int_indexers = {} - new_indexes = dict(indexes) - masked_dims = set() - unchanged_dims = set() - - for dim, indexer in indexers.items(): - if isinstance(indexer, DataArray) and indexer.dims != (dim,): - raise ValueError( - "Indexer has dimensions {:s} that are different " - "from that to be indexed along {:s}".format(str(indexer.dims), dim) - ) - - var_meta = {dim: {"dtype": getattr(indexer, "dtype", None)}} - if dim in variables: - var = variables[dim] - var_meta[dim].update({"attrs": var.attrs, "encoding": var.encoding}) - - target = safe_cast_to_index(indexers[dim]).rename(dim) - idx, idx_vars = PandasIndex.from_pandas_index(target, dim, var_meta=var_meta) - new_indexes[dim] = idx - reindexed.update(idx_vars) - - if dim in indexes: - # TODO (benbovy - flexible indexes): support other indexes than pd.Index? - index = indexes[dim].to_pandas_index() - - if not index.is_unique: - raise ValueError( - f"cannot reindex or align along dimension {dim!r} because the " - "index has duplicate values" - ) - - int_indexer = get_indexer_nd(index, target, method, tolerance) - - # We uses negative values from get_indexer_nd to signify - # values that are missing in the index. - if (int_indexer < 0).any(): - masked_dims.add(dim) - elif np.array_equal(int_indexer, np.arange(len(index))): - unchanged_dims.add(dim) - - int_indexers[dim] = int_indexer - - for dim in sizes: - if dim not in indexes and dim in indexers: - existing_size = sizes[dim] - new_size = indexers[dim].size - if existing_size != new_size: - raise ValueError( - f"cannot reindex or align along dimension {dim!r} without an " - f"index because its size {existing_size!r} is different from the size of " - f"the new index {new_size!r}" - ) - - for name, var in variables.items(): - if name not in indexers: - if isinstance(fill_value, dict): - fill_value_ = fill_value.get(name, dtypes.NA) - else: - fill_value_ = fill_value - - if sparse: - var = var._as_sparse(fill_value=fill_value_) - key = tuple( - slice(None) if d in unchanged_dims else int_indexers.get(d, slice(None)) - for d in var.dims - ) - needs_masking = any(d in masked_dims for d in var.dims) - - if needs_masking: - new_var = var._getitem_with_mask(key, fill_value=fill_value_) - elif all(is_full_slice(k) for k in key): - # no reindexing necessary - # here we need to manually deal with copying data, since - # we neither created a new ndarray nor used fancy indexing - new_var = var.copy(deep=copy) - else: - new_var = var[key] - - reindexed[name] = new_var - - return reindexed, new_indexes - - def _get_broadcast_dims_map_common_coords(args, exclude): common_coords = {} diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b002baf92c3..550574d18e5 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -7,6 +7,7 @@ Any, Callable, Dict, + FrozenSet, Hashable, Iterable, List, @@ -1445,7 +1446,8 @@ def _reindex_callback( dim_pos_indexers: Dict[Hashable, Any], variables: Dict[Hashable, Variable], indexes: Dict[Hashable, Index], - fill_value: Any = None, + fill_value: Any, + exclude_vars: FrozenSet[Hashable], ) -> "DataArray": """Callback called from ``Aligner`` to create a new reindexed DataArray.""" @@ -1458,7 +1460,7 @@ def _reindex_callback( ds = self._to_temp_dataset() reindexed = ds._reindex_callback( - aligner, dim_pos_indexers, variables, indexes, fill_value + aligner, dim_pos_indexers, variables, indexes, fill_value, exclude_vars ) return self._from_temp_dataset(reindexed) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c3e46ce7b1c..bf50052d67b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -15,6 +15,7 @@ Callable, Collection, Dict, + FrozenSet, Hashable, Iterable, Iterator, @@ -2519,6 +2520,7 @@ def _reindex_callback( variables: Dict[Hashable, Variable], indexes: Dict[Hashable, Index], fill_value: Any, + exclude_vars: FrozenSet[Hashable], ) -> "Dataset": """Callback called from ``Aligner`` to create a new reindexed Dataset.""" @@ -2532,7 +2534,11 @@ def _reindex_callback( else: reindexed = self.copy(deep=aligner.copy) else: - to_reindex = {k: v for k, v in self.variables.items() if k not in variables} + to_reindex = { + k: v + for k, v in self.variables.items() + if k not in variables and k not in exclude_vars + } reindexed_vars = alignment.reindex_variables( to_reindex, dim_pos_indexers, @@ -3034,7 +3040,7 @@ def _validate_interp_indexer(x, new_x): } variables: Dict[Hashable, Variable] = {} - to_reindex: Dict[Hashable, Variable] = {} + reindex: bool = False for name, var in obj._variables.items(): if name in indexers: continue @@ -3052,42 +3058,52 @@ def _validate_interp_indexer(x, new_x): elif dtype_kind in "ObU" and (use_indexers.keys() & var.dims): # For types that we do not understand do stepwise # interpolation to avoid modifying the elements. - # Use reindex_variables instead because it supports + # reindex the variable instead because it supports # booleans and objects and retains the dtype but inside # this loop there might be some duplicate code that slows it # down, therefore collect these signals and run it later: - to_reindex[name] = var + reindex = True elif all(d not in indexers for d in var.dims): # For anything else we can only keep variables if they # are not dependent on any coords that are being # interpolated along: variables[name] = var - if to_reindex: - # Reindex variables: - variables_reindex = alignment._reindex_variables( - variables=to_reindex, - sizes=obj.sizes, - indexes=obj.xindexes, - indexers={k: v[-1] for k, v in validated_indexers.items()}, + if reindex: + reindex_indexers = { + k: v for k, (_, v) in validated_indexers.items() if v.dims == (k,) + } + reindexed = alignment.reindex( + obj, + indexers=reindex_indexers, method=method_non_numeric, - )[0] - variables.update(variables_reindex) + exclude_vars=variables.keys(), + ) + indexes = dict(reindexed.xindexes) + variables.update(reindexed.variables) + else: + # Get the indexes that are not being interpolated along + indexes = {k: v for k, v in obj.xindexes.items() if k not in indexers} # Get the coords that also exist in the variables: coord_names = obj._coord_names & variables.keys() - # Get the indexes that are not being interpolated along: - indexes = {k: v for k, v in obj.xindexes.items() if k not in indexers} selected = self._replace_with_new_dims( variables.copy(), coord_names, indexes=indexes ) # Attach indexer as coordinate - variables.update(indexers) for k, v in indexers.items(): assert isinstance(v, Variable) if v.dims == (k,): - indexes[k] = v._to_xindex() + index = PandasIndex(v, k, coord_dtype=v.dtype) + index_vars = index.create_variables( + attrs={k: v.attrs}, + encoding={k: v.encoding}, + ) + indexes[k] = index + variables.update(index_vars) + else: + variables[k] = v # Extract coordinates from indexers coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords) @@ -3148,7 +3164,14 @@ def interp_like( """ if kwargs is None: kwargs = {} - coords = alignment.reindex_like_indexers(self, other) + + # pick only dimension coordinates with a single index + coords = {} + other_indexes = other.xindexes + for dim in self.dims: + other_dim_coords = other_indexes.get_all_coords(dim, errors="ignore") + if len(other_dim_coords) == 1: + coords[dim] = other_dim_coords[dim] numeric_coords: Dict[Hashable, pd.Index] = {} object_coords: Dict[Hashable, pd.Index] = {} From f291c657a0ead24d2fc27e00a6c13194b315e794 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 22 Oct 2021 18:49:28 +0200 Subject: [PATCH 067/159] wip review merge --- xarray/core/alignment.py | 32 ++++++------------ xarray/core/indexes.py | 67 ++++++++++++++++++++++++++++++++++++ xarray/core/merge.py | 73 ++++++++++++++-------------------------- 3 files changed, 103 insertions(+), 69 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index c143f50a5fe..5b4e54bc1db 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -26,7 +26,7 @@ from . import dtypes from .common import DataWithCoords -from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex +from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex, indexes_equal from .utils import is_dict_like, is_full_slice, safe_cast_to_index from .variable import Variable, calculate_dimensions @@ -314,7 +314,7 @@ def assert_no_index_conflict(self): "- they may be used to reindex data along common dimensions" ) - def _need_reindex(self, dims, index, other_indexes, coords, other_coords) -> bool: + def _need_reindex(self, dims, cmp_indexes) -> bool: """Whether or not we need to reindex variables for a set of matching indexes. @@ -325,17 +325,8 @@ def _need_reindex(self, dims, index, other_indexes, coords, other_coords) -> boo pandas). This is useful, e.g., for overwriting such duplicate indexes. """ - try: - index_not_equal = any(not index.equals(idx) for idx in other_indexes) - except NotImplementedError: - # check coordinates equality for indexes that do not support alignment - index_not_equal = any( - not coords[k].equals(o_coords[k]) - for o_coords in other_coords - for k in coords - ) has_unindexed_dims = any(dim in self.unindexed_dim_sizes for dim in dims) - return index_not_equal or has_unindexed_dims + return not (indexes_equal(cmp_indexes)) or has_unindexed_dims def _get_index_joiner(self, index_cls) -> Callable: if self.join in ["outer", "inner"]: @@ -377,21 +368,18 @@ def align_indexes(self): elif key in self.indexes: joined_index = self.indexes[key] joined_index_vars = self.index_vars[key] - need_reindex = self._need_reindex( - dims, - joined_index, - matching_indexes, - joined_index_vars, - matching_index_vars, + cmp_indexes = list( + zip( + [joined_index] + matching_indexes, + [joined_index_vars] + matching_index_vars, + ) ) + need_reindex = self._need_reindex(dims, cmp_indexes) else: if len(matching_indexes) > 1: need_reindex = self._need_reindex( dims, - matching_indexes[0], - matching_indexes[1:], - matching_index_vars[0], - matching_index_vars[1:], + list(zip(matching_indexes, matching_index_vars)), ) else: need_reindex = False diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index fa33a92e55e..0ba59cf8940 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -10,6 +10,7 @@ List, Mapping, Optional, + Sequence, Set, Tuple, TypeVar, @@ -829,6 +830,41 @@ def rename(self, name_dict, dims_dict): return self.from_pandas_index(index, new_dim, var_meta=var_meta) +def create_default_index_implicit( + dim_variable: "Variable", + all_variables: Optional[Mapping] = None, +) -> Tuple[Index, IndexVars]: + """Create a default index from a dimension variable. + + Create a PandasMultiIndex if the given variable wraps a pandas.MultiIndex, + otherwise create a PandasIndex. + + This function will become obsolete once we depreciate + implcitly passing a pandas.MultiIndex as a coordinate. + + """ + if all_variables is None: + all_variables = {} + + name = dim_variable.dims[0] + array = getattr(dim_variable._data, "array", None) + index: PandasIndex + + if isinstance(array, pd.MultiIndex): + index, index_vars = PandasMultiIndex.from_pandas_index(array, name) + # check for conflict between level names and variable names + duplicate_names = [k for k in index_vars if k in all_variables and k != name] + if duplicate_names: + conflict_str = "\n".join(duplicate_names) + raise ValueError( + f"conflicting MultiIndex level / variable name(s):\n{conflict_str}" + ) + else: + index, index_vars = PandasIndex.from_variables({name: dim_variable}) + + return index, index_vars + + def remove_unused_levels_categories(index: pd.Index) -> pd.Index: """ Remove unused levels from MultiIndex and unused categories from CategoricalIndex @@ -1121,3 +1157,34 @@ def propagate_indexes( new_indexes = None # type: ignore[assignment] return new_indexes + + +def indexes_equal(elements: Sequence[Tuple[Index, Dict[Hashable, "Variable"]]]) -> bool: + """Check if indexes are all equal. + + If they are not of the same type or they do not implement this check, check + if their coordinate variables are all equal instead. + + """ + + def check_variables(): + variables = [e[1] for e in elements] + return any( + not variables[0][k].equals(other_vars[k]) + for other_vars in variables[1:] + for k in variables[0] + ) + + indexes = [e[0] for e in elements] + same_type = all(type(indexes[0]) is type(other_idx) for other_idx in indexes[1:]) + if same_type: + try: + not_equal = any( + not indexes[0].equals(other_idx) for other_idx in indexes[1:] + ) + except NotImplementedError: + not_equal = check_variables() + else: + not_equal = check_variables() + + return not not_equal diff --git a/xarray/core/merge.py b/xarray/core/merge.py index a137d2c6bd2..0d71f737693 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import defaultdict from typing import ( TYPE_CHECKING, AbstractSet, @@ -22,13 +23,9 @@ from . import dtypes, pdcompat from .alignment import deep_align from .duck_array_ops import lazy_array_equiv -from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex +from .indexes import Index, Indexes, create_default_index_implicit, indexes_equal from .utils import Frozen, compat_dict_union, dict_equiv, equivalent -from .variable import ( # , assert_unique_multiindex_level_names - Variable, - as_variable, - calculate_dimensions, -) +from .variable import Variable, as_variable, calculate_dimensions if TYPE_CHECKING: from .coordinates import Coordinates @@ -220,13 +217,18 @@ def merge_collected( # TODO(shoyer): consider adjusting this logic. Are we really # OK throwing away variable without an index in favor of # indexed variables, without even checking if values match? + # TODO: benbovy (flexible indexes): possible duplicate index.equals calls + # in case of multi-coordinate indexes. Depending on how this affects the perfs, + # we might need to group the merge elements by matching index. variable, index = indexed_elements[0] - for _, other_index in indexed_elements[1:]: - if not index.equals(other_index): - raise MergeError( - f"conflicting values for index {name!r} on objects to be " - f"combined:\nfirst value: {index!r}\nsecond value: {other_index!r}" - ) + if not indexes_equal( + [(idx, {name: var}) for var, idx in indexed_elements] + ): + # TODO: show differing values/reprs in error msg? + raise MergeError( + f"conflicting values/indexes on objects to be combined " + f"for coordinate {name!r}" + ) if compat == "identical": for other_variable, _ in indexed_elements[1:]: if not dict_equiv(variable.attrs, other_variable.attrs): @@ -281,11 +283,10 @@ def collect_variables_and_indexes( if indexes is None: indexes = {} - grouped: Dict[Hashable, List[Tuple[Variable, Optional[Index]]]] = {} + grouped: Dict[Hashable, List[MergeElement]] = defaultdict(list) def append(name, variable, index): - values = grouped.setdefault(name, []) - values.append((variable, index)) + grouped[name].append((variable, index)) def append_all(variables, indexes): for name, variable in variables.items(): @@ -307,16 +308,12 @@ def append_all(variables, indexes): variable = as_variable(variable, name=name) if name in indexes: - index = indexes[name] + append(name, variable, indexes[name]) elif variable.dims == (name,): - # TODO: benbovy - explicit indexes: do we still need this? - # default "dimension" indexes are already created elsewhere - idx_variable = variable.to_index_variable() - index = idx_variable._to_xindex() - variable = idx_variable + idx, idx_vars = create_default_index_implicit(variable) + append_all(idx_vars, {k: idx for k in idx_vars}) else: - index = None - append(name, variable, index) + append(name, variable, None) return grouped @@ -491,7 +488,6 @@ def merge_coords( collected = collect_variables_and_indexes(aligned) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) variables, out_indexes = merge_collected(collected, prioritized, compat=compat) - # assert_unique_multiindex_level_names(variables) return variables, out_indexes @@ -514,9 +510,9 @@ def _create_indexes_from_coords(coords, data_vars=None): Return those indexes and updated coordinates. """ - all_var_names = set(coords.keys()) + all_variables = dict(coords) if data_vars is not None: - all_var_names |= set(data_vars.keys()) + all_variables.update(data_vars) indexes = {} updated_coords = {} @@ -525,26 +521,10 @@ def _create_indexes_from_coords(coords, data_vars=None): variable = as_variable(obj, name=name) if variable.dims == (name,): - array = getattr(variable._data, "array", None) - if isinstance(array, pd.MultiIndex): - # TODO: benbovy - explicit indexes: depreciate passing multi-indexes as coords? - index, index_vars = PandasMultiIndex.from_pandas_index(array, name) - # check for conflict between level names and variable names - duplicate_names = [ - k for k in index_vars if k in all_var_names and k != name - ] - if duplicate_names: - conflict_str = "\n".join(duplicate_names) - raise ValueError( - f"conflicting MultiIndex level / variable name(s):\n{conflict_str}" - ) - all_var_names |= set(index_vars.keys()) - else: - index, index_vars = PandasIndex.from_variables({name: variable}) - - indexes.update({k: index for k in index_vars}) - updated_coords.update(index_vars) - + idx, idx_vars = create_default_index_implicit(variable, all_variables) + indexes.update({k: idx for k in idx_vars}) + updated_coords.update(idx_vars) + all_variables.update(idx_vars) else: updated_coords[name] = obj @@ -692,7 +672,6 @@ def merge_core( variables, out_indexes = merge_collected( collected, prioritized, compat=compat, combine_attrs=combine_attrs ) - # assert_unique_multiindex_level_names(variables) dims = calculate_dimensions(variables) From 8e3eef24cfbfa9fd6eced84455d4fc62e0f2113b Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 25 Oct 2021 10:07:21 +0200 Subject: [PATCH 068/159] refactor swap_dims --- xarray/core/dataset.py | 19 +++++++++---------- xarray/tests/test_dataarray.py | 15 +++------------ xarray/tests/test_dataset.py | 10 ++-------- 3 files changed, 14 insertions(+), 30 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bf50052d67b..77692d8771e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -62,6 +62,7 @@ Indexes, PandasIndex, PandasMultiIndex, + create_default_index_implicit, default_indexes, isel_variable_and_index, propagate_indexes, @@ -3450,21 +3451,19 @@ def swap_dims( dims = tuple(dims_dict.get(dim, dim) for dim in v.dims) if k in result_dims: var = v.to_index_variable() + var.dims = dims if k in self.xindexes: indexes[k] = self.xindexes[k] + variables[k] = var else: - new_index = var.to_index() - if new_index.nlevels == 1: - # make sure index name matches dimension name - new_index = new_index.rename(k) - if isinstance(new_index, pd.MultiIndex): - indexes[k] = PandasMultiIndex(new_index, k) - else: - indexes[k] = PandasIndex(new_index, k) + index, index_vars = create_default_index_implicit(var) + indexes.update({name: index for name in index_vars}) + variables.update(index_vars) + coord_names.update(index_vars) else: var = v.to_base_variable() - var.dims = dims - variables[k] = var + var.dims = dims + variables[k] = var return self._replace_with_new_dims(variables, coord_names, indexes=indexes) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 17fb792ead5..0c92badea84 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1636,10 +1636,7 @@ def test_swap_dims(self): actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): - pd.testing.assert_index_equal( - expected.xindexes[dim_name].to_pandas_index(), - actual.xindexes[dim_name].to_pandas_index(), - ) + assert actual.xindexes[dim_name].equals(expected.xindexes[dim_name]) # as kwargs array = DataArray(np.random.randn(3), {"x": list("abc")}, "x") @@ -1647,10 +1644,7 @@ def test_swap_dims(self): actual = array.swap_dims(x="y") assert_identical(expected, actual) for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): - pd.testing.assert_index_equal( - expected.xindexes[dim_name].to_pandas_index(), - actual.xindexes[dim_name].to_pandas_index(), - ) + assert actual.xindexes[dim_name].equals(expected.xindexes[dim_name]) # multiindex case idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) @@ -1659,10 +1653,7 @@ def test_swap_dims(self): actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): - pd.testing.assert_index_equal( - expected.xindexes[dim_name].to_pandas_index(), - actual.xindexes[dim_name].to_pandas_index(), - ) + assert actual.xindexes[dim_name].equals(expected.xindexes[dim_name]) def test_expand_dims_error(self): array = DataArray( diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index d818de7ce17..688b0f47ed4 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2771,10 +2771,7 @@ def test_swap_dims(self): assert_identical(expected, actual) assert isinstance(actual.variables["y"], IndexVariable) assert isinstance(actual.variables["x"], Variable) - pd.testing.assert_index_equal( - actual.xindexes["y"].to_pandas_index(), - expected.xindexes["y"].to_pandas_index(), - ) + assert actual.xindexes["y"].equals(expected.xindexes["y"]) roundtripped = actual.swap_dims({"y": "x"}) assert_identical(original.set_coords("y"), roundtripped) @@ -2805,10 +2802,7 @@ def test_swap_dims(self): assert_identical(expected, actual) assert isinstance(actual.variables["y"], IndexVariable) assert isinstance(actual.variables["x"], Variable) - pd.testing.assert_index_equal( - actual.xindexes["y"].to_pandas_index(), - expected.xindexes["y"].to_pandas_index(), - ) + assert actual.xindexes["y"].equals(expected.xindexes["y"]) def test_expand_dims_error(self): original = Dataset( From 5e5abb70293a2a8fceb4711b512351906bb16027 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 25 Oct 2021 16:30:49 +0200 Subject: [PATCH 069/159] refactor isel --- xarray/core/dataset.py | 100 ++++++++++++++++------------ xarray/core/indexes.py | 118 +++++++++++++-------------------- xarray/tests/test_dataarray.py | 2 +- xarray/tests/test_dataset.py | 2 +- xarray/tests/test_indexes.py | 33 +++++---- 5 files changed, 121 insertions(+), 134 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 77692d8771e..90c3d9ab5fc 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -64,7 +64,6 @@ PandasMultiIndex, create_default_index_implicit, default_indexes, - isel_variable_and_index, propagate_indexes, remove_unused_levels_categories, roll_index, @@ -1235,7 +1234,10 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset": if ref_name in self._coord_names or ref_name in self.dims: coord_names.add(var_name) if (var_name,) == var.dims: - indexes[var_name] = var._to_xindex() + index, index_vars = create_default_index_implicit(var, names) + indexes.update({k: index for k in index_vars}) + variables.update(index_vars) + coord_names.update(index_vars) needed_dims: OrderedSet[Hashable] = OrderedSet() for v in variables.values(): @@ -2197,26 +2199,22 @@ def isel( variables = {} dims: Dict[Hashable, Tuple[int, ...]] = {} coord_names = self._coord_names.copy() - indexes = self._indexes.copy() if self._indexes is not None else None - - for var_name, var_value in self._variables.items(): - var_indexers = {k: v for k, v in indexers.items() if k in var_value.dims} - if var_indexers: - var_value = var_value.isel(var_indexers) - if drop and var_value.ndim == 0 and var_name in coord_names: - coord_names.remove(var_name) - if indexes: - indexes.pop(var_name, None) - continue - if indexes and var_name in indexes: - # TODO benbovy - flexible indexes: this won't be always desirable - # (e.g., 1-d out-of-core coordinate, "meta"-index, etc.) - if var_value.ndim == 1: - indexes[var_name] = var_value._to_xindex() - else: - del indexes[var_name] - variables[var_name] = var_value - dims.update(zip(var_value.dims, var_value.shape)) + + indexes, index_variables = self._isel_indexes(indexers) + + for name, var in self._variables.items(): + # preserve variable order + if name in index_variables: + var = index_variables[name] + else: + var_indexers = {k: v for k, v in indexers.items() if k in var.dims} + if var_indexers: + var = var.isel(var_indexers) + if drop and var.ndim == 0 and name in coord_names: + coord_names.remove(name) + continue + variables[name] = var + dims.update(zip(var.dims, var.shape)) return self._construct_direct( variables=variables, @@ -2228,6 +2226,30 @@ def isel( close=self._close, ) + def _isel_indexes( + self, + indexers: Mapping[Any, Any], + ) -> Tuple[Dict[Hashable, Index], Dict[Hashable, Variable]]: + index_variables: Dict[Hashable, Variable] = {} + indexes: Dict[Hashable, Index] = ( + self._indexes.copy() if self._indexes is not None else {} + ) + + for index, index_vars in self.xindexes.group_by_index(): + index_dims = set(d for var in index_vars.values() for d in var.dims) + index_indexers = {k: v for k, v in indexers.items() if k in index_dims} + if index_indexers: + new_index = index.isel(index_indexers) + if new_index is not None: + indexes.update({k: new_index for k in index_vars}) + new_index_vars = new_index.create_variables(index_vars) + index_variables.update(new_index_vars) + else: + for k in index_vars: + indexes.pop(k, None) + + return indexes, index_variables + def _isel_fancy( self, indexers: Mapping[Any, Any], @@ -2235,29 +2257,24 @@ def _isel_fancy( drop: bool, missing_dims: str = "raise", ) -> "Dataset": - # Note: we need to preserve the original indexers variable in order to merge the - # coords below - indexers_list = list(self._validate_indexers(indexers, missing_dims)) + valid_indexers = dict(self._validate_indexers(indexers, missing_dims)) variables: Dict[Hashable, Variable] = {} indexes: Dict[Hashable, Index] = {} - for name, var in self.variables.items(): - var_indexers = {k: v for k, v in indexers_list if k in var.dims} - if drop and name in var_indexers: - continue # drop this variable + indexes, index_variables = self._isel_indexes(valid_indexers) - if name in self.xindexes: - new_var, new_index = isel_variable_and_index( - name, var, self.xindexes[name], var_indexers - ) - if new_index is not None: - indexes[name] = new_index - elif var_indexers: - new_var = var.isel(indexers=var_indexers) + for name, var in self.variables.items(): + if name in index_variables: + new_var = index_variables[name] else: - new_var = var.copy(deep=False) - + var_indexers = { + k: v for k, v in valid_indexers.items() if k in var.dims + } + if var_indexers: + new_var = var.isel(indexers=var_indexers) + else: + new_var = var.copy(deep=False) variables[name] = new_var coord_names = self._coord_names & variables.keys() @@ -3097,10 +3114,7 @@ def _validate_interp_indexer(x, new_x): assert isinstance(v, Variable) if v.dims == (k,): index = PandasIndex(v, k, coord_dtype=v.dtype) - index_vars = index.create_variables( - attrs={k: v.attrs}, - encoding={k: v.encoding}, - ) + index_vars = index.create_variables({k: v}) indexes[k] = index variables.update(index_vars) else: diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 0ba59cf8940..7b004f954a6 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -42,9 +42,7 @@ def from_variables( raise NotImplementedError() def create_variables( - self, - attrs: Optional[Mapping[Any, Any]] = None, - encoding: Optional[Mapping[Any, Any]] = None, + self, variables: Optional[Mapping[Any, "Variable"]] = None ) -> IndexVars: return {} @@ -58,6 +56,11 @@ def to_pandas_index(self) -> pd.Index: """ raise TypeError(f"{self!r} cannot be cast to a pandas.Index object") + def isel( + self, indexers: Mapping[Any, Union[int, slice, np.ndarray, "Variable"]] + ) -> Union["Index", None]: + return None + def query(self, labels: Dict[Any, Any]) -> QueryResult: raise NotImplementedError(f"{self!r} doesn't support label-based selection") @@ -234,22 +237,25 @@ def from_variables( return obj, {name: index_var} def create_variables( - self, - attrs: Optional[Mapping[Any, Any]] = None, - encoding: Optional[Mapping[Any, Any]] = None, + self, variables: Optional[Mapping[Any, "Variable"]] = None ) -> IndexVars: from .variable import IndexVariable - if attrs is None: - attrs = {} - if encoding is None: - encoding = {} + name = self.index.name + attrs: Union[Mapping[Hashable, Any], None] + encoding: Union[Mapping[Hashable, Any], None] + + if variables is not None and name in variables: + var = variables[name] + attrs = var.attrs + encoding = var.encoding + else: + attrs = None + encoding = None name = self.index.name data = PandasIndexingAdapter(self.index, dtype=self.coord_dtype) - var = IndexVariable( - self.dim, data, attrs=attrs.get(name), encoding=encoding.get(name) - ) + var = IndexVariable(self.dim, data, attrs=attrs, encoding=encoding) return {name: var} @classmethod @@ -287,6 +293,24 @@ def from_pandas_index( def to_pandas_index(self) -> pd.Index: return self.index + def isel( + self, indexers: Mapping[Any, Union[int, slice, np.ndarray, "Variable"]] + ) -> Optional["PandasIndex"]: + from .variable import Variable + + indxr = indexers[self.dim] + if isinstance(indxr, int): + # can't preserve index with single value + return None + elif isinstance(indxr, Variable): + if indxr.dims != (self.dim,): + # can't preserve a index if result has new dimensions + return None + else: + indxr = indxr.data + + return self._replace(self.index[indxr]) + def query(self, labels: Dict[Any, Any], method=None, tolerance=None) -> QueryResult: from .dataarray import DataArray from .variable import Variable @@ -632,22 +656,17 @@ def from_pandas_index( return cls(index, dim, level_coords_dtype=level_coords_dtype), index_vars def create_variables( - self, - attrs: Optional[Mapping[Any, Any]] = None, - encoding: Optional[Mapping[Any, Any]] = None, + self, variables: Optional[Mapping[Any, "Variable"]] = None ) -> IndexVars: - if attrs is None: - attrs = {} - if encoding is None: - encoding = {} - var_meta = {} - for name in self.index.names: - var_meta[name] = { - "dtype": self.level_coords_dtype[name], - "attrs": attrs.get(name, {}), - "encoding": encoding.get(name, {}), - } + if variables is not None: + for name in self.index.names: + var = variables[name] + var_meta[name] = { + "dtype": self.level_coords_dtype[name], + "attrs": var.attrs, + "encoding": var.encoding, + } return _create_variables_from_multiindex( self.index, self.dim, var_meta=var_meta @@ -832,7 +851,7 @@ def rename(self, name_dict, dims_dict): def create_default_index_implicit( dim_variable: "Variable", - all_variables: Optional[Mapping] = None, + all_variables: Optional[Union[Mapping, Iterable[Hashable]]] = None, ) -> Tuple[Index, IndexVars]: """Create a default index from a dimension variable. @@ -1087,49 +1106,6 @@ def default_indexes( return {key: coords[key]._to_xindex() for key in dims if key in coords} -def isel_variable_and_index( - name: Hashable, - variable: "Variable", - index: Index, - indexers: Mapping[Any, Union[int, slice, np.ndarray, "Variable"]], -) -> Tuple["Variable", Optional[Index]]: - """Index a Variable and an Index together. - - If the index cannot be indexed, return None (it will be dropped). - - (note: not compatible yet with xarray flexible indexes). - - """ - from .variable import Variable - - if not indexers: - # nothing to index - return variable.copy(deep=False), index - - if len(variable.dims) > 1: - raise NotImplementedError( - "indexing multi-dimensional variable with indexes is not supported yet" - ) - - new_variable = variable.isel(indexers) - - if new_variable.dims != (name,): - # can't preserve a index if result has new dimensions - return new_variable, None - - # we need to compute the new index - (dim,) = variable.dims - indexer = indexers[dim] - if isinstance(indexer, Variable): - indexer = indexer.data - try: - new_index = index[indexer] - except NotImplementedError: - new_index = None - - return new_variable, new_index - - def roll_index(index: PandasIndex, count: int, axis: int = 0) -> PandasIndex: """Roll an pandas.Index.""" pd_index = index.to_pandas_index() diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 0c92badea84..5d1c943d8a6 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -894,7 +894,7 @@ def test_isel_fancy(self): assert "station" in actual.dims assert_identical(actual["station"], stations["station"]) - with pytest.raises(ValueError, match=r"conflicting values for "): + with pytest.raises(ValueError, match=r"conflicting values/indexes on "): da.isel( x=DataArray([0, 1, 2], dims="station", coords={"station": [0, 1, 2]}), y=DataArray([0, 1, 2], dims="station", coords={"station": [0, 1, 3]}), diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 688b0f47ed4..207fc598e4c 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1131,7 +1131,7 @@ def test_isel_fancy(self): assert "station" in actual.dims assert_identical(actual["station"].drop_vars(["dim2"]), stations["station"]) - with pytest.raises(ValueError, match=r"conflicting values for "): + with pytest.raises(ValueError, match=r"conflicting values/indexes on "): data.isel( dim1=DataArray( [0, 1, 2], dims="station", coords={"station": [0, 1, 2]} diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 3ac06622356..a2e1bf5d59f 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -79,14 +79,14 @@ def test_from_pandas_index(self) -> None: def test_create_variables(self) -> None: pd_idx = pd.Index([1, 2, 3], name="foo") index, _ = PandasIndex.from_pandas_index(pd_idx, "x") - attrs = {"unit": "m"} - encoding = {"fill_value": 0} + index_vars = { + "foo": IndexVariable( + "x", pd_idx, attrs={"unit": "m"}, encoding={"fill_value": 0} + ) + } - actual = index.create_variables( - attrs={"foo": attrs}, encoding={"foo": encoding} - ) - expected = {"foo": IndexVariable("x", pd_idx, attrs=attrs, encoding=encoding)} - assert_identical(actual["foo"], expected["foo"]) + actual = index.create_variables(index_vars) + assert_identical(actual["foo"], index_vars["foo"]) def test_to_pandas_index(self) -> None: pd_idx = pd.Index([1, 2, 3], name="foo") @@ -324,20 +324,17 @@ def test_create_variables(self) -> None: foo_data = np.array([0, 0, 1], dtype="int") bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) + index_vars = { + "x": IndexVariable("x", pd_idx), + "foo": IndexVariable("x", foo_data, attrs={"unit": "m"}), + "bar": IndexVariable("x", bar_data, encoding={"fill_value": 0}), + } index, _ = PandasMultiIndex.from_pandas_index(pd_idx, "x") - index_vars = index.create_variables( - attrs={"foo": {"unit": "m"}}, - encoding={"bar": {"fill_value": 0}}, - ) + actual = index.create_variables(index_vars) - assert_identical(index_vars["x"], IndexVariable("x", pd_idx)) - assert_identical( - index_vars["foo"], IndexVariable("x", foo_data, attrs={"unit": "m"}) - ) - assert_identical( - index_vars["bar"], IndexVariable("x", bar_data, encoding={"fill_value": 0}) - ) + for k, expected in index_vars.items(): + assert_identical(actual[k], expected) def test_query(self) -> None: index = PandasMultiIndex( From 2c11dbcd236e58af9ea94854d6f3b4442df8908c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 26 Oct 2021 00:53:20 +0200 Subject: [PATCH 070/159] add create_index option to stack --- xarray/core/dataarray.py | 10 ++- xarray/core/dataset.py | 132 +++++++++++++++++++-------------- xarray/core/indexes.py | 26 +++++-- xarray/tests/test_dask.py | 4 +- xarray/tests/test_dataarray.py | 2 +- xarray/tests/test_dataset.py | 32 +++++--- xarray/tests/test_indexes.py | 4 + 7 files changed, 131 insertions(+), 79 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 550574d18e5..d9087f41dd7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2061,6 +2061,7 @@ def reorder_levels( def stack( self, dimensions: Mapping[Any, Sequence[Hashable]] = None, + create_index: bool = True, **dimensions_kwargs: Sequence[Hashable], ) -> "DataArray": """ @@ -2077,6 +2078,11 @@ def stack( replace. An ellipsis (`...`) will be replaced by all unlisted dimensions. Passing a list containing an ellipsis (`stacked_dim=[...]`) will stack over all dimensions. + create_index : bool, optional + If True (default), create a multi-index for each of the stacked dimensions. + If False, don't create any index. + If None, create a multi-index only if one single (1-d) coordinate index + is found for every dimension to stack. **dimensions_kwargs The keyword arguments form of ``dimensions``. One of dimensions or dimensions_kwargs must be provided. @@ -2113,7 +2119,9 @@ def stack( -------- DataArray.unstack """ - ds = self._to_temp_dataset().stack(dimensions, **dimensions_kwargs) + ds = self._to_temp_dataset().stack( + dimensions, create_index=create_index, **dimensions_kwargs + ) return self._from_temp_dataset(ds) def unstack( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 90c3d9ab5fc..eb481d51b67 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3876,45 +3876,61 @@ def reorder_levels( return self._replace(variables, indexes=indexes) def _find_stack_index( - self, dim, multi=False - ) -> Tuple[Union[Index, None], List[Hashable]]: + self, + dim, + multi=False, + create_index=False, + ) -> Tuple[Union[Index, None], Dict[Hashable, Variable]]: """Used by stack and unstack to find one pandas (multi-)index among the indexed coordinates along dimension `dim`. If it finds exactly one index returns it with its corresponding - coordinate name(s), otherwise returns None and an empty list. + coordinate variables(s), otherwise returns None and an empty dict. """ - if multi: - index_cls = PandasMultiIndex - else: - index_cls = PandasIndex # type: ignore[assignment] - - indexes: Set[Index] = set() - names: List[Hashable] = [] + stack_index: Union[Index, None] = None + stack_coords: Dict[Hashable, Variable] = {} - for name in self._coord_names: + for name, index in self.xindexes.items(): var = self._variables[name] - index = self.xindexes.get(name) - if index is not None and var.ndim == 1: - var_dim = var.dims[0] - if var_dim == dim and type(index) is index_cls: - indexes.add(index) - names.append(name) - - if len(indexes) == 1: - return next(iter(indexes)), names - else: - return None, [] + if ( + var.ndim == 1 + and var.dims[0] == dim + and ( + not multi + and not self.xindexes.is_multi(name) + or multi + and isinstance(index, PandasMultiIndex) + ) + ): + print(dim, name, index) + if stack_index is not None and index is not stack_index: + # more than one index found, stop + if create_index: + raise ValueError( + f"cannot stack dimension {dim!r} with `create_index=True` " + "and with more than one index found along that dimension" + ) + return None, {} + stack_index = index + stack_coords[name] = var - def _stack_once(self, dims, new_dim): + if create_index and stack_index is None: + if dim in self._variables: + var = self._variables[dim] + else: + _, _, var = _get_virtual_variable(self._variables, dim, self.dims) + # dummy index (only var will be used to construct the multi-index) + stack_index = PandasIndex([0], dim) + stack_coords = {dim: var} + + return stack_index, stack_coords + + def _stack_once(self, dims, new_dim, create_index=True): if ... in dims: dims = list(infix_dims(dims, self.dims)) - # TODO: add default dimension variables (range) if missing - # only if we want backwards compatibility (multi-index always created) - - variables: Dict[Hashable, Variable] = {} + new_variables: Dict[Hashable, Variable] = {} stacked_var_names: List[Hashable] = [] drop_indexes: List[Hashable] = [] @@ -3925,55 +3941,52 @@ def _stack_once(self, dims, new_dim): shape = [self.dims[d] for d in vdims] exp_var = var.set_dims(vdims, shape) stacked_var = exp_var.stack(**{new_dim: dims}) - variables[name] = stacked_var + new_variables[name] = stacked_var stacked_var_names.append(name) else: - variables[name] = var.copy(deep=False) + new_variables[name] = var.copy(deep=False) # drop indexes of stacked coordinates (if any) for name in stacked_var_names: drop_indexes += list(self.xindexes.get_all_coords(name, errors="ignore")) - # A new index is created only if each of the stacked dimensions has - # one and only one 1-d coordinate index - # TODO: add API option to force/skip the creation of a new index (see GH 5202) - product_vars: Dict[Any, Variable] = {} - for dim in dims: - index, names = self._find_stack_index(dim) - if index is not None: - n = names[0] - product_vars[n] = self.variables[n] - - if len(product_vars) == len(dims): - idx, idx_vars = PandasMultiIndex.from_product_variables( - product_vars, new_dim - ) - new_indexes = {k: idx for k in idx_vars} - # keep consistent multi-index coordinate order - for k in idx_vars: - variables.pop(k, None) - variables.update(idx_vars) - coord_names = set(self._coord_names) | {new_dim} - else: - new_indexes = {} - coord_names = set(self._coord_names) + new_indexes = {} + new_coord_names = set(self._coord_names) + if create_index or create_index is None: + product_vars: Dict[Any, Variable] = {} + for dim in dims: + idx, idx_vars = self._find_stack_index(dim, create_index=create_index) + if idx is not None: + product_vars.update(idx_vars) + + if len(product_vars) == len(dims): + idx, idx_vars = PandasMultiIndex.from_product_variables( + product_vars, new_dim + ) + new_indexes.update({k: idx for k in idx_vars}) + # keep consistent multi-index coordinate order + for k in idx_vars: + new_variables.pop(k, None) + new_variables.update(idx_vars) + new_coord_names.update({new_dim}) indexes = {k: v for k, v in self.xindexes.items() if k not in drop_indexes} indexes.update(new_indexes) return self._replace_with_new_dims( - variables, coord_names=coord_names, indexes=indexes + new_variables, coord_names=new_coord_names, indexes=indexes ) def stack( self, dimensions: Mapping[Any, Sequence[Hashable]] = None, + create_index: Union[bool, None] = True, **dimensions_kwargs: Sequence[Hashable], ) -> "Dataset": """ Stack any number of existing dimensions into a single new dimension. - New dimensions will be added at the end, and the corresponding + New dimensions will be added at the end, and by default the corresponding coordinate variables will be combined into a MultiIndex. Parameters @@ -3984,6 +3997,11 @@ def stack( ellipsis (`...`) will be replaced by all unlisted dimensions. Passing a list containing an ellipsis (`stacked_dim=[...]`) will stack over all dimensions. + create_index : bool, optional + If True (default), create a multi-index for each of the stacked dimensions. + If False, don't create any index. + If None, create a multi-index only if one single (1-d) coordinate index + is found for every dimension to stack. **dimensions_kwargs The keyword arguments form of ``dimensions``. One of dimensions or dimensions_kwargs must be provided. @@ -4000,7 +4018,7 @@ def stack( dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, "stack") result = self for new_dim, dims in dimensions.items(): - result = result._stack_once(dims, new_dim) + result = result._stack_once(dims, new_dim, create_index) return result def to_stacked_array( @@ -4266,9 +4284,9 @@ def unstack( # each specified dimension must have exactly one multi-index stacked_indexes: Dict[Any, Tuple[PandasMultiIndex, List[Any]]] = {} for d in dims: - idx, idx_var_names = self._find_stack_index(d, multi=True) + idx, idx_vars = self._find_stack_index(d, multi=True) if idx is not None: - stacked_indexes[d] = cast(PandasMultiIndex, idx), idx_var_names + stacked_indexes[d] = cast(PandasMultiIndex, idx), list(idx_vars) if dim is None: dims = list(stacked_indexes) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 7b004f954a6..7830ea2ca42 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -532,6 +532,12 @@ def from_product_variables( _check_dim_compat(variables, all_dims="different") level_indexes = [utils.safe_cast_to_index(var) for var in variables.values()] + for name, idx in zip(variables, level_indexes): + if isinstance(idx, pd.MultiIndex): + raise ValueError( + f"cannot create a multi-index along stacked dimension {dim!r} " + f"from variable {name!r} that wraps a multi-index" + ) split_labels, levels = zip(*[lev.factorize() for lev in level_indexes]) labels_mesh = np.meshgrid(*split_labels, indexing="ij") @@ -1008,17 +1014,23 @@ def get_unique(self) -> List[T_PandasOrXarrayIndex]: return unique_indexes + def is_multi(self, key: Hashable) -> bool: + """Return True if ``key`` maps to a multi-coordinate index, + False otherwise. + """ + return len(self._id_coord_names[self._coord_name_id[key]]) > 1 + def get_all_coords( - self, coord_name: Hashable, errors: str = "raise" + self, key: Hashable, errors: str = "raise" ) -> Dict[Hashable, "Variable"]: """Return all coordinates having the same index. Parameters ---------- - coord_name : hashable - Name of an indexed coordinate. + key : hashable + Index key. errors : {"raise", "ignore"}, optional - If "raise", raises a ValueError if `coord_name` is not in indexes. + If "raise", raises a ValueError if `key` is not in indexes. If "ignore", an empty tuple is returned instead. Returns @@ -1030,13 +1042,13 @@ def get_all_coords( if errors not in ["raise", "ignore"]: raise ValueError('errors must be either "raise" or "ignore"') - if coord_name not in self._indexes: + if key not in self._indexes: if errors == "raise": - raise ValueError(f"no index found for {coord_name!r} coordinate") + raise ValueError(f"no index found for {key!r} coordinate") else: return {} - all_coord_names = self._id_coord_names[self._coord_name_id[coord_name]] + all_coord_names = self._id_coord_names[self._coord_name_id[key]] return {k: self._variables[k] for k in all_coord_names} def group_by_index( diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index d5d460056aa..c244d0d6028 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -50,14 +50,14 @@ def assertLazyAnd(self, expected, actual, test): if isinstance(actual, Dataset): for k, v in actual.variables.items(): - if k in actual.dims: + if k in actual.xindexes: assert isinstance(v.data, np.ndarray) else: assert isinstance(v.data, da.Array) elif isinstance(actual, DataArray): assert isinstance(actual.data, da.Array) for k, v in actual.coords.items(): - if k in actual.dims: + if k in actual.xindexes: assert isinstance(v.data, np.ndarray) else: assert isinstance(v.data, da.Array) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 5d1c943d8a6..84ba30675cf 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6618,7 +6618,7 @@ def test_clip(da): assert_array_equal(result.isel(time=[0, 1]), with_nans.isel(time=[0, 1])) # Unclear whether we want this work, OK to adjust the test when we have decided. - with pytest.raises(ValueError, match="arguments without labels along dimension"): + with pytest.raises(ValueError, match="cannot reindex or align along dimension.*"): result = da.clip(min=da.mean("x"), max=da.mean("a").isel(x=[0, 1])) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 207fc598e4c..b6321f88c43 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3096,20 +3096,27 @@ def test_stack(self): assert_identical(expected, actual) assert list(actual.xindexes) == ["z", "y", "x"] - def test_stack_no_index(self) -> None: + @pytest.mark.parametrize( + "create_index,expected_keys", + [ + (True, ["z", "x", "y"]), + (False, []), + (None, ["z", "x", "y"]), + ], + ) + def test_stack_create_index(self, create_index, expected_keys) -> None: ds = Dataset( data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, - coords={"xx": ("x", [0, 1]), "y": ["a", "b"]}, - ) - expected = Dataset( - data_vars={"b": ("z", [0, 1, 2, 3])}, - coords={"xx": ("z", [0, 0, 1, 1]), "y": ("z", ["a", "b", "a", "b"])}, + coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, ) - actual = ds.stack(z=["x", "y"]) - assert_identical(expected, actual) - assert len(actual.xindexes) == 0 + actual = ds.stack(z=["x", "y"], create_index=create_index) + assert list(actual.xindexes) == expected_keys + # TODO: benbovy (flexible indexes) - test error multiple indexes found + # along dimension + create_index=True + + def test_stack_multi_index(self) -> None: # multi-index on a dimension to stack is discarded too midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("lvl1", "lvl2")) ds = xr.Dataset( @@ -3125,10 +3132,13 @@ def test_stack_no_index(self) -> None: "y": ("z", [0, 1, 0, 1] * 2), }, ) - actual = ds.stack(z=["x", "y"]) + actual = ds.stack(z=["x", "y"], create_index=False) assert_identical(expected, actual) assert len(actual.xindexes) == 0 + with pytest.raises(ValueError, match=r"cannot create.*wraps a multi-index"): + ds.stack(z=["x", "y"], create_index=True) + def test_stack_non_dim_coords(self): ds = Dataset( data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, @@ -4418,7 +4428,7 @@ def test_where_other(self): with pytest.raises(ValueError, match=r"cannot set"): ds.where(ds > 1, other=0, drop=True) - with pytest.raises(ValueError, match=r"indexes .* are not equal"): + with pytest.raises(ValueError, match=r"cannot align .* are not equal"): ds.where(ds > 1, ds.isel(x=slice(3))) with pytest.raises(ValueError, match=r"exact match required"): diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index a2e1bf5d59f..af53e3c5a29 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -459,6 +459,10 @@ def test_dims(self, indexes) -> None: def test_get_unique(self, unique_indexes, indexes) -> None: assert indexes.get_unique() == unique_indexes + def test_is_multi(self, indexes) -> None: + assert indexes.is_multi("one") is True + assert indexes.is_multi("x") is False + def test_get_all_coords(self, indexes) -> None: expected = { "z": indexes.variables["z"], From e65e3f314f785330ea3eafedb6607ad5affb2daf Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 27 Oct 2021 12:32:12 +0200 Subject: [PATCH 071/159] add Index.stack and Index.unstack methods Also added a `index_cls` parameter to `Dataset.stack` and `DataArray.stack`. Custom indexes may thus implement their own logic for stack/unstack, with the current limitation that a `pandas.MultiIndex` is still explicitly required for unstack optimized versions. --- xarray/core/dataarray.py | 22 +++++-- xarray/core/dataset.py | 118 +++++++++++++++++------------------ xarray/core/indexes.py | 83 ++++++++++++++---------- xarray/tests/test_dataset.py | 2 +- 4 files changed, 129 insertions(+), 96 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d9087f41dd7..3baa290b3bc 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -15,6 +15,7 @@ Optional, Sequence, Tuple, + Type, Union, cast, ) @@ -45,7 +46,13 @@ from .coordinates import DataArrayCoordinates, assert_coordinate_consistent from .dataset import Dataset from .formatting import format_item -from .indexes import Index, Indexes, default_indexes, propagate_indexes +from .indexes import ( + Index, + Indexes, + PandasMultiIndex, + default_indexes, + propagate_indexes, +) from .indexing import is_fancy_indexer, map_index_queries from .merge import PANDAS_TYPES, MergeError, _create_indexes_from_coords from .options import OPTIONS, _get_keep_attrs @@ -2062,6 +2069,7 @@ def stack( self, dimensions: Mapping[Any, Sequence[Hashable]] = None, create_index: bool = True, + index_cls: Type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable], ) -> "DataArray": """ @@ -2081,8 +2089,11 @@ def stack( create_index : bool, optional If True (default), create a multi-index for each of the stacked dimensions. If False, don't create any index. - If None, create a multi-index only if one single (1-d) coordinate index - is found for every dimension to stack. + If None, create a multi-index only if exactly one single (1-d) coordinate + index is found for every dimension to stack. + index_cls: class, optional + Can be used to pass a custom multi-index type. Must be an Xarray index that + implements `.stack()`. By default, a pandas multi-index wrapper is used. **dimensions_kwargs The keyword arguments form of ``dimensions``. One of dimensions or dimensions_kwargs must be provided. @@ -2120,7 +2131,10 @@ def stack( DataArray.unstack """ ds = self._to_temp_dataset().stack( - dimensions, create_index=create_index, **dimensions_kwargs + dimensions, + create_index=create_index, + index_cls=index_cls, + **dimensions_kwargs, ) return self._from_temp_dataset(ds) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index eb481d51b67..2aa7dcb9c38 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -26,6 +26,7 @@ Sequence, Set, Tuple, + Type, Union, cast, overload, @@ -3875,17 +3876,20 @@ def reorder_levels( return self._replace(variables, indexes=indexes) - def _find_stack_index( + def _get_stack_index( self, dim, multi=False, create_index=False, ) -> Tuple[Union[Index, None], Dict[Hashable, Variable]]: - """Used by stack and unstack to find one pandas (multi-)index among + """Used by stack and unstack to get one pandas (multi-)index among the indexed coordinates along dimension `dim`. - If it finds exactly one index returns it with its corresponding - coordinate variables(s), otherwise returns None and an empty dict. + If exactly one index is found, return it with its corresponding + coordinate variables(s), otherwise return None and an empty dict. + + If `create_index=True`, create a new index if none is found or raise + an error if multiple indexes are found. """ stack_index: Union[Index, None] = None @@ -3897,13 +3901,14 @@ def _find_stack_index( var.ndim == 1 and var.dims[0] == dim and ( + # stack: must be a single coordinate index not multi and not self.xindexes.is_multi(name) + # unstack: must be an index that implements .unstack or multi - and isinstance(index, PandasMultiIndex) + and type(index).unstack is not Index.unstack ) ): - print(dim, name, index) if stack_index is not None and index is not stack_index: # more than one index found, stop if create_index: @@ -3920,13 +3925,13 @@ def _find_stack_index( var = self._variables[dim] else: _, _, var = _get_virtual_variable(self._variables, dim, self.dims) - # dummy index (only var will be used to construct the multi-index) + # dummy index (only `stack_coords` will be used to construct the multi-index) stack_index = PandasIndex([0], dim) stack_coords = {dim: var} return stack_index, stack_coords - def _stack_once(self, dims, new_dim, create_index=True): + def _stack_once(self, dims, new_dim, index_cls, create_index=True): if ... in dims: dims = list(infix_dims(dims, self.dims)) @@ -3955,14 +3960,12 @@ def _stack_once(self, dims, new_dim, create_index=True): if create_index or create_index is None: product_vars: Dict[Any, Variable] = {} for dim in dims: - idx, idx_vars = self._find_stack_index(dim, create_index=create_index) + idx, idx_vars = self._get_stack_index(dim, create_index=create_index) if idx is not None: product_vars.update(idx_vars) if len(product_vars) == len(dims): - idx, idx_vars = PandasMultiIndex.from_product_variables( - product_vars, new_dim - ) + idx, idx_vars = index_cls.stack(product_vars, new_dim) new_indexes.update({k: idx for k in idx_vars}) # keep consistent multi-index coordinate order for k in idx_vars: @@ -3981,6 +3984,7 @@ def stack( self, dimensions: Mapping[Any, Sequence[Hashable]] = None, create_index: Union[bool, None] = True, + index_cls: Type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable], ) -> "Dataset": """ @@ -4000,8 +4004,11 @@ def stack( create_index : bool, optional If True (default), create a multi-index for each of the stacked dimensions. If False, don't create any index. - If None, create a multi-index only if one single (1-d) coordinate index - is found for every dimension to stack. + If None, create a multi-index only if exactly one single (1-d) coordinate + index is found for every dimension to stack. + index_cls: class, optional + Can be used to pass a custom multi-index type (must be an Xarray index that + implements `.stack()`). By default, a pandas multi-index wrapper is used. **dimensions_kwargs The keyword arguments form of ``dimensions``. One of dimensions or dimensions_kwargs must be provided. @@ -4018,7 +4025,7 @@ def stack( dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, "stack") result = self for new_dim, dims in dimensions.items(): - result = result._stack_once(dims, new_dim, create_index) + result = result._stack_once(dims, new_dim, index_cls, create_index) return result def to_stacked_array( @@ -4147,17 +4154,21 @@ def ensure_stackable(val): def _unstack_once( self, dim: Hashable, - index_and_coords: Tuple[PandasMultiIndex, List[Hashable]], + index_and_vars: Tuple[Index, Dict[Hashable, Variable]], fill_value, ) -> "Dataset": - index, index_vnames = index_and_coords - pd_index = remove_unused_levels_categories(index.index) - + index, index_vars = index_and_vars variables: Dict[Hashable, Variable] = {} indexes = {k: v for k, v in self.xindexes.items() if k != dim} + new_indexes, clean_index = index.unstack() + indexes.update(new_indexes) + + for name, idx in new_indexes.items(): + variables.update(idx.create_variables(index_vars)) + for name, var in self.variables.items(): - if name not in index_vnames: + if name not in index_vars: if dim in var.dims: if isinstance(fill_value, Mapping): fill_value_ = fill_value[name] @@ -4165,21 +4176,12 @@ def _unstack_once( fill_value_ = fill_value variables[name] = var._unstack_once( - index=pd_index, dim=dim, fill_value=fill_value_ + index=clean_index, dim=dim, fill_value=fill_value_ ) else: variables[name] = var - for name, lev in zip(pd_index.names, pd_index.levels): - var = self.variables[name] - meta = { - name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} - } - idx, idx_vars = PandasIndex.from_pandas_index(lev, name, var_meta=meta) - variables[name] = idx_vars[name] - indexes[name] = idx - - coord_names = set(self._coord_names) - {dim} | set(pd_index.names) + coord_names = set(self._coord_names) - {dim} | set(new_indexes) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes @@ -4188,46 +4190,44 @@ def _unstack_once( def _unstack_full_reindex( self, dim: Hashable, - index_and_coords: Tuple[PandasMultiIndex, List[Hashable]], + index_and_vars: Tuple[Index, Dict[Hashable, Variable]], fill_value, sparse: bool, ) -> "Dataset": - index, index_vnames = index_and_coords - pd_index = remove_unused_levels_categories(index.index) - full_idx = pd.MultiIndex.from_product(pd_index.levels, names=pd_index.names) + index, index_vars = index_and_vars + variables: Dict[Hashable, Variable] = {} + indexes = {k: v for k, v in self.xindexes.items() if k != dim} + + new_indexes, clean_index = index.unstack() + indexes.update(new_indexes) + + new_index_variables = {} + for name, idx in new_indexes.items(): + new_index_variables.update(idx.create_variables(index_vars)) + + new_dim_sizes = {k: v.size for k, v in new_index_variables.items()} + variables.update(new_index_variables) # take a shortcut in case the MultiIndex was not modified. - if pd_index.equals(full_idx): + full_idx = pd.MultiIndex.from_product( + clean_index.levels, names=clean_index.names + ) + if clean_index.equals(full_idx): obj = self else: + # TODO: we may depreciate implicit re-indexing with a pandas.MultiIndex obj = self._reindex( {dim: full_idx}, copy=False, fill_value=fill_value, sparse=sparse ) - new_dim_names = pd_index.names - new_dim_sizes = [lev.size for lev in pd_index.levels] - - variables: Dict[Hashable, Variable] = {} - indexes = {k: v for k, v in self.xindexes.items() if k != dim} - for name, var in obj.variables.items(): - if name not in index_vnames: + if name not in index_vars: if dim in var.dims: - new_dims = dict(zip(new_dim_names, new_dim_sizes)) - variables[name] = var.unstack({dim: new_dims}) + variables[name] = var.unstack({dim: new_dim_sizes}) else: variables[name] = var - for name, lev in zip(new_dim_names, pd_index.levels): - var = self.variables[name] - meta = { - name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} - } - idx, idx_vars = PandasIndex.from_pandas_index(lev, name, var_meta=meta) - variables[name] = idx_vars[name] - indexes[name] = idx - - coord_names = set(self._coord_names) - {dim} | set(new_dim_names) + coord_names = set(self._coord_names) - {dim} | set(new_dim_sizes) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes @@ -4282,11 +4282,11 @@ def unstack( ) # each specified dimension must have exactly one multi-index - stacked_indexes: Dict[Any, Tuple[PandasMultiIndex, List[Any]]] = {} + stacked_indexes: Dict[Any, Tuple[Index, Dict[Hashable, Variable]]] = {} for d in dims: - idx, idx_vars = self._find_stack_index(d, multi=True) + idx, idx_vars = self._get_stack_index(d, multi=True) if idx is not None: - stacked_indexes[d] = cast(PandasMultiIndex, idx), list(idx_vars) + stacked_indexes[d] = idx, idx_vars if dim is None: dims = list(stacked_indexes) @@ -4295,7 +4295,7 @@ def unstack( if non_multi_dims: raise ValueError( "cannot unstack dimensions that do not " - f"have exactly one MultiIndex: {tuple(non_multi_dims)}" + f"have exactly one multi-index: {tuple(non_multi_dims)}" ) result = self.copy(deep=False) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 7830ea2ca42..9655b2ed40a 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -41,6 +41,17 @@ def from_variables( ) -> Tuple["Index", IndexVars]: raise NotImplementedError() + @classmethod + def stack( + cls, variables: Mapping[Any, "Variable"], dim: Hashable + ) -> Tuple["Index", IndexVars]: + raise NotImplementedError( + f"{cls!r} cannot be used for creating an index of stacked coordinates" + ) + + def unstack(self) -> Tuple[Dict[Hashable, "Index"], pd.MultiIndex]: + raise NotImplementedError() + def create_variables( self, variables: Optional[Mapping[Any, "Variable"]] = None ) -> IndexVars: @@ -474,6 +485,33 @@ def create_variable(name): return variables +def remove_unused_levels_categories(index: pd.Index) -> pd.Index: + """ + Remove unused levels from MultiIndex and unused categories from CategoricalIndex + """ + if isinstance(index, pd.MultiIndex): + index = index.remove_unused_levels() + # if it contains CategoricalIndex, we need to remove unused categories + # manually. See https://github.com/pandas-dev/pandas/issues/30846 + if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): + levels = [] + for i, level in enumerate(index.levels): + if isinstance(level, pd.CategoricalIndex): + level = level[index.codes[i]].remove_unused_categories() + else: + level = level[index.codes[i]] + levels.append(level) + # TODO: calling from_array() reorders MultiIndex levels. It would + # be best to avoid this, if possible, e.g., by using + # MultiIndex.remove_unused_levels() (which does not reorder) on the + # part of the MultiIndex that is not categorical, or by fixing this + # upstream in pandas. + index = pd.MultiIndex.from_arrays(levels, names=index.names) + elif isinstance(index, pd.CategoricalIndex): + index = index.remove_unused_categories() + return index + + class PandasMultiIndex(PandasIndex): """Wrap a pandas.MultiIndex as an xarray compatible index.""" @@ -517,7 +555,7 @@ def from_variables( return obj, index_vars @classmethod - def from_product_variables( + def stack( cls, variables: Mapping[Any, "Variable"], dim: Hashable ) -> Tuple["PandasMultiIndex", IndexVars]: """Create a new Pandas MultiIndex from the product of 1-d variables (levels) along a @@ -547,6 +585,16 @@ def from_product_variables( return cls.from_pandas_index(index, dim, var_meta=_get_var_metadata(variables)) + def unstack(self) -> Tuple[Dict[Hashable, Index], pd.MultiIndex]: + clean_index = remove_unused_levels_categories(self.index) + + new_indexes: Dict[Hashable, Index] = {} + for name, lev in zip(clean_index.names, clean_index.levels): + idx = PandasIndex(lev, name, coord_dtype=self.level_coords_dtype[name]) + new_indexes[name] = idx + + return new_indexes, clean_index + @classmethod def from_variables_maybe_expand( cls, @@ -862,10 +910,8 @@ def create_default_index_implicit( """Create a default index from a dimension variable. Create a PandasMultiIndex if the given variable wraps a pandas.MultiIndex, - otherwise create a PandasIndex. - - This function will become obsolete once we depreciate - implcitly passing a pandas.MultiIndex as a coordinate. + otherwise create a PandasIndex (note that this will become obsolete once we + depreciate implcitly passing a pandas.MultiIndex as a coordinate). """ if all_variables is None: @@ -890,33 +936,6 @@ def create_default_index_implicit( return index, index_vars -def remove_unused_levels_categories(index: pd.Index) -> pd.Index: - """ - Remove unused levels from MultiIndex and unused categories from CategoricalIndex - """ - if isinstance(index, pd.MultiIndex): - index = index.remove_unused_levels() - # if it contains CategoricalIndex, we need to remove unused categories - # manually. See https://github.com/pandas-dev/pandas/issues/30846 - if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): - levels = [] - for i, level in enumerate(index.levels): - if isinstance(level, pd.CategoricalIndex): - level = level[index.codes[i]].remove_unused_categories() - else: - level = level[index.codes[i]] - levels.append(level) - # TODO: calling from_array() reorders MultiIndex levels. It would - # be best to avoid this, if possible, e.g., by using - # MultiIndex.remove_unused_levels() (which does not reorder) on the - # part of the MultiIndex that is not categorical, or by fixing this - # upstream in pandas. - index = pd.MultiIndex.from_arrays(levels, names=index.names) - elif isinstance(index, pd.CategoricalIndex): - index = index.remove_unused_categories() - return index - - # generic type that represents either a pandas or an xarray index T_PandasOrXarrayIndex = TypeVar("T_PandasOrXarrayIndex") diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b6321f88c43..b33f41742d3 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3174,7 +3174,7 @@ def test_unstack_errors(self): ds = Dataset({"x": [1, 2, 3]}) with pytest.raises(ValueError, match=r"does not contain the dimensions"): ds.unstack("foo") - with pytest.raises(ValueError, match=r".*do not have exactly one MultiIndex"): + with pytest.raises(ValueError, match=r".*do not have exactly one multi-index"): ds.unstack("x") def test_unstack_fill_value(self): From d02dc17c6f3a1aa9d556a11299e3205aeb2fd78e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 27 Oct 2021 13:27:07 +0200 Subject: [PATCH 072/159] fix PandasIndex.isel with scalar indexers Fix concat tests and some groupby tests --- xarray/core/indexes.py | 15 ++++++++++----- xarray/tests/test_concat.py | 4 ++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 9655b2ed40a..c5c6b4908c0 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -310,17 +310,22 @@ def isel( from .variable import Variable indxr = indexers[self.dim] - if isinstance(indxr, int): - # can't preserve index with single value - return None - elif isinstance(indxr, Variable): + if isinstance(indxr, Variable): if indxr.dims != (self.dim,): # can't preserve a index if result has new dimensions return None else: indxr = indxr.data + if not isinstance(indxr, slice) and is_scalar(indxr): + # scalar indexer: drop index + return None - return self._replace(self.index[indxr]) + indexed_index = self.index[indxr] + if not len(indexed_index): + # empty index + return None + else: + return self._replace(indexed_index) def query(self, labels: Dict[Any, Any], method=None, tolerance=None) -> QueryResult: from .dataarray import DataArray diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index e049f843bed..a35551f35b6 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -247,7 +247,7 @@ def test_concat_join_kwarg(self): coords={"x": [0, 1], "y": [0]}, ) - with pytest.raises(ValueError, match=r"indexes along dimension 'y'"): + with pytest.raises(ValueError, match=r"cannot align.*exact.*dimensions.*'y'"): actual = concat([ds1, ds2], join="exact", dim="x") for join in expected: @@ -627,7 +627,7 @@ def test_concat_join_kwarg(self): coords={"x": [0, 1], "y": [0]}, ) - with pytest.raises(ValueError, match=r"indexes along dimension 'y'"): + with pytest.raises(ValueError, match=r"cannot align.*exact.*dimensions.*'y'"): actual = concat([ds1, ds2], join="exact", dim="x") for join in expected: From 557e8b1fbbb1ef82bba484e0357c61746da53445 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 27 Oct 2021 20:57:09 +0200 Subject: [PATCH 073/159] fix index adapter (variable) dtype --- xarray/core/indexes.py | 8 +++++--- xarray/core/indexing.py | 15 +++------------ xarray/core/utils.py | 20 ++++++++++++++++++++ xarray/tests/test_indexing.py | 2 +- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index c5c6b4908c0..d26fbc6a9a6 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -24,7 +24,7 @@ from . import formatting, utils from .indexing import PandasIndexingAdapter, PandasMultiIndexingAdapter, QueryResult from .types import T_Index -from .utils import Frozen, is_dict_like, is_scalar +from .utils import Frozen, get_valid_numpy_dtype, is_dict_like, is_scalar if TYPE_CHECKING: from .variable import Variable @@ -208,7 +208,7 @@ def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None): self.dim = dim if coord_dtype is None: - coord_dtype = self.index.dtype + coord_dtype = get_valid_numpy_dtype(np.asarray(array)) self.coord_dtype = coord_dtype def _replace(self, index, dim=None, coord_dtype=None): @@ -528,7 +528,9 @@ def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None): super().__init__(array, dim) if level_coords_dtype is None: - level_coords_dtype = {idx.name: idx.dtype for idx in self.index.levels} + level_coords_dtype = { + idx.name: get_valid_numpy_dtype(idx) for idx in self.index.levels + } self.level_coords_dtype = level_coords_dtype def _replace(self, index, dim=None, level_coords_dtype=None) -> "PandasMultiIndex": diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 4e140e891ee..b1d0b453d80 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -32,7 +32,7 @@ sparse_array_type, ) from .types import T_Xarray -from .utils import either_dict_or_kwargs +from .utils import either_dict_or_kwargs, get_valid_numpy_dtype if TYPE_CHECKING: from .indexes import Index @@ -1361,18 +1361,9 @@ def __init__(self, array: pd.Index, dtype: DTypeLike = None): self.array = utils.safe_cast_to_index(array) if dtype is None: - if isinstance(array, pd.PeriodIndex): - dtype_ = np.dtype("O") - elif hasattr(array, "categories"): - # category isn't a real numpy dtype - dtype_ = array.categories.dtype - elif not utils.is_valid_numpy_dtype(array.dtype): - dtype_ = np.dtype("O") - else: - dtype_ = array.dtype + self._dtype = get_valid_numpy_dtype(array) else: - dtype_ = np.dtype(dtype) # type: ignore[assignment] - self._dtype = dtype_ + self._dtype = np.dtype(dtype) # type: ignore[assignment] @property def dtype(self) -> np.dtype: diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 35a284723c7..5068db2e506 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -71,6 +71,26 @@ def _maybe_cast_to_cftimeindex(index: pd.Index) -> pd.Index: return index +def get_valid_numpy_dtype(array: Union[np.ndarray, pd.Index]): + """Return a numpy compatible dtype from either + a numpy array or a pandas.Index. + + Used for wrapping a pandas.Index as an xarray,Variable. + + """ + if isinstance(array, pd.PeriodIndex): + dtype = np.dtype("O") + elif hasattr(array, "categories"): + # category isn't a real numpy dtype + dtype = array.categories.dtype # type: ignore[union-attr] + elif not is_valid_numpy_dtype(array.dtype): + dtype = np.dtype("O") + else: + dtype = array.dtype + + return dtype + + def maybe_coerce_to_str(index, original_coords): """maybe coerce a pandas Index back to a nunpy array of type str diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index d9ded2d0430..e588bb4c661 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -1,5 +1,5 @@ import itertools -from typing import Any, Dict, cast +from typing import Any import numpy as np import pandas as pd From 15f6b17a2f5ac37e70502ef59b13d23569d8f519 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 28 Oct 2021 14:53:51 +0200 Subject: [PATCH 074/159] more indexes invariants checks Check consistency between PandasIndex objects and coordinate variables data adapters (PandasIndexingAdapter). --- xarray/testing.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/xarray/testing.py b/xarray/testing.py index 673d474d6cb..44aa36bcc22 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -4,11 +4,12 @@ from typing import Hashable, Set, Union import numpy as np +import pandas as pd from xarray.core import duck_array_ops, formatting, utils from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset -from xarray.core.indexes import Index +from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex from xarray.core.variable import IndexVariable, Variable __all__ = ( @@ -262,6 +263,32 @@ def _assert_indexes_invariants_checks(indexes, possible_coord_variables): } assert indexes.keys() <= index_vars, (set(indexes), index_vars) + # check pandas index wrappers vs. coordinate data adapters + for k, index in indexes.items(): + if isinstance(index, PandasIndex): + pd_index = index.index + var = possible_coord_variables[k] + assert (index.dim,) == var.dims, (pd_index, var) + if k == index.dim: + # skip multi-index levels here (checked below) + assert index.coord_dtype == var.dtype, (pd_index, var) + assert isinstance(var._data.array, pd.Index), var._data.array + # TODO: check identity instead of equality? + assert pd_index.equals(var._data.array), (pd_index, var) + if isinstance(index, PandasMultiIndex): + pd_index = index.index + for name in index.index.names: + assert name in possible_coord_variables, (pd_index, index_vars) + var = possible_coord_variables[name] + assert (index.dim,) == var.dims, (pd_index, var) + assert index.level_coords_dtype[name] == var.dtype, (pd_index, var) + assert isinstance(var._data.array, pd.MultiIndex), var._data.array + assert pd_index.equals(var._data.array), (pd_index, var) + # check all all levels are in `indexes` + assert name in indexes, (name, set(indexes)) + # index identity is used to find unique indexes in `indexes` + assert index is indexes[name], (pd_index, indexes[name].index) + # TODO: benbovy - explicit indexes: do we still need these checks? Or opt-in? # non-default indexes are now supported. # defaults = default_indexes(possible_coord_variables, dims) From 29035aa4ec89515d1d23bd64cf4dfd32a15c46e3 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 28 Oct 2021 15:01:47 +0200 Subject: [PATCH 075/159] fix DataArray.isel() --- xarray/core/dataarray.py | 21 +++++++++++++-------- xarray/core/dataset.py | 31 +++---------------------------- xarray/core/indexes.py | 23 +++++++++++++++++++++++ 3 files changed, 39 insertions(+), 36 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 3baa290b3bc..858073290c5 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -51,6 +51,7 @@ Indexes, PandasMultiIndex, default_indexes, + isel_indexes, propagate_indexes, ) from .indexing import is_fancy_indexer, map_index_queries @@ -1190,19 +1191,23 @@ def isel( # lists, or zero or one-dimensional np.ndarray's variable = self._variable.isel(indexers, missing_dims=missing_dims) + indexes, index_variables = isel_indexes(self.xindexes, indexers) coords = {} for coord_name, coord_value in self._coords.items(): - coord_indexers = { - k: v for k, v in indexers.items() if k in coord_value.dims - } - if coord_indexers: - coord_value = coord_value.isel(coord_indexers) - if drop and coord_value.ndim == 0: - continue + if coord_name in index_variables: + coord_value = index_variables[coord_name] + else: + coord_indexers = { + k: v for k, v in indexers.items() if k in coord_value.dims + } + if coord_indexers: + coord_value = coord_value.isel(coord_indexers) + if drop and coord_value.ndim == 0: + continue coords[coord_name] = coord_value - return self._replace(variable=variable, coords=coords) + return self._replace(variable=variable, coords=coords, indexes=indexes) def sel( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2aa7dcb9c38..382d5ba09c9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -65,6 +65,7 @@ PandasMultiIndex, create_default_index_implicit, default_indexes, + isel_indexes, propagate_indexes, remove_unused_levels_categories, roll_index, @@ -2201,7 +2202,7 @@ def isel( dims: Dict[Hashable, Tuple[int, ...]] = {} coord_names = self._coord_names.copy() - indexes, index_variables = self._isel_indexes(indexers) + indexes, index_variables = isel_indexes(self.xindexes, indexers) for name, var in self._variables.items(): # preserve variable order @@ -2227,30 +2228,6 @@ def isel( close=self._close, ) - def _isel_indexes( - self, - indexers: Mapping[Any, Any], - ) -> Tuple[Dict[Hashable, Index], Dict[Hashable, Variable]]: - index_variables: Dict[Hashable, Variable] = {} - indexes: Dict[Hashable, Index] = ( - self._indexes.copy() if self._indexes is not None else {} - ) - - for index, index_vars in self.xindexes.group_by_index(): - index_dims = set(d for var in index_vars.values() for d in var.dims) - index_indexers = {k: v for k, v in indexers.items() if k in index_dims} - if index_indexers: - new_index = index.isel(index_indexers) - if new_index is not None: - indexes.update({k: new_index for k in index_vars}) - new_index_vars = new_index.create_variables(index_vars) - index_variables.update(new_index_vars) - else: - for k in index_vars: - indexes.pop(k, None) - - return indexes, index_variables - def _isel_fancy( self, indexers: Mapping[Any, Any], @@ -2261,9 +2238,7 @@ def _isel_fancy( valid_indexers = dict(self._validate_indexers(indexers, missing_dims)) variables: Dict[Hashable, Variable] = {} - indexes: Dict[Hashable, Index] = {} - - indexes, index_variables = self._isel_indexes(valid_indexers) + indexes, index_variables = isel_indexes(self.xindexes, valid_indexers) for name, var in self.variables.items(): if name in index_variables: diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index d26fbc6a9a6..e5e2f5da357 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1202,3 +1202,26 @@ def check_variables(): not_equal = check_variables() return not not_equal + + +def isel_indexes( + indexes: Indexes[Index], + indexers: Mapping[Any, Any], +) -> Tuple[Dict[Hashable, Index], Dict[Hashable, "Variable"]]: + new_indexes: Dict[Hashable, Index] = {k: v for k, v in indexes.items()} + new_index_variables: Dict[Hashable, Variable] = {} + + for index, index_vars in indexes.group_by_index(): + index_dims = set(d for var in index_vars.values() for d in var.dims) + index_indexers = {k: v for k, v in indexers.items() if k in index_dims} + if index_indexers: + new_index = index.isel(index_indexers) + if new_index is not None: + new_indexes.update({k: new_index for k in index_vars}) + new_index_vars = new_index.create_variables(index_vars) + new_index_variables.update(new_index_vars) + else: + for k in index_vars: + new_indexes.pop(k, None) + + return new_indexes, new_index_variables From 2a067bd2a61faa150a32f14471e8f43e5e7b1935 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 28 Oct 2021 15:06:15 +0200 Subject: [PATCH 076/159] refactor Dataset/DataArray copy Filter indexes and preserve unique index objects for multi-coordinate indexes. Also fixed indexes tests (stack/unstack). --- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 12 ++++++++---- xarray/core/indexes.py | 37 +++++++++++++++++++++++++++++++++++- xarray/tests/test_indexes.py | 31 +++++++++++++++++++++++++----- 4 files changed, 71 insertions(+), 11 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 858073290c5..6f462ba8506 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1053,7 +1053,7 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: if self._indexes is None: indexes = self._indexes else: - indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()} + indexes = self.xindexes.copy_indexes(deep=deep) return self._replace(variable, coords, indexes=indexes) def __copy__(self) -> "DataArray": diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 382d5ba09c9..c97fbb44117 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -65,6 +65,7 @@ PandasMultiIndex, create_default_index_implicit, default_indexes, + filter_indexes_from_coords, isel_indexes, propagate_indexes, remove_unused_levels_categories, @@ -1164,6 +1165,7 @@ def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset": """ if data is None: variables = {k: v.copy(deep=deep) for k, v in self._variables.items()} + indexes = self.xindexes.copy_indexes(deep=deep) elif not utils.is_dict_like(data): raise ValueError("Data must be dict-like") else: @@ -1185,10 +1187,12 @@ def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset": k: v.copy(deep=deep, data=data.get(k)) for k, v in self._variables.items() } + # drop all existing indexes (will create new, default ones) + indexes = {} attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs) - return self._replace(variables, attrs=attrs) + return self._replace(variables, indexes=indexes, attrs=attrs) def as_numpy(self: "Dataset") -> "Dataset": """ @@ -1255,8 +1259,8 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset": if set(self.variables[k].dims) <= needed_dims: variables[k] = self._variables[k] coord_names.add(k) - if k in self.xindexes: - indexes[k] = self.xindexes[k] + + indexes.update(filter_indexes_from_coords(self.xindexes, coord_names)) return self._replace(variables, coord_names, dims, indexes=indexes) @@ -1280,7 +1284,7 @@ def _construct_dataarray(self, name: Hashable) -> "DataArray": if self._indexes is None: indexes = None else: - indexes = {k: v for k, v in self._indexes.items() if k in coords} + indexes = filter_indexes_from_coords(self.xindexes, set(coords)) return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index e5e2f5da357..e8d096e2386 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -944,7 +944,7 @@ def create_default_index_implicit( # generic type that represents either a pandas or an xarray index -T_PandasOrXarrayIndex = TypeVar("T_PandasOrXarrayIndex") +T_PandasOrXarrayIndex = TypeVar("T_PandasOrXarrayIndex", Index, pd.Index) class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): @@ -1108,6 +1108,18 @@ def to_pandas_indexes(self) -> "Indexes[pd.Index]": return Indexes(indexes, self._variables) + def copy_indexes(self, deep: bool = True) -> Dict[Hashable, Index]: + """Return a new dictionary with copies of indexes, preserving + unique indexes. + + """ + new_indexes = {} + for idx, coords in self.group_by_index(): + new_idx = idx.copy(deep=deep) + new_indexes.update({k: new_idx for k in coords}) + + return new_indexes + def __iter__(self): return iter(self._indexes) @@ -1225,3 +1237,26 @@ def isel_indexes( new_indexes.pop(k, None) return new_indexes, new_index_variables + + +def filter_indexes_from_coords( + indexes: Mapping[Any, Index], + filtered_coord_names: Set, +) -> Dict[Hashable, Index]: + """Return filtered indexes from a mapping of filtered coordinate variables. + + Ensure that all multi-coordinate index items are dropped if any of those + coordinate variables is not present in the filtered collection. + + """ + filtered_indexes = {} + + index_coord_names = defaultdict(set) + for name, idx in indexes.items(): + index_coord_names[id(idx)].add(name) + + for idx_coord_names in index_coord_names.values(): + if idx_coord_names <= filtered_coord_names: + filtered_indexes.update({k: indexes[k] for k in idx_coord_names}) + + return filtered_indexes diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index af53e3c5a29..1af6dbea9d7 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -255,13 +255,13 @@ def test_from_variables(self) -> None: ): PandasMultiIndex.from_variables({"level1": v_level1, "level3": v_level3}) - def test_from_product_variables(self) -> None: + def test_stack(self) -> None: prod_vars = { "x": xr.Variable("x", pd.Index(["b", "a"]), attrs={"foo": "bar"}), "y": xr.Variable("y", pd.Index([1, 3, 2])), } - index, index_vars = PandasMultiIndex.from_product_variables(prod_vars, "z") + index, index_vars = PandasMultiIndex.stack(prod_vars, "z") assert index.dim == "z" assert index.index.names == ["x", "y"] @@ -281,18 +281,18 @@ def test_from_product_variables(self) -> None: with pytest.raises( ValueError, match=r"conflicting dimensions for multi-index product.*" ): - PandasMultiIndex.from_product_variables( + PandasMultiIndex.stack( {"x": xr.Variable("x", ["a", "b"]), "x2": xr.Variable("x", [1, 2])}, "z", ) - def test_from_product_variables_non_unique(self) -> None: + def test_stack_non_unique(self) -> None: prod_vars = { "x": xr.Variable("x", pd.Index(["b", "a"]), attrs={"foo": "bar"}), "y": xr.Variable("y", pd.Index([1, 1, 2])), } - index, _ = PandasMultiIndex.from_product_variables(prod_vars, "z") + index, _ = PandasMultiIndex.stack(prod_vars, "z") np.testing.assert_array_equal( index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] @@ -300,6 +300,18 @@ def test_from_product_variables_non_unique(self) -> None: np.testing.assert_array_equal(index.index.levels[0], ["b", "a"]) np.testing.assert_array_equal(index.index.levels[1], [1, 2]) + def test_unstack(self) -> None: + pd_midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2, 3]], names=["one", "two"] + ) + index = PandasMultiIndex(pd_midx, "x") + + new_indexes, new_pd_idx = index.unstack() + assert list(new_indexes) == ["one", "two"] + assert new_indexes["one"].equals(PandasIndex(["a", "b"], "one")) + assert new_indexes["two"].equals(PandasIndex([1, 2, 3], "two")) + assert new_pd_idx.equals(pd_midx) + def test_from_pandas_index(self) -> None: foo_data = np.array([0, 0, 1], dtype="int") bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") @@ -500,3 +512,12 @@ def test_to_pandas_indexes(self, indexes) -> None: assert isinstance(pd_indexes, Indexes) assert all([isinstance(idx, pd.Index) for idx in pd_indexes.values()]) assert indexes.variables == pd_indexes.variables + + def test_copy_indexes(self, indexes) -> None: + copied = indexes.copy_indexes() + + assert copied.keys() == indexes.keys() + for new, original in zip(copied.values(), indexes.values()): + assert new.equals(original) + # check unique index objects preserved + assert copied["z"] is copied["one"] is copied["two"] From c8e7a0688a8881302887e01f28a781dce655b6d6 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 28 Oct 2021 15:23:09 +0200 Subject: [PATCH 077/159] refactor computation (and merge) --- xarray/core/computation.py | 58 +++++++++++++++++++++----------- xarray/core/dataarray.py | 2 +- xarray/core/merge.py | 24 +++++++++---- xarray/tests/test_computation.py | 2 +- 4 files changed, 59 insertions(+), 27 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index bbaae1f5b36..393c372ad15 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -28,6 +28,7 @@ from . import dtypes, duck_array_ops, utils from .alignment import align, deep_align +from .indexes import Index, filter_indexes_from_coords from .merge import merge_attrs, merge_coordinates_without_align from .options import OPTIONS, _get_keep_attrs from .pycompat import is_duck_dask_array @@ -208,13 +209,13 @@ def _get_coords_list(args) -> List[Coordinates]: return coords_list -def build_output_coords( +def build_output_coords_and_indexes( args: list, signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset(), combine_attrs: str = "override", -) -> "List[Dict[Any, Variable]]": - """Build output coordinates for an operation. +) -> Tuple[List[Dict[Any, Variable]], List[Dict[Any, Index]]]: + """Build output coordinates and indexes for an operation. Parameters ---------- @@ -229,7 +230,7 @@ def build_output_coords( Returns ------- - Dictionary of Variable objects with merged coordinates. + Dictionaries of Variable and Index objects with merged coordinates. """ coords_list = _get_coords_list(args) @@ -237,24 +238,30 @@ def build_output_coords( # we can skip the expensive merge (unpacked_coords,) = coords_list merged_vars = dict(unpacked_coords.variables) + merged_indexes = dict(unpacked_coords.xindexes) else: - # TODO: save these merged indexes, instead of re-computing them later - merged_vars, unused_indexes = merge_coordinates_without_align( + merged_vars, merged_indexes = merge_coordinates_without_align( coords_list, exclude_dims=exclude_dims, combine_attrs=combine_attrs ) output_coords = [] + output_indexes = [] for output_dims in signature.output_core_dims: dropped_dims = signature.all_input_core_dims - set(output_dims) if dropped_dims: - filtered = { + filtered_coords = { k: v for k, v in merged_vars.items() if dropped_dims.isdisjoint(v.dims) } + filtered_indexes = filter_indexes_from_coords( + merged_indexes, set(filtered_coords) + ) else: - filtered = merged_vars - output_coords.append(filtered) + filtered_coords = merged_vars + filtered_indexes = merged_indexes + output_coords.append(filtered_coords) + output_indexes.append(filtered_indexes) - return output_coords + return output_coords, output_indexes def apply_dataarray_vfunc( @@ -282,7 +289,7 @@ def apply_dataarray_vfunc( else: first_obj = _first_of_type(args, DataArray) name = first_obj.name - result_coords = build_output_coords( + result_coords, result_indexes = build_output_coords_and_indexes( args, signature, exclude_dims, combine_attrs=keep_attrs ) @@ -291,12 +298,19 @@ def apply_dataarray_vfunc( if signature.num_outputs > 1: out = tuple( - DataArray(variable, coords, name=name, fastpath=True) - for variable, coords in zip(result_var, result_coords) + DataArray( + variable, coords=coords, indexes=indexes, name=name, fastpath=True + ) + for variable, coords, indexes in zip( + result_var, result_coords, result_indexes + ) ) else: (coords,) = result_coords - out = DataArray(result_var, coords, name=name, fastpath=True) + (indexes,) = result_indexes + out = DataArray( + result_var, coords=coords, indexes=indexes, name=name, fastpath=True + ) attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) if isinstance(out, tuple): @@ -397,7 +411,9 @@ def apply_dict_of_variables_vfunc( def _fast_dataset( - variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable] + variables: Dict[Hashable, Variable], + coord_variables: Mapping[Hashable, Variable], + indexes: Dict[Hashable, Index], ) -> Dataset: """Create a dataset as quickly as possible. @@ -407,7 +423,7 @@ def _fast_dataset( variables.update(coord_variables) coord_names = set(coord_variables) - return Dataset._construct_direct(variables, coord_names) + return Dataset._construct_direct(variables, coord_names, indexes=indexes) def apply_dataset_vfunc( @@ -439,7 +455,7 @@ def apply_dataset_vfunc( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False ) - list_of_coords = build_output_coords( + list_of_coords, list_of_indexes = build_output_coords_and_indexes( args, signature, exclude_dims, combine_attrs=keep_attrs ) args = [getattr(arg, "data_vars", arg) for arg in args] @@ -449,10 +465,14 @@ def apply_dataset_vfunc( ) if signature.num_outputs > 1: - out = tuple(_fast_dataset(*args) for args in zip(result_vars, list_of_coords)) + out = tuple( + _fast_dataset(*args) + for args in zip(result_vars, list_of_coords, list_of_indexes) + ) else: (coord_vars,) = list_of_coords - out = _fast_dataset(result_vars, coord_vars) + (indexes,) = list_of_indexes + out = _fast_dataset(result_vars, coord_vars, indexes=indexes) attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) if isinstance(out, tuple): diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6f462ba8506..1e855ab6b31 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -371,7 +371,7 @@ def __init__( name: Hashable = None, attrs: Mapping = None, # internal parameters - indexes: Dict[Hashable, pd.Index] = None, + indexes: Dict[Hashable, Index] = None, fastpath: bool = False, ): if fastpath: diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 0d71f737693..d209aac1291 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -23,7 +23,13 @@ from . import dtypes, pdcompat from .alignment import deep_align from .duck_array_ops import lazy_array_equiv -from .indexes import Index, Indexes, create_default_index_implicit, indexes_equal +from .indexes import ( + Index, + Indexes, + create_default_index_implicit, + filter_indexes_from_coords, + indexes_equal, +) from .utils import Frozen, compat_dict_union, dict_equiv, equivalent from .variable import Variable, as_variable, calculate_dimensions @@ -212,7 +218,6 @@ def merge_collected( for variable, index in elements_list if index is not None ] - if indexed_elements: # TODO(shoyer): consider adjusting this logic. Are we really # OK throwing away variable without an index in favor of @@ -322,14 +327,14 @@ def collect_from_coordinates( list_of_coords: "List[Coordinates]", ) -> Dict[Hashable, List[MergeElement]]: """Collect variables and indexes to be merged from Coordinate objects.""" - grouped: Dict[Hashable, List[Tuple[Variable, Optional[Index]]]] = {} + grouped: Dict[Hashable, List[Tuple[Variable, Optional[Index]]]] = defaultdict(list) for coords in list_of_coords: variables = coords.variables indexes = coords.xindexes for name, variable in variables.items(): - value = grouped.setdefault(name, []) - value.append((variable, indexes.get(name))) + grouped[name].append((variable, indexes.get(name))) + return grouped @@ -359,7 +364,14 @@ def merge_coordinates_without_align( else: filtered = collected - return merge_collected(filtered, prioritized, combine_attrs=combine_attrs) + # TODO: indexes should probably be filtered in collected elements + # before merging them + merged_coords, merged_indexes = merge_collected( + filtered, prioritized, combine_attrs=combine_attrs + ) + merged_indexes = filter_indexes_from_coords(merged_indexes, set(merged_coords)) + + return merged_coords, merged_indexes def determine_coords( diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 4680857219d..4a0ce7e522f 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -964,7 +964,7 @@ def test_dataset_join() -> None: ds1 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]}) # by default, cannot have different labels - with pytest.raises(ValueError, match=r"indexes .* are not equal"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*"): apply_ufunc(operator.add, ds0, ds1) with pytest.raises(TypeError, match=r"must supply"): apply_ufunc(operator.add, ds0, ds1, dataset_join="outer") From 9963520ed31dd8f18544bd7d4b0d03a5b8a9f9eb Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 28 Oct 2021 17:42:57 +0200 Subject: [PATCH 078/159] minor fixes and tweaks --- xarray/core/indexes.py | 10 ++++++++-- xarray/testing.py | 7 +++++-- xarray/tests/test_combine.py | 4 ++-- xarray/tests/test_groupby.py | 12 +++++++++--- 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index e8d096e2386..a519de4877d 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -204,7 +204,14 @@ class PandasIndex(Index): __slots__ = ("index", "dim", "coord_dtype") def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None): - self.index = utils.safe_cast_to_index(array) + index = utils.safe_cast_to_index(array) + if index.name is None: + # cannot use pd.Index.rename as this constructor is also + # called from PandasMultiIndex + index = index.copy() + index.name = dim + + self.index = index self.dim = dim if coord_dtype is None: @@ -264,7 +271,6 @@ def create_variables( attrs = None encoding = None - name = self.index.name data = PandasIndexingAdapter(self.index, dtype=self.coord_dtype) var = IndexVariable(self.dim, data, attrs=attrs, encoding=encoding) return {name: var} diff --git a/xarray/testing.py b/xarray/testing.py index 44aa36bcc22..e37aa3c8aa5 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -271,7 +271,7 @@ def _assert_indexes_invariants_checks(indexes, possible_coord_variables): assert (index.dim,) == var.dims, (pd_index, var) if k == index.dim: # skip multi-index levels here (checked below) - assert index.coord_dtype == var.dtype, (pd_index, var) + assert index.coord_dtype == var.dtype, (index.coord_dtype, var.dtype) assert isinstance(var._data.array, pd.Index), var._data.array # TODO: check identity instead of equality? assert pd_index.equals(var._data.array), (pd_index, var) @@ -281,7 +281,10 @@ def _assert_indexes_invariants_checks(indexes, possible_coord_variables): assert name in possible_coord_variables, (pd_index, index_vars) var = possible_coord_variables[name] assert (index.dim,) == var.dims, (pd_index, var) - assert index.level_coords_dtype[name] == var.dtype, (pd_index, var) + assert index.level_coords_dtype[name] == var.dtype, ( + index.level_coords_dtype[name], + var.dtype, + ) assert isinstance(var._data.array, pd.MultiIndex), var._data.array assert pd_index.equals(var._data.array), (pd_index, var) # check all all levels are in `indexes` diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 3ca964b94e1..ce202b865ef 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -391,7 +391,7 @@ def test_combine_nested_join(self, join, expected): def test_combine_nested_join_exact(self): objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact"): combine_nested(objs, concat_dim="x", join="exact") def test_empty_input(self): @@ -757,7 +757,7 @@ def test_combine_coords_join(self, join, expected): def test_combine_coords_join_exact(self): objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*"): combine_nested(objs, concat_dim="x", join="exact") @pytest.mark.parametrize( diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index d48726e8304..ebcfb1d70de 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -450,10 +450,16 @@ def test_groupby_drops_nans() -> None: actual = grouped.mean() stacked = ds.stack({"xy": ["lat", "lon"]}) expected = ( - stacked.variable.where(stacked.id.notnull()).rename({"xy": "id"}).to_dataset() + stacked.variable.where(stacked.id.notnull()) + .rename({"xy": "id"}) + .to_dataset() + .reset_index("id", drop=True) + .drop_vars(["lon", "lat"]) + .assign(id=stacked.id.values) + .dropna("id") + .transpose(*actual.dims) ) - expected["id"] = stacked.id.values - assert_identical(actual, expected.dropna("id").transpose(*actual.dims)) + assert_identical(actual, expected) # reduction operation along a different dimension actual = grouped.mean("time") From e0f498ab2ff988cd4ea3c6678bda9ecacacb17b1 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 28 Oct 2021 17:47:57 +0200 Subject: [PATCH 079/159] propagate_indexes -> filter_indexes_from_coords We can't just look at excluded dimensions anymore... --- xarray/core/dataarray.py | 17 +++++++---------- xarray/core/dataset.py | 7 ++++--- xarray/core/groupby.py | 4 ++-- xarray/core/indexes.py | 18 ------------------ xarray/tests/test_dataarray.py | 4 ++-- 5 files changed, 15 insertions(+), 35 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1e855ab6b31..636689f32f9 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -51,8 +51,8 @@ Indexes, PandasMultiIndex, default_indexes, + filter_indexes_from_coords, isel_indexes, - propagate_indexes, ) from .indexing import is_fancy_indexer, map_index_queries from .merge import PANDAS_TYPES, MergeError, _create_indexes_from_coords @@ -429,9 +429,11 @@ def _replace( variable = self.variable if coords is None: coords = self._coords + if indexes is None: + indexes = self._indexes if name is _default: name = self.name - return type(self)(variable, coords, name=name, fastpath=True, indexes=indexes) + return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True) def _replace_maybe_drop_dims( self, variable: Variable, name: Union[Hashable, None, Default] = _default @@ -447,18 +449,13 @@ def _replace_maybe_drop_dims( for k, v in self._coords.items() if v.shape == tuple(new_sizes[d] for d in v.dims) } - changed_dims = [ - k for k in variable.dims if variable.sizes[k] != self.sizes[k] - ] - indexes = propagate_indexes(self._indexes, exclude=changed_dims) + indexes = filter_indexes_from_coords(self.xindexes, set(coords)) else: allowed_dims = set(variable.dims) coords = { k: v for k, v in self._coords.items() if set(v.dims) <= allowed_dims } - indexes = propagate_indexes( - self._indexes, exclude=(set(self.dims) - allowed_dims) - ) + indexes = filter_indexes_from_coords(self.xindexes, set(coords)) return self._replace(variable, coords, name, indexes=indexes) def _overwrite_indexes( @@ -517,8 +514,8 @@ def subset(dim, label): variables = {label: subset(dim, label) for label in self.get_index(dim)} variables.update({k: v for k, v in self._coords.items() if k != dim}) - indexes = propagate_indexes(self._indexes, exclude=dim) coord_names = set(self._coords) - {dim} + indexes = filter_indexes_from_coords(self.xindexes, coord_names) dataset = Dataset._construct_direct( variables, coord_names, indexes=indexes, attrs=self.attrs ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c97fbb44117..57fd7b41e09 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -67,7 +67,6 @@ default_indexes, filter_indexes_from_coords, isel_indexes, - propagate_indexes, remove_unused_levels_categories, roll_index, ) @@ -5397,8 +5396,10 @@ def to_array(self, dim="variable", name=None): data = duck_array_ops.stack([b.data for b in broadcast_vars], axis=0) coords = dict(self.coords) - coords[dim] = list(self.data_vars) - indexes = propagate_indexes(self._indexes) + indexes = filter_indexes_from_coords(self.xindexes, set(coords)) + new_dim_index = PandasIndex(list(self.data_vars), dim) + indexes[new_dim_index] = new_dim_index + coords.update(new_dim_index.create_variables()) dims = (dim,) + broadcast_vars[0].dims diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 1ca5de965d0..f83ef0ece62 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -8,7 +8,7 @@ from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from .concat import concat from .formatting import format_array_flat -from .indexes import propagate_indexes +from .indexes import filter_indexes_from_coords from .options import _get_keep_attrs from .pycompat import integer_types from .utils import ( @@ -518,7 +518,7 @@ def _maybe_unstack(self, obj): for dim in self._inserted_dims: if dim in obj.coords: del obj.coords[dim] - obj._indexes = propagate_indexes(obj._indexes, exclude=self._inserted_dims) + obj._indexes = filter_indexes_from_coords(obj.xindexes, set(obj.coords)) return obj def fillna(self, value): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index a519de4877d..38b901f0f1d 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1173,24 +1173,6 @@ def roll_index(index: PandasIndex, count: int, axis: int = 0) -> PandasIndex: return PandasIndex(new_idx, index.dim) -def propagate_indexes( - indexes: Optional[Dict[Hashable, Index]], exclude: Optional[Any] = None -) -> Optional[Dict[Hashable, Index]]: - """Creates new indexes dict from existing dict optionally excluding some dimensions.""" - if exclude is None: - exclude = () - - if is_scalar(exclude): - exclude = (exclude,) - - if indexes is not None: - new_indexes = {k: v for k, v in indexes.items() if k not in exclude} - else: - new_indexes = None # type: ignore[assignment] - - return new_indexes - - def indexes_equal(elements: Sequence[Tuple[Index, Dict[Hashable, "Variable"]]]) -> bool: """Check if indexes are all equal. diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 84ba30675cf..fd2d8fa2b1d 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -23,7 +23,7 @@ from xarray.convert import from_cdms2 from xarray.core import dtypes from xarray.core.common import full_like -from xarray.core.indexes import Index, PandasIndex, propagate_indexes +from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords from xarray.core.utils import is_scalar from xarray.tests import ( LooseVersion, @@ -1319,7 +1319,7 @@ def test_coords(self): assert expected == actual del da.coords["x"] - da._indexes = propagate_indexes(da._indexes, exclude="x") + da._indexes = filter_indexes_from_coords(da.xindexes, set(da.coords)) expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo") assert_identical(da, expected) From 26bc7c9595de4f20bdac69ac17a75cfee7bec443 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 30 Oct 2021 00:53:42 +0200 Subject: [PATCH 080/159] Update xarray/core/coordinates.py --- xarray/core/coordinates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 9ce0e83282f..11f03cb56c4 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -20,7 +20,7 @@ from .indexes import Index, Indexes from .merge import merge_coordinates_without_align, merge_coords from .utils import Frozen, ReprObject -from .variable import calculate_dimensions, Variable +from .variable import Variable, calculate_dimensions if TYPE_CHECKING: from .dataarray import DataArray From 5da11f4831c74cf4cf2517a541737b8120669bab Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 2 Nov 2021 16:36:31 +0100 Subject: [PATCH 081/159] refactor expand_dims --- xarray/core/dataset.py | 21 ++++++++++++--------- xarray/core/indexes.py | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ac0980732bd..98f24242d78 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3541,6 +3541,7 @@ def expand_dims( ) variables: Dict[Hashable, Variable] = {} + indexes: Dict[Hashable, Index] = dict(self.xindexes) coord_names = self._coord_names.copy() # If dim is a dict, then ensure that the values are either integers # or iterables. @@ -3550,7 +3551,9 @@ def expand_dims( # save the coordinates to the variables dict, and set the # value within the dim dict to the length of the iterable # for later use. - variables[k] = xr.IndexVariable((k,), v) + index = PandasIndex(v, k) + indexes[k] = index + variables.update(index.create_variables()) coord_names.add(k) dim[k] = variables[k].size elif isinstance(v, int): @@ -3586,15 +3589,15 @@ def expand_dims( all_dims.insert(d, c) variables[k] = v.set_dims(dict(all_dims)) else: - # If dims includes a label of a non-dimension coordinate, - # it will be promoted to a 1D coordinate with a single value. - variables[k] = v.set_dims(k).to_index_variable() - - new_dims = self._dims.copy() - new_dims.update(dim) + if k not in variables: + # If dims includes a label of a non-dimension coordinate, + # it will be promoted to a 1D coordinate with a single value. + index, index_vars = create_default_index_implicit(v.set_dims(k)) + indexes[k] = index + variables.update(index_vars) - return self._replace_vars_and_dims( - variables, dims=new_dims, coord_names=coord_names + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes ) def set_index( diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index c47982f8b42..1492b4e5f7f 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -920,7 +920,7 @@ def rename(self, name_dict, dims_dict): def create_default_index_implicit( dim_variable: "Variable", all_variables: Optional[Union[Mapping, Iterable[Hashable]]] = None, -) -> Tuple[Index, IndexVars]: +) -> Tuple[PandasIndex, IndexVars]: """Create a default index from a dimension variable. Create a PandasMultiIndex if the given variable wraps a pandas.MultiIndex, From e0b08c1dc2976b82e1aa4771f1e95627814207e7 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 2 Nov 2021 18:43:24 +0100 Subject: [PATCH 082/159] merge: avoid compare same indexes more than once --- xarray/core/alignment.py | 4 ++-- xarray/core/indexes.py | 43 +++++++++++++++++++++++++++++++++++++++- xarray/core/merge.py | 21 ++++++++++---------- 3 files changed, 54 insertions(+), 14 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 5b4e54bc1db..cee568f2d52 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -26,7 +26,7 @@ from . import dtypes from .common import DataWithCoords -from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex, indexes_equal +from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex, indexes_all_equal from .utils import is_dict_like, is_full_slice, safe_cast_to_index from .variable import Variable, calculate_dimensions @@ -326,7 +326,7 @@ def _need_reindex(self, dims, cmp_indexes) -> bool: """ has_unindexed_dims = any(dim in self.unindexed_dim_sizes for dim in dims) - return not (indexes_equal(cmp_indexes)) or has_unindexed_dims + return not (indexes_all_equal(cmp_indexes)) or has_unindexed_dims def _get_index_joiner(self, index_cls) -> Callable: if self.join in ["outer", "inner"]: diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 1492b4e5f7f..62ad38e61b9 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1174,7 +1174,48 @@ def roll_index(index: PandasIndex, count: int, axis: int = 0) -> PandasIndex: return PandasIndex(new_idx, index.dim) -def indexes_equal(elements: Sequence[Tuple[Index, Dict[Hashable, "Variable"]]]) -> bool: +def indexes_equal( + index: Index, + other_index: Index, + variable: "Variable", + other_variable: "Variable", + cache: Dict[Tuple[int, int], Union[bool, None]] = None, +) -> bool: + """Check if two indexes are equal, possibly with cached results. + + If the two indexes are not of the same type or they do not implement + equality, fallback to coordinate labels equality check. + + """ + if cache is None: + # dummy cache + cache = {} + + key = (id(index), id(other_index)) + equal: Union[bool, None] = None + + if key not in cache: + if type(index) is type(other_index): + try: + equal = index.equals(other_index) + except NotImplementedError: + equal = None + else: + cache[key] = equal + else: + equal = None + else: + equal = cache[key] + + if equal is None: + equal = variable.equals(other_variable) + + return cast(bool, equal) + + +def indexes_all_equal( + elements: Sequence[Tuple[Index, Dict[Hashable, "Variable"]]] +) -> bool: """Check if indexes are all equal. If they are not of the same type or they do not implement this check, check diff --git a/xarray/core/merge.py b/xarray/core/merge.py index d209aac1291..4716f0a8fad 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -205,6 +205,7 @@ def merge_collected( merged_vars: Dict[Hashable, Variable] = {} merged_indexes: Dict[Hashable, Index] = {} + index_cmp_cache: Dict[Tuple[int, int], Union[bool, None]] = {} for name, elements_list in grouped.items(): if name in prioritized: @@ -222,18 +223,16 @@ def merge_collected( # TODO(shoyer): consider adjusting this logic. Are we really # OK throwing away variable without an index in favor of # indexed variables, without even checking if values match? - # TODO: benbovy (flexible indexes): possible duplicate index.equals calls - # in case of multi-coordinate indexes. Depending on how this affects the perfs, - # we might need to group the merge elements by matching index. variable, index = indexed_elements[0] - if not indexes_equal( - [(idx, {name: var}) for var, idx in indexed_elements] - ): - # TODO: show differing values/reprs in error msg? - raise MergeError( - f"conflicting values/indexes on objects to be combined " - f"for coordinate {name!r}" - ) + for other_var, other_index in indexed_elements[1:]: + if not indexes_equal( + index, other_index, variable, other_var, index_cmp_cache + ): + raise MergeError( + f"conflicting values/indexes on objects to be combined fo coordinate {name!r}\n" + f"first index: {index!r}\nsecond index: {other_index!r}\n" + f"first variable: {variable!r}\nsecond variable: {other_var!r}\n" + ) if compat == "identical": for other_variable, _ in indexed_elements[1:]: if not dict_equiv(variable.attrs, other_variable.attrs): From fbf556c0810e11ea2ad6b0a15b2a03ede23d0988 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 2 Nov 2021 19:37:00 +0100 Subject: [PATCH 083/159] wip refactor update coords Check in merge that prioritized elements won't corrupt collected indexes. --- xarray/core/merge.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 4716f0a8fad..f40d4ad005f 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -177,6 +177,42 @@ def _assert_compat_valid(compat): MergeElement = Tuple[Variable, Optional[Index]] +def _assert_prioritized_valid( + grouped: Dict[Hashable, List[MergeElement]], + prioritized: Mapping[Any, MergeElement], +) -> None: + """Make sure that elements given in prioritized will not corrupt any + index given in grouped. + + """ + prioritized_by_index: Dict[int, Set[Hashable]] = defaultdict(set) + grouped_by_index: Dict[int, Set[Hashable]] = defaultdict(set) + + for name, (_, index) in prioritized.items(): + if index is not None: + prioritized_by_index[id(index)].add(name) + + for name, elements_list in grouped.items(): + for (_, index) in elements_list: + if index is not None: + grouped_by_index[id(index)].add(name) + + prioritized_cnames = list(prioritized_by_index.values()) + # add non-indexed elements in prioritized individually + prioritized_cnames += [k for k, (_, index) in prioritized.items() if index is None] + + # discard single-coordinate indexes found in `grouped` as they can't be corrupted + grouped_cnames = [v for v in grouped_by_index.values() if len(v) > 1] + + for p_cnames in prioritized_cnames: + for g_cnames in grouped_cnames: + if p_cnames < g_cnames: + raise ValueError( + "cannot set or update coordinate(s) {list(p_cnames)!r} which would corrupt " + "an index built from coordinates {list(g_cnames)!r}" + ) + + def merge_collected( grouped: Dict[Hashable, List[MergeElement]], prioritized: Mapping[Any, MergeElement] = None, @@ -202,6 +238,7 @@ def merge_collected( prioritized = {} _assert_compat_valid(compat) + _assert_prioritized_valid(grouped, prioritized) merged_vars: Dict[Hashable, Variable] = {} merged_indexes: Dict[Hashable, Index] = {} From ef3feb60fc64c950bec270ca1ea9faf43a28bc5e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 3 Nov 2021 14:45:24 +0100 Subject: [PATCH 084/159] refactor update/remove coords Fixes and tweaks Refactored assign, update, setitem, delitem, drop_vars Added or updated tests --- xarray/core/coordinates.py | 3 +- xarray/core/dataset.py | 5 +++ xarray/core/indexes.py | 20 ++++++++++ xarray/core/merge.py | 41 ++++++++++---------- xarray/tests/test_dataarray.py | 31 +++++++++++++++- xarray/tests/test_dataset.py | 68 ++++++++++++++++++++++++++++------ 6 files changed, 131 insertions(+), 37 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 11f03cb56c4..458be214f81 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -17,7 +17,7 @@ import pandas as pd from . import formatting -from .indexes import Index, Indexes +from .indexes import Index, Indexes, assert_no_index_corrupted from .merge import merge_coordinates_without_align, merge_coords from .utils import Frozen, ReprObject from .variable import Variable, calculate_dimensions @@ -362,6 +362,7 @@ def to_dataset(self) -> "Dataset": def __delitem__(self, key: Hashable) -> None: if key not in self: raise KeyError(f"{key!r} is not a coordinate variable.") + assert_no_index_corrupted(self._data.xindexes, {key}) del self._data._coords[key] if self._data._indexes is not None and key in self._data._indexes: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 98f24242d78..135bd4d5eae 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -63,6 +63,7 @@ Indexes, PandasIndex, PandasMultiIndex, + assert_no_index_corrupted, create_default_index_implicit, default_indexes, filter_indexes_from_coords, @@ -1500,6 +1501,8 @@ def _setitem_check(self, key, value): def __delitem__(self, key: Hashable) -> None: """Remove a variable from this dataset.""" + assert_no_index_corrupted(self.xindexes, {key}) + del self._variables[key] self._coord_names.discard(key) if key in self.xindexes: @@ -4484,6 +4487,8 @@ def drop_vars( if errors == "raise": self._assert_all_in_dataset(names) + assert_no_index_corrupted(self.xindexes, names) + variables = {k: v for k, v in self._variables.items() if k not in names} coord_names = {k for k in self._coord_names if k in variables} indexes = {k: v for k, v in self.xindexes.items() if k not in names} diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 62ad38e61b9..f798892fd3d 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1290,3 +1290,23 @@ def filter_indexes_from_coords( filtered_indexes.update({k: indexes[k] for k in idx_coord_names}) return filtered_indexes + + +def assert_no_index_corrupted( + indexes: Indexes[Index], + coord_names: Set[Hashable], +) -> None: + """Assert removing coordinates will not corrupt indexes.""" + + # An index may be corrupted when the set of its corresponding coordinate name(s) + # partially overlaps the set of coordinate names to remove + for index, index_coords in indexes.group_by_index(): + common_names = set(index_coords) & coord_names + if common_names and len(common_names) != len(index_coords): + common_names_str = ", ".join(f"{k!r}" for k in common_names) + index_names_str = ", ".join(f"{k!r}" for k in index_coords) + raise ValueError( + f"cannot remove coordinate(s) {common_names_str}, which would corrupt " + f"the following index built from coordinates {index_names_str}:\n" + f"{index}" + ) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index f40d4ad005f..cb5f751bfb4 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -185,32 +185,29 @@ def _assert_prioritized_valid( index given in grouped. """ - prioritized_by_index: Dict[int, Set[Hashable]] = defaultdict(set) - grouped_by_index: Dict[int, Set[Hashable]] = defaultdict(set) - - for name, (_, index) in prioritized.items(): - if index is not None: - prioritized_by_index[id(index)].add(name) + prioritized_names = set(prioritized) + grouped_by_index: Dict[int, List[Hashable]] = defaultdict(list) + indexes: Dict[int, Index] = {} for name, elements_list in grouped.items(): for (_, index) in elements_list: if index is not None: - grouped_by_index[id(index)].add(name) - - prioritized_cnames = list(prioritized_by_index.values()) - # add non-indexed elements in prioritized individually - prioritized_cnames += [k for k, (_, index) in prioritized.items() if index is None] - - # discard single-coordinate indexes found in `grouped` as they can't be corrupted - grouped_cnames = [v for v in grouped_by_index.values() if len(v) > 1] - - for p_cnames in prioritized_cnames: - for g_cnames in grouped_cnames: - if p_cnames < g_cnames: - raise ValueError( - "cannot set or update coordinate(s) {list(p_cnames)!r} which would corrupt " - "an index built from coordinates {list(g_cnames)!r}" - ) + grouped_by_index[id(index)].append(name) + indexes[id(index)] = index + + # An index may be corrupted when the set of its corresponding coordinate name(s) + # partially overlaps the set of names given in prioritized + for index_id, index_coord_names in grouped_by_index.items(): + index_names = set(index_coord_names) + common_names = index_names & prioritized_names + if common_names and len(common_names) != len(index_names): + common_names_str = ", ".join(f"{k!r}" for k in common_names) + index_names_str = ", ".join(f"{k!r}" for k in index_coord_names) + raise ValueError( + f"cannot set or update variable(s) {common_names_str}, which would corrupt " + f"the following index built from coordinates {index_names_str}:\n" + f"{indexes[index_id]!r}" + ) def merge_collected( diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8159e70e55f..21b606454ad 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1442,8 +1442,10 @@ def test_assign_coords(self): expected = DataArray(10, {"c": 42}) assert_identical(actual, expected) - # TODO: benbovy (explicit indexes) check that multi-index is reset - self.mda.assign_coords(level_1=("x", range(4))) + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + self.mda.assign_coords(level_1=("x", range(4))) # GH: 2112 da = xr.DataArray([0, 1, 2], dims="x") @@ -1469,6 +1471,12 @@ def test_set_coords_update_index(self): actual.coords["x"] = ["a", "b", "c"] assert actual.xindexes["x"].to_pandas_index().equals(pd.Index(["a", "b", "c"])) + def test_set_coords_multiindex_level(self): + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + self.mda["level_1"] = range(4) + def test_coords_replacement_alignment(self): # regression test for GH725 arr = DataArray([0, 1, 2], dims=["abc"]) @@ -1489,6 +1497,12 @@ def test_coords_delitem_delete_indexes(self): del arr.coords["x"] assert "x" not in arr.xindexes + def test_coords_delitem_multiindex_level(self): + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + del self.mda.coords["level_1"] + def test_broadcast_like(self): arr1 = DataArray( np.ones((2, 3)), @@ -2329,6 +2343,19 @@ def test_drop_coordinates(self): actual = renamed.drop_vars("foo", errors="ignore") assert_identical(actual, renamed) + def test_drop_multiindex_level(self): + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + self.mda.drop_vars("level_1") + + def test_drop_all_multiindex_levels(self): + dim_levels = ["x", "level_1", "level_2"] + actual = self.mda.drop_vars(dim_levels) + # no error, multi-index dropped + for key in dim_levels: + assert key not in actual.xindexes + def test_drop_index_labels(self): arr = DataArray(np.random.randn(2, 3), coords={"y": [0, 1, 2]}, dims=["x", "y"]) actual = arr.drop_sel(y=[0, 1]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index d430ecf83b4..f398b8e0ee0 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -739,7 +739,9 @@ def test_coords_setitem_with_new_dimension(self): def test_coords_setitem_multiindex(self): data = create_test_multiindex() - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): data.coords["level_1"] = range(4) def test_coords_set(self): @@ -2337,6 +2339,14 @@ def test_drop_variables(self): actual = data.drop({"time", "not_found_here"}, errors="ignore") assert_identical(expected, actual) + def test_drop_multiindex_level(self): + data = create_test_multiindex() + + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + data.drop_vars("level_1") + def test_drop_index_labels(self): data = Dataset({"A": (["x", "y"], np.random.randn(2, 3)), "x": ["a", "b"]}) @@ -3345,6 +3355,14 @@ def test_update_overwrite_coords(self): expected = Dataset({"a": ("x", [1, 2]), "c": 5}, {"b": 3}) assert_identical(data, expected) + def test_update_multiindex_level(self): + data = create_test_multiindex() + + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + data.update({"level_1": range(4)}) + def test_update_auto_align(self): ds = Dataset({"x": ("t", [3, 4])}, {"t": [0, 1]}) @@ -3485,7 +3503,7 @@ def test_setitem(self): with pytest.raises(ValueError, match=r"already exists as a scalar"): data1["newvar"] = ("scalar", [3, 4, 5]) # can't resize a used dimension - with pytest.raises(ValueError, match=r"arguments without labels"): + with pytest.raises(ValueError, match=r"conflicting dimension sizes"): data1["dim1"] = data1["dim1"][:5] # override an existing value data1["A"] = 3 * data2["A"] @@ -3522,7 +3540,7 @@ def test_setitem(self): with pytest.raises(ValueError, match=err_msg): data4[{"dim2": [2, 3]}] = data3[{"dim2": [2, 3]}] data3["var2"] = data3["var2"].T - err_msg = "indexes along dimension 'dim2' are not equal" + err_msg = r"cannot align objects.*not equal along these coordinates.*" with pytest.raises(ValueError, match=err_msg): data4[{"dim2": [2, 3]}] = data3[{"dim2": [2, 3, 4]}] err_msg = "Dataset assignment only accepts DataArrays, Datasets, and scalars." @@ -3738,22 +3756,39 @@ def test_assign_attrs(self): def test_assign_multiindex_level(self): data = create_test_multiindex() - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): data.assign(level_1=range(4)) data.assign_coords(level_1=range(4)) - # raise an Error when any level name is used as dimension GH:2299 - with pytest.raises(ValueError): - data["y"] = ("level_1", [0, 1]) + + def test_assign_all_multiindex_coords(self): + data = create_test_multiindex() + actual = data.assign(x=range(4), level_1=range(4), level_2=range(4)) + # no error but multi-index dropped in favor of single indexes for each level + assert ( + actual.xindexes["x"] + is not actual.xindexes["level_1"] + is not actual.xindexes["level_2"] + ) def test_merge_multiindex_level(self): data = create_test_multiindex() - other = Dataset({"z": ("level_1", [0, 1])}) # conflict dimension - with pytest.raises(ValueError): + + other = Dataset({"level_1": ("x", [0, 1])}) + with pytest.raises(ValueError, match=r".*conflicting dimension sizes.*"): data.merge(other) - other = Dataset({"level_1": ("x", [0, 1])}) # conflict variable name - with pytest.raises(ValueError): + + other = Dataset({"level_1": ("x", range(4))}) + with pytest.raises( + ValueError, match=r"unable to determine.*coordinates or not.*" + ): data.merge(other) + # `other` Dataset coordinates are ignored (bug or feature?) + other = Dataset(coords={"level_1": ("x", range(4))}) + assert_identical(data.merge(other), data) + def test_setitem_original_non_unique_index(self): # regression test for GH943 original = Dataset({"data": ("x", np.arange(5))}, coords={"x": [0, 1, 2, 0, 1]}) @@ -3785,7 +3820,9 @@ def test_setitem_both_non_unique_index(self): def test_setitem_multiindex_level(self): data = create_test_multiindex() - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): data["level_1"] = range(4) def test_delitem(self): @@ -3803,6 +3840,13 @@ def test_delitem(self): del actual["y"] assert_identical(expected, actual) + def test_delitem_multiindex_level(self): + data = create_test_multiindex() + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + del data["level_1"] + def test_squeeze(self): data = Dataset({"foo": (["x", "y", "z"], [[[1], [2]]])}) for args in [[], [["x"]], [["x", "z"]]]: From 78f4fb0b30bf07d25150de52b9b77ad3132e3513 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 3 Nov 2021 17:09:41 +0100 Subject: [PATCH 085/159] misc. fixes --- xarray/core/dataset.py | 3 +-- xarray/core/indexes.py | 7 +------ xarray/core/indexing.py | 15 +++++++++++++++ xarray/tests/test_backends.py | 4 +++- xarray/tests/test_dask.py | 2 +- xarray/tests/test_dataarray.py | 10 ++++++---- xarray/tests/test_indexes.py | 2 +- 7 files changed, 28 insertions(+), 15 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e391fd0e036..beb323ec625 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4,7 +4,6 @@ import sys import warnings from collections import defaultdict -from dataclasses import astuple from html import escape from numbers import Number from operator import methodcaller @@ -2371,7 +2370,7 @@ def sel( ) result = self.isel(indexers=query_results.dim_indexers, drop=drop) - return result._overwrite_indexes(*astuple(query_results)[1:]) + return result._overwrite_indexes(*query_results.as_tuple()[1:]) def head( self, diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index f798892fd3d..f9dbb941a91 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -327,12 +327,7 @@ def isel( # scalar indexer: drop index return None - indexed_index = self.index[indxr] - if not len(indexed_index): - # empty index - return None - else: - return self._replace(indexed_index) + return self._replace(self.index[indxr]) def query(self, labels: Dict[Any, Any], method=None, tolerance=None) -> QueryResult: from .dataarray import DataArray diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index b1d0b453d80..4b092cb72d8 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -69,6 +69,21 @@ class QueryResult: drop_indexes: List[Hashable] = field(default_factory=list) rename_dims: Dict[Any, Hashable] = field(default_factory=dict) + def as_tuple(self): + """Unlike ``dataclasses.astuple``, return a shallow copy. + + See https://stackoverflow.com/a/51802661 + + """ + return ( + self.dim_indexers, + self.indexes, + self.variables, + self.drop_coords, + self.drop_indexes, + self.rename_dims, + ) + def merge_query_results(results: List[QueryResult]) -> QueryResult: all_dims_count = Counter([dim for res in results for dim in res.dim_indexers]) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 4e9b98b02e9..2ec28c098a9 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3223,7 +3223,9 @@ def test_open_mfdataset_exact_join_raises_error(self, combine, concat_dim, opt): with self.setup_files_and_datasets(fuzz=0.1) as (files, [ds1, ds2]): if combine == "by_coords": files.reverse() - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises( + ValueError, match=r"cannot align objects.*join.*exact.*" + ): open_mfdataset( files, data_vars=opt, diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 6be961fbd67..5a60ca2b368 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1208,7 +1208,7 @@ def sumda(da1, da2): with pytest.raises(ValueError, match=r"Chunk sizes along dimension 'x'"): xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})]) - with pytest.raises(ValueError, match=r"indexes along dimension 'x' are not equal"): + with pytest.raises(ValueError, match=r"cannot align.*index.*are not equal"): xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))]) # reduction diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index d8ec22f06f5..8dbab3cd0e6 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1256,7 +1256,7 @@ def test_selection_multiindex_from_level(self): data = xr.concat([da, db], dim="x").set_index(xy=["x", "y"]) assert data.dims == ("xy",) actual = data.sel(y="a") - expected = data.isel(xy=[0, 1]).unstack("xy").squeeze("y").drop_vars("y") + expected = data.isel(xy=[0, 1]).unstack("xy").squeeze("y") assert_equal(actual, expected) def test_virtual_default_coords(self): @@ -1310,9 +1310,11 @@ def test_coords(self): expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo") assert_identical(da, expected) - # TODO: benbovy (explicit indexes) check that multi-index is reset - self.mda["level_1"] = ("x", np.arange(4)) - self.mda.coords["level_1"] = ("x", np.arange(4)) + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + self.mda["level_1"] = ("x", np.arange(4)) + self.mda.coords["level_1"] = ("x", np.arange(4)) def test_coords_to_index(self): da = DataArray(np.zeros((2, 3)), [("x", [1, 2]), ("y", list("abc"))]) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 1af6dbea9d7..f96f2cc0054 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -34,7 +34,7 @@ def test_constructor(self) -> None: pd_idx = pd.Index([1, 2, 3]) index = PandasIndex(pd_idx, "x") - assert index.index is pd_idx + assert index.index.equals(pd_idx) assert index.dim == "x" def test_from_variables(self) -> None: From 4dcb1360c521636fd6d4c4c2113b9c14a2070188 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 3 Nov 2021 23:40:58 +0100 Subject: [PATCH 086/159] fix Dataset,__delitem__ --- xarray/core/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index beb323ec625..2e28ed7771b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1502,11 +1502,11 @@ def __delitem__(self, key: Hashable) -> None: """Remove a variable from this dataset.""" assert_no_index_corrupted(self.xindexes, {key}) - del self._variables[key] - self._coord_names.discard(key) if key in self.xindexes: assert self._indexes is not None del self._indexes[key] + del self._variables[key] + self._coord_names.discard(key) self._dims = calculate_dimensions(self._variables) # mutable objects should not be hashable From 72aff10111ba7081c0ee00554440c148a963dcd8 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 3 Nov 2021 23:41:28 +0100 Subject: [PATCH 087/159] fix Dataset.reset_coords (default coords) --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2e28ed7771b..d92e2f49e79 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1662,7 +1662,7 @@ def reset_coords( Dataset """ if names is None: - names = self._coord_names - set(self.dims) + names = self._coord_names - set(self.xindexes) else: if isinstance(names, str) or not isinstance(names, Iterable): names = [names] From c0eb3dd33e4ea7a6df75bfc75ec749cbba5e6528 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 9 Nov 2021 16:54:56 +0100 Subject: [PATCH 088/159] refactor Dataset.from_dataframe --- xarray/core/dataset.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d92e2f49e79..b52de32d596 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5652,7 +5652,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas # forwarding arguments to pandas.Series.to_numpy? arrays = [(k, np.asarray(v)) for k, v in dataframe.items()] - obj = cls() + indexes = {} + index_vars = {} if isinstance(idx, pd.MultiIndex): dims = tuple( @@ -5660,11 +5661,17 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas for n, name in enumerate(idx.names) ) for dim, lev in zip(dims, idx.levels): - obj[dim] = (dim, lev) + xr_idx = PandasIndex(lev, dim) + indexes[dim] = xr_idx + index_vars.update(xr_idx.create_variables()) else: index_name = idx.name if idx.name is not None else "index" dims = (index_name,) - obj[index_name] = (dims, idx) + xr_idx = PandasIndex(idx, index_name) + indexes[index_name] = xr_idx + index_vars.update(xr_idx.create_variables()) + + obj = cls._construct_direct(index_vars, set(index_vars), indexes=indexes) if sparse: obj._set_sparse_data_from_dataframe(idx, arrays, dims) From 34749b27fe7d4e76f24c27b18c6411332e5ef76c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 9 Nov 2021 17:19:37 +0100 Subject: [PATCH 089/159] Fix .sel with DataArray and multi-index --- xarray/core/indexing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 4b092cb72d8..7cfe0d2b7f9 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -191,9 +191,12 @@ def map_index_queries( merged = merge_query_results(results) # drop dimension coordinates found in dimension indexers + # (also drop multi-index if any) # (.sel() already ensures alignment) for k, v in merged.dim_indexers.items(): if isinstance(v, DataArray): + if k in v.xindexes: + v = v.reset_index(k) drop_coords = [name for name in v._coords if name in merged.dim_indexers] merged.dim_indexers[k] = v.drop_vars(drop_coords) From ef952b29c3b550f22ce3a1fb3c7e13c28857df82 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 9 Nov 2021 17:20:48 +0100 Subject: [PATCH 090/159] PandasIndex.from_variables: preserve wrapped index Prevent converting special index types like ``pd.CategoricalIndex`` when such objects are wrapped in xarray variables. --- xarray/core/indexes.py | 5 ++++- xarray/tests/test_indexes.py | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index f9dbb941a91..092e49c124f 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -246,7 +246,10 @@ def from_variables( ) dim = var.dims[0] - obj = cls(var.data, dim, coord_dtype=var.dtype) + # preserve wrapped pd.Index (if any) + data = getattr(var._data, "array", var.data) + + obj = cls(data, dim, coord_dtype=var.dtype) obj.index.name = name data = PandasIndexingAdapter(obj.index, dtype=var.dtype) index_var = IndexVariable( diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index f96f2cc0054..4f253d8da65 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -59,6 +59,15 @@ def test_from_variables(self) -> None: ): PandasIndex.from_variables({"foo": var2}) + def test_from_variables_index_adapter(self) -> None: + # test index type is preserved when variable wraps a pd.Index + data = pd.Series(["foo", "bar"], dtype="category") + pd_idx = pd.Index(data) + var = xr.Variable("x", pd_idx) + + index, _ = PandasIndex.from_variables({"x": var}) + assert isinstance(index.index, pd.CategoricalIndex) + def test_from_pandas_index(self) -> None: pd_idx = pd.Index([1, 2, 3], name="foo") From 1d6694a6b92099d5b5529049d10150868fd1c0b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Feb 2022 16:34:27 +0000 Subject: [PATCH 091/159] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/computation.py | 2 +- xarray/tests/test_indexes.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 1cc6fe9f082..aeb274a9ec2 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -964,7 +964,7 @@ def apply_ufunc( Calculate the vector magnitude of two arguments: >>> def magnitude(a, b): - ... func = lambda x, y: np.sqrt(x**2 + y**2) + ... func = lambda x, y: np.sqrt(x ** 2 + y ** 2) ... return xr.apply_ufunc(func, a, b) ... diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 4f253d8da65..52ffeda4426 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List import numpy as np import pandas as pd From ebb260520261ed4aa1b3aaaf39a3b1b1936fdb4b Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 3 Feb 2022 16:38:50 +0100 Subject: [PATCH 092/159] two minor fixes after merging main --- xarray/core/merge.py | 3 ++- xarray/tests/test_formatting.py | 4 +--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index f182f5a56fd..dc2b7c6e973 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -11,6 +11,7 @@ NamedTuple, Optional, Sequence, + Tuple, Union, ) @@ -168,7 +169,7 @@ def _assert_compat_valid(compat): raise ValueError(f"compat={compat!r} invalid: must be {set(_VALID_COMPAT)}") -MergeElement = tuple[Variable, Optional[Index]] +MergeElement = Tuple[Variable, Optional[Index]] def _assert_prioritized_valid( diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 529382279de..105cec7e850 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -558,9 +558,7 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None: display_expand_attrs=False, ): actual = formatting.dataset_repr(ds) - col_width = formatting._calculate_col_width( - formatting._get_col_items(ds.variables) - ) + col_width = formatting._calculate_col_width(ds.variables) dims_start = formatting.pretty_print("Dimensions:", col_width) dims_values = formatting.dim_summary_limited( ds, col_width=col_width + 1, max_rows=display_max_rows From 5c3ff55a84943084e82e72483ef6eadcd7b89561 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 4 Feb 2022 13:47:41 +0100 Subject: [PATCH 093/159] filter_indexes_from_coords: preserve index order --- xarray/core/indexes.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 04d91448692..bd3d670bef5 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1278,21 +1278,22 @@ def filter_indexes_from_coords( indexes: Mapping[Any, Index], filtered_coord_names: Set, ) -> Dict[Hashable, Index]: - """Return filtered indexes from a mapping of filtered coordinate variables. + """Filter index items given a (sub)set of coordinate names. - Ensure that all multi-coordinate index items are dropped if any of those - coordinate variables is not present in the filtered collection. + Drop all multi-coordinate related index items for any key missing in the set + of coordinate names. """ - filtered_indexes = {} + filtered_indexes: Dict[Any, Index] = dict(**indexes) - index_coord_names = defaultdict(set) + index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set) for name, idx in indexes.items(): index_coord_names[id(idx)].add(name) for idx_coord_names in index_coord_names.values(): - if idx_coord_names <= filtered_coord_names: - filtered_indexes.update({k: indexes[k] for k in idx_coord_names}) + if not idx_coord_names <= filtered_coord_names: + for k in idx_coord_names: + del filtered_indexes[k] return filtered_indexes From 8e8690ada58e1b16c8bc44afcca02d5487316246 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 4 Feb 2022 13:49:00 +0100 Subject: [PATCH 094/159] implicit multi-index from coord: fix edge case --- xarray/core/indexes.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index bd3d670bef5..21cf0a63df1 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -935,6 +935,8 @@ def create_default_index_implicit( """ if all_variables is None: all_variables = {} + if not isinstance(all_variables, Mapping): + all_variables = {k: None for k in all_variables} name = dim_variable.dims[0] array = getattr(dim_variable._data, "array", None) @@ -945,10 +947,20 @@ def create_default_index_implicit( # check for conflict between level names and variable names duplicate_names = [k for k in index_vars if k in all_variables and k != name] if duplicate_names: - conflict_str = "\n".join(duplicate_names) - raise ValueError( - f"conflicting MultiIndex level / variable name(s):\n{conflict_str}" - ) + # dirty workaround for an edge case where both the dimension + # coordinate and the level coordinates are given for the same + # multi-index object. + # TODO: remove this check when removing the multi-index dimension coordinate + duplicate_data = [ + getattr(all_variables[k], "_data", None) for k in duplicate_names + ] + duplicate_arrays = [getattr(data, "array", None) for data in duplicate_data] + conflict = any([arr is not array for arr in duplicate_arrays]) + if conflict: + conflict_str = "\n".join(duplicate_names) + raise ValueError( + f"conflicting MultiIndex level / variable name(s):\n{conflict_str}" + ) else: index, index_vars = PandasIndex.from_variables({name: dim_variable}) From 93c3f693c98c66b0e50db298e99b5d98c712996e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 4 Feb 2022 17:45:12 +0100 Subject: [PATCH 095/159] groupby combine: fix missing index in result When coord is None --- xarray/core/groupby.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 45120ef07d0..78ad0b4c79f 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -9,7 +9,7 @@ from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from .concat import concat from .formatting import format_array_flat -from .indexes import filter_indexes_from_coords +from .indexes import create_default_index_implicit, filter_indexes_from_coords from .options import _get_keep_attrs from .pycompat import integer_types from .utils import ( @@ -20,7 +20,7 @@ peek_at, safe_cast_to_index, ) -from .variable import IndexVariable, Variable, as_variable +from .variable import IndexVariable, Variable def check_reduce_dims(reduce_dims, dimensions): @@ -469,6 +469,7 @@ def _infer_concat_args(self, applied_example): (dim,) = coord.dims if isinstance(coord, _DummyGroup): coord = None + coord = getattr(coord, "variable", coord) return coord, dim, positions def _binary_op(self, other, f, reflexive=False): @@ -822,13 +823,11 @@ def _combine(self, applied, shortcut=False): if isinstance(combined, type(self._obj)): # only restore dimension order for arrays combined = self._restore_dim_order(combined) - # assign coord when the applied function does not return that coord + # assign coord and index when the applied function does not return that coord if coord is not None and dim not in applied_example.dims: - if shortcut: - coord_var = as_variable(coord) - combined._coords[coord.name] = coord_var - else: - combined.coords[coord.name] = coord + index, index_vars = create_default_index_implicit(coord) + indexes = {k: index for k in index_vars} + combined = combined._overwrite_indexes(indexes, coords=index_vars) combined = self._maybe_restore_empty_groups(combined) combined = self._maybe_unstack(combined) return combined @@ -944,7 +943,9 @@ def _combine(self, applied): combined = _maybe_reorder(combined, dim, positions) # assign coord when the applied function does not return that coord if coord is not None and dim not in applied_example.dims: - combined[coord.name] = coord + index, index_vars = create_default_index_implicit(coord) + indexes = {k: index for k in index_vars} + combined = combined._overwrite_indexes(indexes, variables=index_vars) combined = self._maybe_restore_empty_groups(combined) combined = self._maybe_unstack(combined) return combined From 3316ee4650527d2f7f18a52e5db6f61796008f80 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 4 Feb 2022 21:17:52 +0100 Subject: [PATCH 096/159] implicit multi-index coord: more robust fix Also cover the cases where indexes are equal but not identical (caused a couple of tests failing). Perfs may not be a concern for such edge case? --- xarray/core/indexes.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 21cf0a63df1..ffa00094426 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -949,13 +949,16 @@ def create_default_index_implicit( if duplicate_names: # dirty workaround for an edge case where both the dimension # coordinate and the level coordinates are given for the same - # multi-index object. + # multi-index object => do not raise an error # TODO: remove this check when removing the multi-index dimension coordinate - duplicate_data = [ - getattr(all_variables[k], "_data", None) for k in duplicate_names - ] - duplicate_arrays = [getattr(data, "array", None) for data in duplicate_data] - conflict = any([arr is not array for arr in duplicate_arrays]) + if len(duplicate_names) < len(index.index.names): + conflict = True + else: + duplicate_vars = [all_variables[k] for k in duplicate_names] + conflict = any( + v is None or not dim_variable.equals(v) for v in duplicate_vars + ) + if conflict: conflict_str = "\n".join(duplicate_names) raise ValueError( From f2b25f5c93a1764418b5b8e0ad0016a77f6b85cc Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 9 Feb 2022 23:21:43 +0100 Subject: [PATCH 097/159] backward compat fix: multi-index given as data-var ... in Dataset constructor --- xarray/core/merge.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index dc2b7c6e973..1e13ca30a66 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -557,7 +557,16 @@ def _create_indexes_from_coords(coords, data_vars=None): indexes = {} updated_coords = {} - for name, obj in coords.items(): + # this is needed for backward compatibility: when a pandas multi-index + # is given as data variable, it is promoted as index / level coordinates + # TODO: depreciate this implicit behavior + index_vars = { + k: v + for k, v in all_variables.items() + if k in coords or isinstance(v, pd.MultiIndex) + } + + for name, obj in index_vars.items(): variable = as_variable(obj, name=name) if variable.dims == (name,): From 627a001fa0cd8e7a135e2dc7433b0a56630df1ec Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 9 Feb 2022 23:26:52 +0100 Subject: [PATCH 098/159] refactor concat Summary: - Add the ``Index.concat()`` class method with implementations for ``PandasIndex`` and ``PandasMultiIndex`` (adapted from ``IndexVariable.concat()``). - Use ``Index.concat()`` to create new indexes (and their corresponding variables) in ``_dataset_concat``. Fallback to variable concatenation (with default index) when no implementation is given for ``Index.concat`` or when no all the variables have an index. - Refactored merging the other variables (also merge the indexes) in ``_dataset_concat``. This refactor is incomplete. It should work with Pandas (multi-)indexes but it is likely that it won't work with meta-indexes involving multiple dimensions. We probably need to update ``_calc_concat_over`` to take into account such meta-indexes and their related coordinates. Other limitation (?) we use here the index of the 1st dataset for the concatenation (i.e., ``Index.concat``). No specific check is made on the type and/or coordinates of the other indexes. --- xarray/core/concat.py | 136 +++++++++++++++++++++++------------- xarray/core/groupby.py | 2 + xarray/core/indexes.py | 71 ++++++++++++++++++- xarray/tests/test_concat.py | 2 +- 4 files changed, 160 insertions(+), 51 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 4621e622d42..50dc8736831 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,14 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Hashable, Iterable, Literal, overload +from typing import TYPE_CHECKING, Any, Hashable, Iterable, Literal, overload import pandas as pd from . import dtypes, utils from .alignment import align from .duck_array_ops import lazy_array_equiv -from .merge import _VALID_COMPAT, merge_attrs, unique_variable -from .variable import IndexVariable, Variable, as_variable +from .indexes import PandasIndex +from .merge import ( + _VALID_COMPAT, + collect_variables_and_indexes, + merge_attrs, + merge_collected, +) +from .variable import Variable from .variable import concat as concat_vars if TYPE_CHECKING: @@ -240,30 +246,31 @@ def concat( ) -def _calc_concat_dim_coord(dim): - """ - Infer the dimension name and 1d coordinate variable (if appropriate) +def _calc_concat_dim_index( + dim_or_data: Hashable | Any, +) -> tuple[Hashable, PandasIndex | None]: + """Infer the dimension name and 1d index / coordinate variable (if appropriate) for concatenating along the new dimension. + """ from .dataarray import DataArray - if isinstance(dim, str): - coord = None - elif not isinstance(dim, (DataArray, Variable)): - dim_name = getattr(dim, "name", None) - if dim_name is None: - dim_name = "concat_dim" - coord = IndexVariable(dim_name, dim) - dim = dim_name - elif not isinstance(dim, DataArray): - coord = as_variable(dim).to_index_variable() - (dim,) = coord.dims + dim: Hashable | None + + if isinstance(dim_or_data, str): + dim = dim_or_data + index = None else: - coord = dim - if coord.name is None: - coord.name = dim.dims[0] - (dim,) = coord.dims - return dim, coord + if not isinstance(dim_or_data, (DataArray, Variable)): + dim = getattr(dim_or_data, "name", None) + if dim is None: + dim = "concat_dim" + else: + (dim,) = dim_or_data.dims + coord_dtype = getattr(dim_or_data, "dtype", None) + index = PandasIndex(dim_or_data, dim, coord_dtype=coord_dtype) + + return dim, index def _calc_concat_over(datasets, dim, dim_names, data_vars, coords, compat): @@ -431,7 +438,8 @@ def _dataset_concat( "The elements in the input list need to be either all 'Dataset's or all 'DataArray's" ) - dim, coord = _calc_concat_dim_coord(dim) + dim, index = _calc_concat_dim_index(dim) + # Make sure we're working on a copy (we'll be loading variables) datasets = [ds.copy() for ds in datasets] datasets = list( @@ -464,22 +472,19 @@ def _dataset_concat( variables_to_merge = (coord_names | data_names) - concat_over - dim_names result_vars = {} + result_indexes = {} + if variables_to_merge: - to_merge: dict[Hashable, list[Variable]] = { - var: [] for var in variables_to_merge + grouped = { + k: v + for k, v in collect_variables_and_indexes(list(datasets)).items() + if k in variables_to_merge } + merged_vars, merged_indexes = merge_collected(grouped, compat=compat) - for ds in datasets: - for var in variables_to_merge: - if var in ds: - to_merge[var].append(ds.variables[var]) + result_vars.update(merged_vars) + result_indexes.update(merged_indexes) - for var in variables_to_merge: - result_vars[var] = unique_variable( - var, to_merge[var], compat=compat, equals=equals.get(var, None) - ) - else: - result_vars = {} result_vars.update(dim_coords) # assign attrs and encoding from first dataset @@ -506,22 +511,53 @@ def ensure_common_dims(vars): var = var.set_dims(common_dims, common_shape) yield var - # stack up each variable to fill-out the dataset (in order) + # stack up each variable and/or index to fill-out the dataset (in order) # n.b. this loop preserves variable order, needed for groupby. - for k in datasets[0].variables: - if k in concat_over: + for name in datasets[0].variables: + if name in concat_over and name not in result_indexes: try: - vars = ensure_common_dims([ds[k].variable for ds in datasets]) + vars = ensure_common_dims([ds[name].variable for ds in datasets]) except KeyError: - raise ValueError(f"{k!r} is not present in all datasets.") - combined = concat_vars(vars, dim, positions, combine_attrs=combine_attrs) - assert isinstance(combined, Variable) - result_vars[k] = combined - elif k in result_vars: + raise ValueError(f"{name!r} is not present in all datasets.") + + # Try concatenate the indexes first, silently fallback to concatenate + # the variables when no index is found on all datasets ot when the + # 1st index doesn't implement concat. + # TODO: (benbovy - explicit indexes): check index types and/or coordinates + # of all datasets? + try: + indexes = [ds.xindexes[name] for ds in datasets] + except KeyError: + combined_var = concat_vars( + vars, dim, positions, combine_attrs=combine_attrs + ) + result_vars[name] = combined_var + else: + try: + combined_idx = indexes[0].concat(indexes, dim, positions) + except NotImplementedError: + # fallback to concat variable(s) + combined_var = concat_vars( + vars, dim, positions, combine_attrs=combine_attrs + ) + result_vars[name] = combined_var + else: + idx_vars = datasets[0].xindexes.get_all_coords(name) + result_indexes.update({k: combined_idx for k in idx_vars}) + combined_idx_vars = combined_idx.create_variables(idx_vars) + for k, v in combined_idx_vars.items(): + v.attrs = merge_attrs( + [ds.variables[k].attrs for ds in datasets], + combine_attrs=combine_attrs, + ) + result_vars[k] = v + + elif name in result_vars: # preserves original variable order - result_vars[k] = result_vars.pop(k) + result_vars[name] = result_vars.pop(name) result = Dataset(result_vars, attrs=result_attrs) + absent_coord_names = coord_names - set(result.variables) if absent_coord_names: raise ValueError( @@ -532,9 +568,13 @@ def ensure_common_dims(vars): result = result.drop_vars(unlabeled_dims, errors="ignore") - if coord is not None: - # add concat dimension last to ensure that its in the final Dataset - result[coord.name] = coord + if index is not None: + # add concat index / coordinate last to ensure that its in the final Dataset + result[dim] = index.create_variables()[dim] + result_indexes[dim] = index + + # TODO: add indexes at Dataset creation (when it is supported) + result = result._overwrite_indexes(result_indexes) return result diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 980db3d38b0..161172a86d5 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -761,6 +761,8 @@ def _concat_shortcut(self, applied, dim, positions=None): # speed things up, but it's not very interpretable and there are much # faster alternatives (e.g., doing the grouped aggregation in a # compiled language) + # TODO: benbovy - explicit indexes: this fast implementation doesn't + # create an explicit index for the stacked dim coordinate stacked = Variable.concat(applied, dim, shortcut=True) reordered = _maybe_reorder(stacked, dim, positions) return self._obj._replace_maybe_drop_dims(reordered) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index ffa00094426..03caeff7ab3 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -14,6 +14,7 @@ Sequence, Set, Tuple, + Type, TypeVar, Union, cast, @@ -22,7 +23,7 @@ import numpy as np import pandas as pd -from . import formatting, utils +from . import formatting, nputils, utils from .indexing import PandasIndexingAdapter, PandasMultiIndexingAdapter, QueryResult from .types import T_Index from .utils import Frozen, get_valid_numpy_dtype, is_dict_like, is_scalar @@ -42,6 +43,15 @@ def from_variables( ) -> Tuple["Index", IndexVars]: raise NotImplementedError() + @classmethod + def concat( + cls: Type[T_Index], + indexes: Sequence[T_Index], + dim: Hashable, + positions: Iterable[int] = None, + ) -> T_Index: + raise NotImplementedError() + @classmethod def stack( cls, variables: Mapping[Any, "Variable"], dim: Hashable @@ -56,7 +66,11 @@ def unstack(self) -> Tuple[Dict[Hashable, "Index"], pd.MultiIndex]: def create_variables( self, variables: Optional[Mapping[Any, "Variable"]] = None ) -> IndexVars: - return {} + if variables is not None: + # pass through + return dict(**variables) + else: + return {} def to_pandas_index(self) -> pd.Index: """Cast this xarray index to a pandas.Index object or raise a TypeError @@ -258,6 +272,39 @@ def from_variables( return obj, {name: index_var} + @staticmethod + def _concat_indexes(indexes, dim, positions=None) -> pd.Index: + new_pd_index: pd.Index + + if not indexes: + new_pd_index = pd.Index([]) + else: + assert all(idx.dim == dim for idx in indexes) + pd_indexes = [idx.index for idx in indexes] + new_pd_index = pd_indexes[0].append(pd_indexes[1:]) + + if positions is not None: + indices = nputils.inverse_permutation(np.concatenate(positions)) + new_pd_index = new_pd_index.take(indices) + + return new_pd_index + + @classmethod + def concat( + cls, + indexes: Sequence["PandasIndex"], + dim: Hashable, + positions: Iterable[int] = None, + ) -> "PandasIndex": + new_pd_index = cls._concat_indexes(indexes, dim, positions) + + if not indexes: + coord_dtype = None + else: + coord_dtype = np.result_type(*[idx.coord_dtype for idx in indexes]) + + return cls(new_pd_index, dim=dim, coord_dtype=coord_dtype) + def create_variables( self, variables: Optional[Mapping[Any, "Variable"]] = None ) -> IndexVars: @@ -573,6 +620,26 @@ def from_variables( return obj, index_vars + @classmethod + def concat( # type: ignore[override] + cls, + indexes: Sequence["PandasMultiIndex"], + dim: Hashable, + positions: Iterable[int] = None, + ) -> "PandasMultiIndex": + new_pd_index = cls._concat_indexes(indexes, dim, positions) + + if not indexes: + level_coords_dtype = None + else: + level_coords_dtype = {} + for name in indexes[0].level_coords_dtype: + level_coords_dtype[name] = np.result_type( + *[idx.level_coords_dtype[name] for idx in indexes] + ) + + return cls(new_pd_index, dim=dim, level_coords_dtype=level_coords_dtype) + @classmethod def stack( cls, variables: Mapping[Any, "Variable"], dim: Hashable diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 31bccd92cdb..7de72c352f6 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -465,7 +465,7 @@ def test_concat_dim_is_variable(self) -> None: def test_concat_multiindex(self) -> None: x = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]]) - expected = Dataset({"x": x}) + expected = Dataset(coords={"x": x}) actual = concat( [expected.isel(x=slice(2)), expected.isel(x=slice(2, None))], "x" ) From ba1c75f28c64f88370d9edfe323635b2576b13f2 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 10 Feb 2022 10:04:36 +0100 Subject: [PATCH 099/159] sel drop=True: remove multi-index scalar coords --- xarray/core/dataset.py | 10 ++++++++++ xarray/tests/test_dataset.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 223ae28bdb1..5df0022b61f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2364,6 +2364,16 @@ def sel( self, indexers=indexers, method=method, tolerance=tolerance ) + if drop: + no_scalar_variables = {} + for k, v in query_results.variables.items(): + if v.dims: + no_scalar_variables[k] = v + else: + if k in self._coord_names: + query_results.drop_coords.append(k) + query_results.variables = no_scalar_variables + result = self.isel(indexers=query_results.dim_indexers, drop=drop) return result._overwrite_indexes(*query_results.as_tuple()[1:]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index d71053e538a..f1472584766 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1483,6 +1483,16 @@ def test_sel_drop(self): selected = data.sel(x=0, drop=True) assert_identical(expected, selected) + def test_sel_drop_mindex(self): + midx = pd.MultiIndex.from_arrays([["a", "a"], [1, 2]], names=("foo", "bar")) + data = Dataset(coords={"x": midx}) + + actual = data.sel(foo="a", drop=True) + assert "foo" not in actual.coords + + actual = data.sel(foo="a", drop=False) + assert_equal(actual.foo, DataArray("a", coords={"foo": "a"})) + def test_isel_drop(self): data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]}) expected = Dataset({"foo": 1}) From 3371623762eea09389347d3628fcf7128c65b6ce Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 10 Feb 2022 10:59:52 +0100 Subject: [PATCH 100/159] PandasIndex.from_variables(multi-index level var) Make sure it gets the level index (not the whole multi-index). --- xarray/core/indexes.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 03caeff7ab3..fb7f0547872 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -260,10 +260,21 @@ def from_variables( ) dim = var.dims[0] + + # TODO: (benbovy - explicit indexes): add __index__ to ExplicitlyIndexesNDArrayMixin? + # this could be eventually used by Variable.to_index() and would remove the need to perform + # the checks below. + # preserve wrapped pd.Index (if any) data = getattr(var._data, "array", var.data) + # multi-index level variable: get level index + if isinstance(var._data, PandasMultiIndexingAdapter): + level = var._data.level + if level is not None: + data = var._data.array.get_level_values(level) obj = cls(data, dim, coord_dtype=var.dtype) + assert not isinstance(obj.index, pd.MultiIndex) obj.index.name = name data = PandasIndexingAdapter(obj.index, dtype=var.dtype) index_var = IndexVariable( From 8fec03c140142cdf4e03878f456c48f6541a0818 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 10 Feb 2022 11:39:20 +0100 Subject: [PATCH 101/159] reindex: disable invalid dimension check From the docstrings: mis-matched dimensions are simply ignored --- xarray/core/alignment.py | 15 +++++++++------ xarray/tests/test_dataset.py | 6 ++++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index a8c51b6bede..cca3c182f24 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -849,12 +849,15 @@ def reindex( Not public API. """ - bad_keys = [k for k in indexers if k not in obj.xindexes and k not in obj.dims] - if bad_keys: - raise ValueError( - f"indexer keys {bad_keys} do not correspond to any indexed coordinate " - "or unindexed dimension in the object to reindex" - ) + + # TODO: (benbovy - explicit indexes): uncomment? + # --> from reindex docstrings: "any mis-matched dimension is simply ignored" + # bad_keys = [k for k in indexers if k not in obj.xindexes and k not in obj.dims] + # if bad_keys: + # raise ValueError( + # f"indexer keys {bad_keys} do not correspond to any indexed coordinate " + # "or unindexed dimension in the object to reindex" + # ) aligner = Aligner( (obj,), diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f1472584766..802f2c1b948 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1831,8 +1831,10 @@ def test_reindex(self): data.reindex("foo") # invalid dimension - with pytest.raises(ValueError, match=r"indexer keys.*not correspond.*"): - data.reindex(invalid=0) + # TODO: (benbovy - explicit indexes): uncomment? + # --> from reindex docstrings: "any mis-matched dimension is simply ignored" + # with pytest.raises(ValueError, match=r"indexer keys.*not correspond.*"): + # data.reindex(invalid=0) # out of order expected = data.sel(dim2=data["dim2"][:5:-1]) From d76be3d9d4e8ac6bd8089a16304dad6acfdfcee6 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 11 Feb 2022 10:18:15 +0100 Subject: [PATCH 102/159] add index concat tests + fix positions type --- xarray/core/concat.py | 8 +++--- xarray/core/indexes.py | 6 ++--- xarray/tests/test_indexes.py | 48 ++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 50dc8736831..52e5a0d5ae4 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -34,7 +34,7 @@ def concat( data_vars: concat_options | list[Hashable] = "all", coords: concat_options | list[Hashable] = "different", compat: compat_options = "equals", - positions: Iterable[int] | None = None, + positions: Iterable[Iterable[int]] | None = None, fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: str = "override", @@ -49,7 +49,7 @@ def concat( data_vars: concat_options | list[Hashable] = "all", coords: concat_options | list[Hashable] = "different", compat: compat_options = "equals", - positions: Iterable[int] | None = None, + positions: Iterable[Iterable[int]] | None = None, fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: str = "override", @@ -421,7 +421,7 @@ def _dataset_concat( data_vars: str | list[str], coords: str | list[str], compat: str, - positions: Iterable[int] | None, + positions: Iterable[Iterable[int]] | None, fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: str = "override", @@ -585,7 +585,7 @@ def _dataarray_concat( data_vars: str | list[str], coords: str | list[str], compat: str, - positions: Iterable[int] | None, + positions: Iterable[Iterable[int]] | None, fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: str = "override", diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index fb7f0547872..23201e612a9 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -48,7 +48,7 @@ def concat( cls: Type[T_Index], indexes: Sequence[T_Index], dim: Hashable, - positions: Iterable[int] = None, + positions: Iterable[Iterable[int]] = None, ) -> T_Index: raise NotImplementedError() @@ -305,7 +305,7 @@ def concat( cls, indexes: Sequence["PandasIndex"], dim: Hashable, - positions: Iterable[int] = None, + positions: Iterable[Iterable[int]] = None, ) -> "PandasIndex": new_pd_index = cls._concat_indexes(indexes, dim, positions) @@ -636,7 +636,7 @@ def concat( # type: ignore[override] cls, indexes: Sequence["PandasMultiIndex"], dim: Hashable, - positions: Iterable[int] = None, + positions: Iterable[Iterable[int]] = None, ) -> "PandasMultiIndex": new_pd_index = cls._concat_indexes(indexes, dim, positions) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 52ffeda4426..4926dc08cc9 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -68,6 +68,36 @@ def test_from_variables_index_adapter(self) -> None: index, _ = PandasIndex.from_variables({"x": var}) assert isinstance(index.index, pd.CategoricalIndex) + def test_concat_periods(self): + periods = pd.period_range("2000-01-01", periods=10) + indexes = [PandasIndex(periods[:5], "t"), PandasIndex(periods[5:], "t")] + expected = PandasIndex(periods, "t") + actual = PandasIndex.concat(indexes, dim="t") + assert actual.equals(expected) + assert isinstance(actual.index, pd.PeriodIndex) + + positions = [list(range(5)), list(range(5, 10))] + actual = PandasIndex.concat(indexes, dim="t", positions=positions) + assert actual.equals(expected) + assert isinstance(actual.index, pd.PeriodIndex) + + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_concat_str_dtype(self, dtype) -> None: + + a = PandasIndex(np.array(["a"], dtype=dtype), "x", coord_dtype=dtype) + b = PandasIndex(np.array(["b"], dtype=dtype), "x", coord_dtype=dtype) + expected = PandasIndex( + np.array(["a", "b"], dtype=dtype), "x", coord_dtype=dtype + ) + + actual = PandasIndex.concat([a, b], "x") + assert actual.equals(expected) + assert np.issubdtype(actual.coord_dtype, dtype) + + def test_concat_empty(self) -> None: + idx = PandasIndex.concat([], "x") + assert idx.coord_dtype is np.dtype("O") + def test_from_pandas_index(self) -> None: pd_idx = pd.Index([1, 2, 3], name="foo") @@ -264,6 +294,24 @@ def test_from_variables(self) -> None: ): PandasMultiIndex.from_variables({"level1": v_level1, "level3": v_level3}) + def test_concat(self) -> None: + pd_midx = pd.MultiIndex.from_product( + [[0, 1, 2], ["a", "b"]], names=("foo", "bar") + ) + level_coords_dtype = {"foo": np.int32, "bar": " None: prod_vars = { "x": xr.Variable("x", pd.Index(["b", "a"]), attrs={"foo": "bar"}), From 3b78f4b2caea1de13b65070c8cff09d81c5875f4 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 11 Feb 2022 10:19:31 +0100 Subject: [PATCH 103/159] add Indexes.get_all_dims convenient method --- xarray/core/indexes.py | 23 +++++++++++++++++++++++ xarray/tests/test_indexes.py | 4 ++++ 2 files changed, 27 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 23201e612a9..148b5ab6fa2 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1182,6 +1182,29 @@ def get_all_coords( all_coord_names = self._id_coord_names[self._coord_name_id[key]] return {k: self._variables[k] for k in all_coord_names} + def get_all_dims( + self, key: Hashable, errors: str = "raise" + ) -> Mapping[Hashable, int]: + """Return all dimensions shared by an index. + + Parameters + ---------- + key : hashable + Index key. + errors : {"raise", "ignore"}, optional + If "raise", raises a ValueError if `key` is not in indexes. + If "ignore", an empty tuple is returned instead. + + Returns + ------- + dims : dict + A dictionary of all dimensions shared by an index. + + """ + from .variable import calculate_dimensions + + return calculate_dimensions(self.get_all_coords(key, errors=errors)) + def group_by_index( self, ) -> List[Tuple[T_PandasOrXarrayIndex, Dict[Hashable, "Variable"]]]: diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 4926dc08cc9..ee0e7ed7853 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -548,6 +548,10 @@ def test_get_all_coords(self, indexes) -> None: assert indexes.get_all_coords("no_coord", errors="ignore") == {} + def test_get_all_dims(self, indexes) -> None: + expected = {"z": 4} + assert indexes.get_all_dims("one") == expected + def test_group_by_index(self, unique_indexes, indexes): expected = [ (unique_indexes[0], {"x": indexes.variables["x"]}), From 7ce000e4a55f161dc6503ba7372725e6e6552232 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 11 Feb 2022 11:27:42 +0100 Subject: [PATCH 104/159] refactor pad --- xarray/core/dataarray.py | 3 +++ xarray/core/dataset.py | 21 ++++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6240c9cec85..d7866cd516d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4088,6 +4088,9 @@ def pad( For ``mode="constant"`` and ``constant_values=None``, integer types will be promoted to ``float`` and padded with ``np.nan``. + Padding coordinates will drop their corresponding index (if any) and will reset default + indexes for dimension coordinates. + Examples -------- >>> arr = xr.DataArray([5, 6, 7], coords=[("x", [0, 1, 2])]) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5df0022b61f..31fea87ee9e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7306,6 +7306,9 @@ def pad( promoted to ``float`` and padded with ``np.nan``. To avoid type promotion specify ``constant_values=np.nan`` + Padding coordinates will drop their corresponding index (if any) and will reset default + indexes for dimension coordinates. + Examples -------- >>> ds = xr.Dataset({"foo": ("x", range(5))}) @@ -7331,6 +7334,15 @@ def pad( coord_pad_options = {} variables = {} + + # keep indexes that won't be affected by pad and drop all other indexes + xindexes = self.xindexes + pad_dims = set(pad_width) + indexes = {} + for k, idx in xindexes.items(): + if not pad_dims.intersection(xindexes.get_all_dims(k)): + indexes[k] = idx + for name, var in self.variables.items(): var_pad_width = {k: v for k, v in pad_width.items() if k in var.dims} if not var_pad_width: @@ -7350,8 +7362,15 @@ def pad( mode=coord_pad_mode, **coord_pad_options, # type: ignore[arg-type] ) + # reset default index of dimension coordinates + if (name,) == var.dims: + index, index_vars = PandasIndex.from_variables( + {name: variables[name]} + ) + indexes[name] = index + variables[name] = index_vars[name] - return self._replace_vars_and_dims(variables) + return self._replace_with_new_dims(variables, indexes=indexes) def idxmin( self, From 5fa287fc1bf9bb3e0c5b23cd4d3a29c30d195199 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 11 Feb 2022 13:49:46 +0100 Subject: [PATCH 105/159] unstack: return copies of mindex.levels Fix error when trying to set the name of the extracted level indexes. --- xarray/core/indexes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 148b5ab6fa2..0abaa287a57 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -687,7 +687,9 @@ def unstack(self) -> Tuple[Dict[Hashable, Index], pd.MultiIndex]: new_indexes: Dict[Hashable, Index] = {} for name, lev in zip(clean_index.names, clean_index.levels): - idx = PandasIndex(lev, name, coord_dtype=self.level_coords_dtype[name]) + idx = PandasIndex( + lev.copy(), name, coord_dtype=self.level_coords_dtype[name] + ) new_indexes[name] = idx return new_indexes, clean_index From 11bf070c7cc02937c030366cdffa25c42f42ec70 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 11 Feb 2022 16:48:20 +0100 Subject: [PATCH 106/159] strip_units: prevent index->array conversion which caused alignment conflicts due to multi-index variable converted to non-index variable. --- xarray/tests/test_units.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index a083c50c3d1..f8d147edc2a 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -147,10 +147,10 @@ def strip_units(obj): new_obj = xr.Dataset(data_vars=data_vars, coords=coords) elif isinstance(obj, xr.DataArray): - data = array_strip_units(obj.data) + data = array_strip_units(obj.variable._data) coords = { strip_units(name): ( - (value.dims, array_strip_units(value.data)) + (value.dims, array_strip_units(value.variable._data)) if isinstance(value.data, Quantity) else value # to preserve multiindexes ) From 35a74640049f08ccbd90a9e398cbe322db5571c4 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 11 Feb 2022 17:39:00 +0100 Subject: [PATCH 107/159] attach_units: do not preserve coord multi-indexes To prevent conflicts when trying to implicitly create multi-index coordinates. --- xarray/tests/test_units.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index f8d147edc2a..c820ddd26ca 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -198,8 +198,7 @@ def attach_units(obj, units): name: ( (value.dims, array_attach_units(value.data, units.get(name) or 1)) if name in units - # to preserve multiindexes - else value + else (value.dims, value.data) ) for name, value in obj.coords.items() } From cfb0cf1ed3cc8ca98032a901aca398ab731686a5 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 09:52:07 +0100 Subject: [PATCH 108/159] concat: avoid multiple loads of dask arrays --- xarray/core/concat.py | 8 ++++---- xarray/core/merge.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 52e5a0d5ae4..e1aa16de7cd 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -480,8 +480,9 @@ def _dataset_concat( for k, v in collect_variables_and_indexes(list(datasets)).items() if k in variables_to_merge } - merged_vars, merged_indexes = merge_collected(grouped, compat=compat) - + merged_vars, merged_indexes = merge_collected( + grouped, compat=compat, equals=equals + ) result_vars.update(merged_vars) result_indexes.update(merged_indexes) @@ -521,7 +522,7 @@ def ensure_common_dims(vars): raise ValueError(f"{name!r} is not present in all datasets.") # Try concatenate the indexes first, silently fallback to concatenate - # the variables when no index is found on all datasets ot when the + # the variables when no index is found on all datasets or when the # 1st index doesn't implement concat. # TODO: (benbovy - explicit indexes): check index types and/or coordinates # of all datasets? @@ -536,7 +537,6 @@ def ensure_common_dims(vars): try: combined_idx = indexes[0].concat(indexes, dim, positions) except NotImplementedError: - # fallback to concat variable(s) combined_var = concat_vars( vars, dim, positions, combine_attrs=combine_attrs ) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 1e13ca30a66..6d39fa47e52 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -209,7 +209,8 @@ def merge_collected( grouped: dict[Hashable, list[MergeElement]], prioritized: Mapping[Any, MergeElement] = None, compat: str = "minimal", - combine_attrs="override", + combine_attrs: str | None = "override", + equals: dict[Hashable, bool] = None, ) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: """Merge dicts of variables, while resolving conflicts appropriately. @@ -219,6 +220,8 @@ def merge_collected( prioritized : mapping compat : str Type of equality check to use when checking for conflicts. + equals : mapping, optional + corresponding to result of compat test Returns ------- @@ -228,6 +231,8 @@ def merge_collected( """ if prioritized is None: prioritized = {} + if equals is None: + equals = {} _assert_compat_valid(compat) _assert_prioritized_valid(grouped, prioritized) @@ -278,7 +283,9 @@ def merge_collected( else: variables = [variable for variable, _ in elements_list] try: - merged_vars[name] = unique_variable(name, variables, compat) + merged_vars[name] = unique_variable( + name, variables, compat, equals.get(name, None) + ) except MergeError: if compat != "minimal": # we need more than "minimal" compatibility (for which From 8cc125308352d52d0c0995c6e8588cc524ef7b45 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 10:34:17 +0100 Subject: [PATCH 109/159] groupby mindex coord: propagate level names --- xarray/core/groupby.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 161172a86d5..b6086cf910d 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -54,6 +54,8 @@ def unique_value_groups(ar, sort=True): the corresponding value in `unique_values`. """ inverse, values = pd.factorize(ar, sort=sort) + if isinstance(values, pd.MultiIndex): + values.names = ar.names groups = [[] for _ in range(len(values))] for n, g in enumerate(inverse): if g >= 0: From ead47a20a80cc0011b14a7b701497ea7ddac4991 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 11:45:35 +0100 Subject: [PATCH 110/159] Dataset.copy: don't reset indexes if data is given The `data` parameter accepts new data for data variables only, it won't affect the indexes. --- xarray/core/dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 31fea87ee9e..624333167a2 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1161,7 +1161,6 @@ def copy(self, deep: bool = False, data: Mapping = None) -> Dataset: """ if data is None: variables = {k: v.copy(deep=deep) for k, v in self._variables.items()} - indexes = self.xindexes.copy_indexes(deep=deep) elif not utils.is_dict_like(data): raise ValueError("Data must be dict-like") else: @@ -1183,8 +1182,8 @@ def copy(self, deep: bool = False, data: Mapping = None) -> Dataset: k: v.copy(deep=deep, data=data.get(k)) for k, v in self._variables.items() } - # drop all existing indexes (will create new, default ones) - indexes = {} + + indexes = self.xindexes.copy_indexes(deep=deep) attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs) From 23dc4cf60dc07b2a145770efdd485b94600314f3 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 12:40:23 +0100 Subject: [PATCH 111/159] reindex/align: fix coord dtype of new indexes Use `as_compatible_data` to get the right dtype to assign to the new index coordinates created from reindex indexers. --- xarray/core/alignment.py | 9 +++++---- xarray/core/utils.py | 5 +++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index cca3c182f24..34268898df8 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -27,7 +27,7 @@ from .common import DataWithCoords from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex, indexes_all_equal from .utils import is_dict_like, is_full_slice, safe_cast_to_index -from .variable import Variable, calculate_dimensions +from .variable import Variable, as_compatible_data, calculate_dimensions if TYPE_CHECKING: from .dataarray import DataArray @@ -195,12 +195,13 @@ def _normalize_indexes( f"Indexer has dimensions {idx.dims} that are different " f"from that to be indexed along '{k}'" ) - pd_idx = safe_cast_to_index(idx).copy() + data = as_compatible_data(idx) + pd_idx = safe_cast_to_index(data) pd_idx.name = k if isinstance(pd_idx, pd.MultiIndex): - idx, _ = PandasMultiIndex.from_pandas_index(pd_idx, k) + idx = PandasMultiIndex(pd_idx, k) else: - idx, _ = PandasIndex.from_pandas_index(pd_idx, k) + idx = PandasIndex(pd_idx, k, coord_dtype=data.dtype) xr_variables.update(idx.create_variables()) xr_indexes[k] = idx diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 913a5789b2f..a0f5bfdcf27 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -118,9 +118,14 @@ def safe_cast_to_index(array: Any) -> pd.Index: if isinstance(array, pd.Index): index = array elif hasattr(array, "to_index"): + # xarray Variable index = array.to_index() elif hasattr(array, "to_pandas_index"): + # xarray Index index = array.to_pandas_index() + elif hasattr(array, "array") and isinstance(array.array, pd.Index): + # xarray PandasIndexingAdapter + index = array.array else: kwargs = {} if hasattr(array, "dtype") and array.dtype.kind == "O": From 8b325eff422f507a8158202b9e6d034028fcdc0e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 12:44:15 +0100 Subject: [PATCH 112/159] fix doctests --- xarray/core/alignment.py | 3 +-- xarray/core/dataarray.py | 6 +++--- xarray/core/dataset.py | 6 +++--- xarray/core/merge.py | 2 +- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 34268898df8..37e8e11c807 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -721,8 +721,7 @@ def align( >>> a, b = xr.align(x, y, join="exact") Traceback (most recent call last): ... - "indexes along dimension {!r} are not equal".format(dim) - ValueError: indexes along dimension 'lat' are not equal + ValueError: cannot align objects with join='exact' ... >>> a, b = xr.align(x, y, join="override") >>> a diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d7866cd516d..93024fe9459 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2156,7 +2156,7 @@ def stack( ('b', 0), ('b', 1), ('b', 2)], - names=['x', 'y']) + name='z') See Also -------- @@ -2221,7 +2221,7 @@ def unstack( ('b', 0), ('b', 1), ('b', 2)], - names=['x', 'y']) + name='z') >>> roundtripped = stacked.unstack() >>> arr.identical(roundtripped) True @@ -2273,7 +2273,7 @@ def to_unstacked_dataset(self, dim, level=0): ('a', 1.0), ('a', 2.0), ('b', nan)], - names=['variable', 'y']) + name='z') >>> roundtripped = stacked.to_unstacked_dataset(dim="z") >>> data.identical(roundtripped) True diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 624333167a2..12cb9203a55 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4121,9 +4121,9 @@ def to_stacked_array( array([[0, 1, 2, 6], [3, 4, 5, 7]]) Coordinates: - * z (z) MultiIndex - - variable (z) object 'a' 'a' 'a' 'b' - - y (z) object 'u' 'v' 'w' nan + * z (z) object MultiIndex + * variable (z) object 'a' 'a' 'a' 'b' + * y (z) object 'u' 'v' 'w' nan Dimensions without coordinates: x """ diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 6d39fa47e52..de25869af29 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -962,7 +962,7 @@ def merge( >>> xr.merge([x, y, z], join="exact") Traceback (most recent call last): ... - ValueError: indexes along dimension 'lat' are not equal + ValueError: cannot align objects with join='exact' where ... Raises ------ From 30023a484cda9e6d4d6093cc9a244572f961b7a4 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 14:05:11 +0100 Subject: [PATCH 113/159] to_stacked_array: do not coerce level dtypes If the intent was to propagate the original coordinate dtypes, this is now handled by PandasMultiIndex.stack. This fixes the case where multi-index ``.levels`` do not return labels with "nan" values (if any) but where the coordinate does, which resulted in dtype mismatch between the coordinate variable dtype (e.g., coerced to int64) and the actual label values (float64 due to nan values), eventually causing numpy errors trying to coerce nan value to int64 when indexing the coordinate. --- xarray/core/dataset.py | 12 ------------ xarray/tests/test_dataset.py | 2 -- 2 files changed, 14 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 05fff1001de..d7ad828791a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4159,18 +4159,6 @@ def ensure_stackable(val): stackable_vars = [ensure_stackable(self[key]) for key in self.data_vars] data_array = xr.concat(stackable_vars, dim=new_dim) - # coerce the levels of the MultiIndex to have the same type as the - # input dimensions. This code is messy, so it might be better to just - # input a dummy value for the singleton dimension. - # TODO: benbovy - flexible indexes: update when MultIndex has its own - # class inheriting from xarray.Index - idx = data_array.xindexes[new_dim].to_pandas_index() - levels = [idx.levels[0]] + [ - level.astype(self[level.name].dtype) for level in idx.levels[1:] - ] - new_idx = idx.set_levels(levels) - data_array[new_dim] = IndexVariable(new_dim, new_idx) - if name is not None: data_array.name = name diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 802f2c1b948..3f993fa86d8 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3331,8 +3331,6 @@ def test_to_stacked_array_dtype_dims(self): D = xr.Dataset({"a": a, "b": b}) sample_dims = ["x"] y = D.to_stacked_array("features", sample_dims) - # TODO: benbovy - flexible indexes: update when MultiIndex has its own class - # inherited from xarray.Index assert y.xindexes["features"].to_pandas_index().levels[1].dtype == D.y.dtype assert y.dims == ("x", "features") From d7fed2946142546daff7a6d5d85e8b77d3d123bf Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 16:35:42 +0100 Subject: [PATCH 114/159] fix indent --- xarray/core/alignment.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 37e8e11c807..5a27f930ed3 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -891,11 +891,11 @@ def reindex_like( if dim in obj.dims: other_size = other.sizes[dim] obj_size = obj.sizes[dim] - if other_size != obj_size: - raise ValueError( - "different size for unlabeled " - f"dimension on argument {dim!r}: {other_size!r} vs {obj_size!r}" - ) + if other_size != obj_size: + raise ValueError( + "different size for unlabeled " + f"dimension on argument {dim!r}: {other_size!r} vs {obj_size!r}" + ) return reindex( obj, From 76e15ef27fee5858b813aa24e763b4ab1a6d7ed7 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 16:35:56 +0100 Subject: [PATCH 115/159] doc: fix user-guide build errors --- doc/user-guide/plotting.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/user-guide/plotting.rst b/doc/user-guide/plotting.rst index d81ba30f12f..f514b4ecbef 100644 --- a/doc/user-guide/plotting.rst +++ b/doc/user-guide/plotting.rst @@ -251,7 +251,7 @@ Finally, if a dataset does not have any coordinates it enumerates all data point .. ipython:: python :okwarning: - air1d_multi = air1d_multi.drop("date") + air1d_multi = air1d_multi.drop(["date", "time", "decimal_day"]) air1d_multi.plot() The same applies to 2D plots below. From d69bb45d706cfffd032d8562986ad2e1e45de26e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 17:24:38 +0100 Subject: [PATCH 116/159] stack: fix new index variables not in coordinates --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d7ad828791a..1041d9e5021 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3996,7 +3996,7 @@ def _stack_once(self, dims, new_dim, index_cls, create_index=True): for k in idx_vars: new_variables.pop(k, None) new_variables.update(idx_vars) - new_coord_names.update({new_dim}) + new_coord_names.update(idx_vars) indexes = {k: v for k, v in self.xindexes.items() if k not in drop_indexes} indexes.update(new_indexes) From 44e68933c3dc08cf3cfbdd341afcb1859e05a8e0 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 18:11:40 +0100 Subject: [PATCH 117/159] unstack full-reindex: fix alignment errors We need a mapping of all index coordinate names to the full index so that Aligner can find matching indexes. (Note: reindex may eventually be refactored to be a bit more clever so that providing only a `{dim: indexer}` will be enough in this case?) --- xarray/core/dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1041d9e5021..65a370248e8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4233,8 +4233,13 @@ def _unstack_full_reindex( obj = self else: # TODO: we may depreciate implicit re-indexing with a pandas.MultiIndex + xr_full_idx = PandasMultiIndex(full_idx, dim) + indexers = Indexes( + {k: xr_full_idx for k in index_vars}, + xr_full_idx.create_variables(index_vars), + ) obj = self._reindex( - {dim: full_idx}, copy=False, fill_value=fill_value, sparse=sparse + indexers, copy=False, fill_value=fill_value, sparse=sparse ) for name, var in obj.variables.items(): From 50bd989e4d18670c2f382aa7552c54379b96fdfd Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 21:09:22 +0100 Subject: [PATCH 118/159] PandasIndex coord dtype: avoid convert index->array --- xarray/core/indexes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 0abaa287a57..b1eae7ecf99 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -230,7 +230,7 @@ def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None): self.dim = dim if coord_dtype is None: - coord_dtype = get_valid_numpy_dtype(np.asarray(array)) + coord_dtype = get_valid_numpy_dtype(index) self.coord_dtype = coord_dtype def _replace(self, index, dim=None, coord_dtype=None): From 1184f50e79939bef8166747fc8e1a118d88cb3e9 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 23:43:03 +0100 Subject: [PATCH 119/159] refactor diff --- xarray/core/dataset.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 65a370248e8..ccb3b791acf 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6026,10 +6026,13 @@ def diff(self, dim, n=1, label="upper"): else: raise ValueError("The 'label' argument has to be either 'upper' or 'lower'") + indexes, index_vars = isel_indexes(self.xindexes, kwargs_new) variables = {} for name, var in self.variables.items(): - if dim in var.dims: + if name in index_vars: + variables[name] = index_vars[name] + elif dim in var.dims: if name in self.data_vars: variables[name] = var.isel(**kwargs_end) - var.isel(**kwargs_start) else: @@ -6037,16 +6040,6 @@ def diff(self, dim, n=1, label="upper"): else: variables[name] = var - indexes = dict(self.xindexes) - if dim in indexes: - if isinstance(indexes[dim], PandasIndex): - # maybe optimize? (pandas index already indexed above with var.isel) - new_index = indexes[dim].index[kwargs_new[dim]] - if isinstance(new_index, pd.MultiIndex): - indexes[dim] = PandasMultiIndex(new_index, dim) - else: - indexes[dim] = PandasIndex(new_index, dim) - difference = self._replace_with_new_dims(variables, indexes=indexes) if n > 1: From ed7f24bde803767d83352dc4cc1e2af1088c4b6c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 14 Feb 2022 23:55:42 +0100 Subject: [PATCH 120/159] quick fix level coord dtype int32/64 on win --- xarray/tests/test_indexes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index ee0e7ed7853..cea04a08fff 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -370,7 +370,7 @@ def test_unstack(self) -> None: assert new_pd_idx.equals(pd_midx) def test_from_pandas_index(self) -> None: - foo_data = np.array([0, 0, 1], dtype="int") + foo_data = np.array([0, 0, 1], dtype="int64") bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) From ffa06daa92e166a2853f1d7b888fc8eb971276d7 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 15 Feb 2022 00:20:45 +0100 Subject: [PATCH 121/159] dask presist/compute dataarray: propagate indexes --- xarray/core/dataarray.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 5e00a6b21d2..2e6abf6d685 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -905,7 +905,8 @@ def _dask_finalize(results, name, func, *args, **kwargs): ds = func(results, *args, **kwargs) variable = ds._variables.pop(_THIS_ARRAY) coords = ds._variables - return DataArray(variable, coords, name=name, fastpath=True) + indexes = ds._indexes + return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) def load(self, **kwargs) -> DataArray: """Manually trigger loading of this array's data from disk or a From 7c087699b31eda52a02675333619208f593025fe Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 15 Feb 2022 01:08:22 +0100 Subject: [PATCH 122/159] refactor roll --- xarray/core/dataset.py | 28 +++++++++++------------- xarray/core/indexes.py | 49 ++++++++++++++++++++++++++++-------------- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ccb3b791acf..d2a078cad9c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -64,7 +64,7 @@ filter_indexes_from_coords, isel_indexes, remove_unused_levels_categories, - roll_index, + roll_indexes, ) from .indexing import is_fancy_indexer, map_index_queries from .merge import ( @@ -6178,29 +6178,27 @@ def roll( if invalid: raise ValueError(f"dimensions {invalid!r} do not exist") - unrolled_vars = () if roll_coords else self.coords + unrolled_vars: tuple[Hashable, ...] + + if roll_coords: + indexes, index_vars = roll_indexes(self.xindexes, shifts) + unrolled_vars = () + else: + indexes = dict(self.xindexes) + index_vars = dict(self.xindexes.variables) + unrolled_vars = tuple(self.coords) variables = {} for k, var in self.variables.items(): - if k not in unrolled_vars: + if k in index_vars: + variables[k] = index_vars[k] + elif k not in unrolled_vars: variables[k] = var.roll( shifts={k: s for k, s in shifts.items() if k in var.dims} ) else: variables[k] = var - if roll_coords: - indexes: dict[Hashable, Index] = {} - idx: pd.Index - for k, idx in self.xindexes.items(): - (dim,) = self.variables[k].dims - if dim in shifts: - indexes[k] = roll_index(idx, shifts[dim]) - else: - indexes[k] = idx - else: - indexes = dict(self.xindexes) - return self._replace(variables, indexes=indexes) def sortby(self, variables, ascending=True): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index b1eae7ecf99..7d90093f0cc 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -107,6 +107,9 @@ def union(self, other): # pragma: no cover def intersection(self, other): # pragma: no cover raise NotImplementedError() + def roll(self, shifts: Mapping[Any, int]) -> Union["Index", None]: + return None + def rename( self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable] ) -> Tuple["Index", IndexVars]: @@ -483,6 +486,16 @@ def reindex_like( return {self.dim: get_indexer_nd(self.index, other.index, method, tolerance)} + def roll(self, shifts: Mapping[Any, int]) -> "PandasIndex": + shift = shifts[self.dim] % self.index.shape[0] + + if shift != 0: + new_pd_idx = self.index[-shift:].append(self.index[:-shift]) + else: + new_pd_idx = self.index[:] + + return self._replace(new_pd_idx) + def rename(self, name_dict, dims_dict): if self.index.name not in name_dict and self.dim not in dims_dict: return self, {} @@ -1286,17 +1299,6 @@ def default_indexes( return {key: coords[key]._to_xindex() for key in dims if key in coords} -def roll_index(index: PandasIndex, count: int, axis: int = 0) -> PandasIndex: - """Roll an pandas.Index.""" - pd_index = index.to_pandas_index() - count %= pd_index.shape[0] - if count != 0: - new_idx = pd_index[-count:].append(pd_index[:-count]) - else: - new_idx = pd_index[:] - return PandasIndex(new_idx, index.dim) - - def indexes_equal( index: Index, other_index: Index, @@ -1369,18 +1371,19 @@ def check_variables(): return not not_equal -def isel_indexes( +def _apply_indexes( indexes: Indexes[Index], - indexers: Mapping[Any, Any], + args: Mapping[Any, Any], + func: str, ) -> Tuple[Dict[Hashable, Index], Dict[Hashable, "Variable"]]: new_indexes: Dict[Hashable, Index] = {k: v for k, v in indexes.items()} new_index_variables: Dict[Hashable, Variable] = {} for index, index_vars in indexes.group_by_index(): index_dims = {d for var in index_vars.values() for d in var.dims} - index_indexers = {k: v for k, v in indexers.items() if k in index_dims} - if index_indexers: - new_index = index.isel(index_indexers) + index_args = {k: v for k, v in args.items() if k in index_dims} + if index_args: + new_index = getattr(index, func)(index_args) if new_index is not None: new_indexes.update({k: new_index for k in index_vars}) new_index_vars = new_index.create_variables(index_vars) @@ -1392,6 +1395,20 @@ def isel_indexes( return new_indexes, new_index_variables +def isel_indexes( + indexes: Indexes[Index], + indexers: Mapping[Any, Any], +) -> Tuple[Dict[Hashable, Index], Dict[Hashable, "Variable"]]: + return _apply_indexes(indexes, indexers, "isel") + + +def roll_indexes( + indexes: Indexes[Index], + shifts: Mapping[Any, int], +) -> Tuple[Dict[Hashable, Index], Dict[Hashable, "Variable"]]: + return _apply_indexes(indexes, shifts, "roll") + + def filter_indexes_from_coords( indexes: Mapping[Any, Index], filtered_coord_names: Set, From 71ac97a3826f904ea91ed8d42cd44bbd93afbc30 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 15 Feb 2022 12:28:16 +0100 Subject: [PATCH 123/159] rename Index.query -> Index.sel Also rename: - QueryResult -> IndexSelResult - merge_query_results -> merge_sel_results --- xarray/core/indexes.py | 18 ++++++++++-------- xarray/core/indexing.py | 14 +++++++------- xarray/tests/test_indexes.py | 30 +++++++++++++++--------------- xarray/tests/test_indexing.py | 22 ++++++++++++---------- 4 files changed, 44 insertions(+), 40 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 7d90093f0cc..efcbc17f041 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -24,7 +24,7 @@ import pandas as pd from . import formatting, nputils, utils -from .indexing import PandasIndexingAdapter, PandasMultiIndexingAdapter, QueryResult +from .indexing import IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter from .types import T_Index from .utils import Frozen, get_valid_numpy_dtype, is_dict_like, is_scalar @@ -87,7 +87,7 @@ def isel( ) -> Union["Index", None]: return None - def query(self, labels: Dict[Any, Any]) -> QueryResult: + def sel(self, labels: Dict[Any, Any]) -> IndexSelResult: raise NotImplementedError(f"{self!r} doesn't support label-based selection") def join(self: T_Index, other: T_Index, how: str = "inner") -> T_Index: @@ -393,7 +393,9 @@ def isel( return self._replace(self.index[indxr]) - def query(self, labels: Dict[Any, Any], method=None, tolerance=None) -> QueryResult: + def sel( + self, labels: Dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: from .dataarray import DataArray from .variable import Variable @@ -448,7 +450,7 @@ def query(self, labels: Dict[Any, Any], method=None, tolerance=None) -> QueryRes elif isinstance(label, DataArray): indexer = DataArray(indexer, coords=label._coords, dims=label.dims) - return QueryResult({self.dim: indexer}) + return IndexSelResult({self.dim: indexer}) def equals(self, other: Index): if not isinstance(other, PandasIndex): @@ -838,7 +840,7 @@ def create_variables( self.index, self.dim, var_meta=var_meta ) - def query(self, labels, method=None, tolerance=None) -> QueryResult: + def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: from .dataarray import DataArray from .variable import Variable @@ -898,7 +900,7 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: raise ValueError( f"invalid multi-index level names {invalid_levels}" ) - return self.query(label) + return self.sel(label) elif isinstance(label, slice): indexer = _query_slice(self.index, label, coord_name) @@ -969,7 +971,7 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: for name, val in scalar_coord_values.items(): variables[name] = Variable([], val) - return QueryResult( + return IndexSelResult( {self.dim: indexer}, indexes=indexes, variables=variables, @@ -979,7 +981,7 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult: ) else: - return QueryResult({self.dim: indexer}) + return IndexSelResult({self.dim: indexer}) def join(self, other, how: str = "inner"): if how == "outer": diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 58cc321d77e..4a44364ad3b 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -41,7 +41,7 @@ @dataclass -class QueryResult: +class IndexSelResult: """Index query results. Attributes @@ -86,7 +86,7 @@ def as_tuple(self): ) -def merge_query_results(results: List[QueryResult]) -> QueryResult: +def merge_sel_results(results: List[IndexSelResult]) -> IndexSelResult: all_dims_count = Counter([dim for res in results for dim in res.dim_indexers]) duplicate_dims = {k: v for k, v in all_dims_count.items() if v > 1} @@ -119,7 +119,7 @@ def merge_query_results(results: List[QueryResult]) -> QueryResult: drop_indexes += res.drop_indexes rename_dims.update(res.rename_dims) - return QueryResult( + return IndexSelResult( dim_indexers, indexes, variables, drop_coords, drop_indexes, rename_dims ) @@ -165,7 +165,7 @@ def map_index_queries( method=None, tolerance=None, **indexers_kwargs: Any, -) -> QueryResult: +) -> IndexSelResult: """Execute index queries from a DataArray / Dataset and label-based indexers and return the (merged) query results. @@ -185,11 +185,11 @@ def map_index_queries( for index, labels in grouped_indexers: if index is None: # forward dimension indexers with no index/coordinate - results.append(QueryResult(labels)) + results.append(IndexSelResult(labels)) else: - results.append(index.query(labels, **options)) # type: ignore[call-arg] + results.append(index.sel(labels, **options)) # type: ignore[call-arg] - merged = merge_query_results(results) + merged = merge_sel_results(results) # drop dimension coordinates found in dimension indexers # (also drop multi-index if any) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index cea04a08fff..3cac1e648a0 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -132,21 +132,21 @@ def test_to_pandas_index(self) -> None: index = PandasIndex(pd_idx, "x") assert index.to_pandas_index() is pd_idx - def test_query(self) -> None: + def test_sel(self) -> None: # TODO: add tests that aren't just for edge cases index = PandasIndex(pd.Index([1, 2, 3]), "x") with pytest.raises(KeyError, match=r"not all values found"): - index.query({"x": [0]}) + index.sel({"x": [0]}) with pytest.raises(KeyError): - index.query({"x": 0}) + index.sel({"x": 0}) with pytest.raises(ValueError, match=r"does not have a MultiIndex"): - index.query({"x": {"one": 0}}) + index.sel({"x": {"one": 0}}) def test_query_boolean(self) -> None: # index should be ignored and indexer dtype should not be coerced # see https://github.com/pydata/xarray/issues/5727 index = PandasIndex(pd.Index([0.0, 2.0, 1.0, 3.0]), "x") - actual = index.query({"x": [False, True, False, True]}) + actual = index.sel({"x": [False, True, False, True]}) expected_dim_indexers = {"x": [False, True, False, True]} np.testing.assert_array_equal( actual.dim_indexers["x"], expected_dim_indexers["x"] @@ -156,11 +156,11 @@ def test_query_datetime(self) -> None: index = PandasIndex( pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x" ) - actual = index.query({"x": "2001-01-01"}) + actual = index.sel({"x": "2001-01-01"}) expected_dim_indexers = {"x": 1} assert actual.dim_indexers == expected_dim_indexers - actual = index.query({"x": index.to_pandas_index().to_numpy()[1]}) + actual = index.sel({"x": index.to_pandas_index().to_numpy()[1]}) assert actual.dim_indexers == expected_dim_indexers def test_query_unsorted_datetime_index_raises(self) -> None: @@ -169,7 +169,7 @@ def test_query_unsorted_datetime_index_raises(self) -> None: # pandas will try to convert this into an array indexer. We should # raise instead, so we can be sure the result of indexing with a # slice is always a view. - index.query({"x": slice("2001", "2002")}) + index.sel({"x": slice("2001", "2002")}) def test_equals(self) -> None: index1 = PandasIndex([1, 2, 3], "x") @@ -405,26 +405,26 @@ def test_create_variables(self) -> None: for k, expected in index_vars.items(): assert_identical(actual[k], expected) - def test_query(self) -> None: + def test_sel(self) -> None: index = PandasMultiIndex( pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x" ) # test tuples inside slice are considered as scalar indexer values - actual = index.query({"x": slice(("a", 1), ("b", 2))}) + actual = index.sel({"x": slice(("a", 1), ("b", 2))}) expected_dim_indexers = {"x": slice(0, 4)} assert actual.dim_indexers == expected_dim_indexers with pytest.raises(KeyError, match=r"not all values found"): - index.query({"x": [0]}) + index.sel({"x": [0]}) with pytest.raises(KeyError): - index.query({"x": 0}) + index.sel({"x": 0}) with pytest.raises(ValueError, match=r"cannot provide labels for both.*"): - index.query({"one": 0, "x": "a"}) + index.sel({"one": 0, "x": "a"}) with pytest.raises(ValueError, match=r"invalid multi-index level names"): - index.query({"x": {"three": 0}}) + index.sel({"x": {"three": 0}}) with pytest.raises(IndexError): - index.query({"x": (slice(None), 1, "no_level")}) + index.sel({"x": (slice(None), 1, "no_level")}) def test_join(self): midx = pd.MultiIndex.from_product([["a", "aa"], [1, 2]], names=("one", "two")) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index e588bb4c661..79bf71d5242 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -103,7 +103,7 @@ def create_query_results( variables.update(index_vars) variables.update(other_vars) - return indexing.QueryResult( + return indexing.IndexSelResult( dim_indexers=dim_indexers, indexes=indexes, variables=variables, @@ -115,7 +115,7 @@ def create_query_results( def test_indexer( data: T_Xarray, x: Any, - expected: indexing.QueryResult, + expected: indexing.IndexSelResult, ) -> None: results = indexing.map_index_queries(data, {"x": x}) @@ -140,10 +140,10 @@ def test_indexer( ) mdata = DataArray(range(8), [("x", mindex)]) - test_indexer(data, 1, indexing.QueryResult({"x": 0})) - test_indexer(data, np.int32(1), indexing.QueryResult({"x": 0})) - test_indexer(data, Variable([], 1), indexing.QueryResult({"x": 0})) - test_indexer(mdata, ("a", 1, -1), indexing.QueryResult({"x": 0})) + test_indexer(data, 1, indexing.IndexSelResult({"x": 0})) + test_indexer(data, np.int32(1), indexing.IndexSelResult({"x": 0})) + test_indexer(data, Variable([], 1), indexing.IndexSelResult({"x": 0})) + test_indexer(mdata, ("a", 1, -1), indexing.IndexSelResult({"x": 0})) expected = create_query_results( [True, True, False, False, False, False, False, False], @@ -182,18 +182,20 @@ def test_indexer( test_indexer(mdata, ("a",), expected) test_indexer( - mdata, [("a", 1, -1), ("b", 2, -2)], indexing.QueryResult({"x": [0, 7]}) + mdata, [("a", 1, -1), ("b", 2, -2)], indexing.IndexSelResult({"x": [0, 7]}) ) test_indexer( - mdata, slice("a", "b"), indexing.QueryResult({"x": slice(0, 8, None)}) + mdata, slice("a", "b"), indexing.IndexSelResult({"x": slice(0, 8, None)}) ) test_indexer( mdata, slice(("a", 1), ("b", 1)), - indexing.QueryResult({"x": slice(0, 6, None)}), + indexing.IndexSelResult({"x": slice(0, 6, None)}), ) test_indexer( - mdata, {"one": "a", "two": 1, "three": -1}, indexing.QueryResult({"x": 0}) + mdata, + {"one": "a", "two": 1, "three": -1}, + indexing.IndexSelResult({"x": 0}), ) expected = create_query_results( From 7600e462f3bd475a652c0dc0fccef723423a7332 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 15 Feb 2022 12:34:11 +0100 Subject: [PATCH 124/159] remove Index.union and Index.intersection Those are not used elsewhere. ``Index.join`` is used instead for alignment. --- xarray/core/indexes.py | 14 -------------- xarray/tests/test_indexes.py | 14 -------------- 2 files changed, 28 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index efcbc17f041..7538c99f872 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -101,12 +101,6 @@ def reindex_like(self: T_Index, other: T_Index) -> Dict[Hashable, Any]: def equals(self, other): # pragma: no cover raise NotImplementedError() - def union(self, other): # pragma: no cover - raise NotImplementedError() - - def intersection(self, other): # pragma: no cover - raise NotImplementedError() - def roll(self, shifts: Mapping[Any, int]) -> Union["Index", None]: return None @@ -469,14 +463,6 @@ def join( coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype) return type(self)(index, self.dim, coord_dtype=coord_dtype) - def union(self, other): - new_index = self.index.union(other.index) - return type(self)(new_index, self.dim) - - def intersection(self, other): - new_index = self.index.intersection(other.index) - return type(self)(new_index, self.dim) - def reindex_like( self, other: "PandasIndex", method=None, tolerance=None ) -> Dict[Hashable, Any]: diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 3cac1e648a0..82c36ce24ae 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -192,20 +192,6 @@ def test_join(self) -> None: assert actual.equals(expected) assert actual.coord_dtype == " None: - index1 = PandasIndex([1, 2, 3], "x") - index2 = PandasIndex([4, 5, 6], "y") - actual = index1.union(index2) - assert actual.index.equals(pd.Index([1, 2, 3, 4, 5, 6])) - assert actual.dim == "x" - - def test_intersection(self) -> None: - index1 = PandasIndex([1, 2, 3], "x") - index2 = PandasIndex([2, 3, 4], "y") - actual = index1.intersection(index2) - assert actual.index.equals(pd.Index([2, 3])) - assert actual.dim == "x" - def test_reindex_like(self) -> None: index1 = PandasIndex([0, 1, 2], "x") index2 = PandasIndex([1, 2, 3, 4], "x") From 1b6c97dd16e7eee1f6513444076694a5aeeb99fb Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 15 Feb 2022 12:41:19 +0100 Subject: [PATCH 125/159] use future annotations in indexes.py --- xarray/core/indexes.py | 200 ++++++++++++++++++++--------------------- 1 file changed, 97 insertions(+), 103 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 7538c99f872..a53e07c06ea 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections.abc from collections import defaultdict from typing import ( @@ -8,15 +10,9 @@ Hashable, Iterable, Iterator, - List, Mapping, - Optional, Sequence, - Set, - Tuple, - Type, TypeVar, - Union, cast, ) @@ -39,13 +35,13 @@ class Index: @classmethod def from_variables( - cls, variables: Mapping[Any, "Variable"] - ) -> Tuple["Index", IndexVars]: + cls, variables: Mapping[Any, Variable] + ) -> tuple[Index, IndexVars]: raise NotImplementedError() @classmethod def concat( - cls: Type[T_Index], + cls: type[T_Index], indexes: Sequence[T_Index], dim: Hashable, positions: Iterable[Iterable[int]] = None, @@ -54,17 +50,17 @@ def concat( @classmethod def stack( - cls, variables: Mapping[Any, "Variable"], dim: Hashable - ) -> Tuple["Index", IndexVars]: + cls, variables: Mapping[Any, Variable], dim: Hashable + ) -> tuple[Index, IndexVars]: raise NotImplementedError( f"{cls!r} cannot be used for creating an index of stacked coordinates" ) - def unstack(self) -> Tuple[Dict[Hashable, "Index"], pd.MultiIndex]: + def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: raise NotImplementedError() def create_variables( - self, variables: Optional[Mapping[Any, "Variable"]] = None + self, variables: Mapping[Any, Variable] | None = None ) -> IndexVars: if variables is not None: # pass through @@ -83,11 +79,11 @@ def to_pandas_index(self) -> pd.Index: raise TypeError(f"{self!r} cannot be cast to a pandas.Index object") def isel( - self, indexers: Mapping[Any, Union[int, slice, np.ndarray, "Variable"]] - ) -> Union["Index", None]: + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Index | None: return None - def sel(self, labels: Dict[Any, Any]) -> IndexSelResult: + def sel(self, labels: dict[Any, Any]) -> IndexSelResult: raise NotImplementedError(f"{self!r} doesn't support label-based selection") def join(self: T_Index, other: T_Index, how: str = "inner") -> T_Index: @@ -95,18 +91,18 @@ def join(self: T_Index, other: T_Index, how: str = "inner") -> T_Index: f"{self!r} doesn't support alignment with inner/outer join method" ) - def reindex_like(self: T_Index, other: T_Index) -> Dict[Hashable, Any]: + def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]: raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") def equals(self, other): # pragma: no cover raise NotImplementedError() - def roll(self, shifts: Mapping[Any, int]) -> Union["Index", None]: + def roll(self, shifts: Mapping[Any, int]) -> Index | None: return None def rename( self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable] - ) -> Tuple["Index", IndexVars]: + ) -> tuple[Index, IndexVars]: return self, {} def copy(self, deep: bool = True): # pragma: no cover @@ -239,8 +235,8 @@ def _replace(self, index, dim=None, coord_dtype=None): @classmethod def from_variables( - cls, variables: Mapping[Any, "Variable"] - ) -> Tuple["PandasIndex", IndexVars]: + cls, variables: Mapping[Any, Variable] + ) -> tuple[PandasIndex, IndexVars]: from .variable import IndexVariable if len(variables) != 1: @@ -300,10 +296,10 @@ def _concat_indexes(indexes, dim, positions=None) -> pd.Index: @classmethod def concat( cls, - indexes: Sequence["PandasIndex"], + indexes: Sequence[PandasIndex], dim: Hashable, positions: Iterable[Iterable[int]] = None, - ) -> "PandasIndex": + ) -> PandasIndex: new_pd_index = cls._concat_indexes(indexes, dim, positions) if not indexes: @@ -314,13 +310,13 @@ def concat( return cls(new_pd_index, dim=dim, coord_dtype=coord_dtype) def create_variables( - self, variables: Optional[Mapping[Any, "Variable"]] = None + self, variables: Mapping[Any, Variable] | None = None ) -> IndexVars: from .variable import IndexVariable name = self.index.name - attrs: Union[Mapping[Hashable, Any], None] - encoding: Union[Mapping[Hashable, Any], None] + attrs: Mapping[Hashable, Any] | None + encoding: Mapping[Hashable, Any] | None if variables is not None and name in variables: var = variables[name] @@ -339,8 +335,8 @@ def from_pandas_index( cls, index: pd.Index, dim: Hashable, - var_meta: Optional[Dict[Any, Dict]] = None, - ) -> Tuple["PandasIndex", IndexVars]: + var_meta: dict[Any, dict] | None = None, + ) -> tuple[PandasIndex, IndexVars]: from .variable import IndexVariable if index.name is None: @@ -370,8 +366,8 @@ def to_pandas_index(self) -> pd.Index: return self.index def isel( - self, indexers: Mapping[Any, Union[int, slice, np.ndarray, "Variable"]] - ) -> Optional["PandasIndex"]: + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> PandasIndex | None: from .variable import Variable indxr = indexers[self.dim] @@ -388,7 +384,7 @@ def isel( return self._replace(self.index[indxr]) def sel( - self, labels: Dict[Any, Any], method=None, tolerance=None + self, labels: dict[Any, Any], method=None, tolerance=None ) -> IndexSelResult: from .dataarray import DataArray from .variable import Variable @@ -451,9 +447,7 @@ def equals(self, other: Index): return False return self.index.equals(other.index) and self.dim == other.dim - def join( - self: "PandasIndex", other: "PandasIndex", how: str = "inner" - ) -> "PandasIndex": + def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasIndex: if how == "outer": index = self.index.union(other.index) else: @@ -464,8 +458,8 @@ def join( return type(self)(index, self.dim, coord_dtype=coord_dtype) def reindex_like( - self, other: "PandasIndex", method=None, tolerance=None - ) -> Dict[Hashable, Any]: + self, other: PandasIndex, method=None, tolerance=None + ) -> dict[Hashable, Any]: if not self.index.is_unique: raise ValueError( f"cannot reindex or align along dimension {self.dim!r} because the " @@ -474,7 +468,7 @@ def reindex_like( return {self.dim: get_indexer_nd(self.index, other.index, method, tolerance)} - def roll(self, shifts: Mapping[Any, int]) -> "PandasIndex": + def roll(self, shifts: Mapping[Any, int]) -> PandasIndex: shift = shifts[self.dim] % self.index.shape[0] if shift != 0: @@ -502,7 +496,7 @@ def __getitem__(self, indexer: Any): return self._replace(self.index[indexer]) -def _check_dim_compat(variables: Mapping[Any, "Variable"], all_dims: str = "equal"): +def _check_dim_compat(variables: Mapping[Any, Variable], all_dims: str = "equal"): """Check that all multi-index variable candidates are 1-dimensional and either share the same (single) dimension or each have a different dimension. @@ -525,7 +519,7 @@ def _check_dim_compat(variables: Mapping[Any, "Variable"], all_dims: str = "equa ) -def _get_var_metadata(variables: Mapping[Any, "Variable"]) -> Dict[Any, Dict[str, Any]]: +def _get_var_metadata(variables: Mapping[Any, Variable]) -> dict[Any, dict[str, Any]]: return { name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} for name, var in variables.items() @@ -591,7 +585,7 @@ def remove_unused_levels_categories(index: pd.Index) -> pd.Index: class PandasMultiIndex(PandasIndex): """Wrap a pandas.MultiIndex as an xarray compatible index.""" - level_coords_dtype: Dict[str, Any] + level_coords_dtype: dict[str, Any] __slots__ = ("index", "dim", "coord_dtype", "level_coords_dtype") @@ -604,7 +598,7 @@ def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None): } self.level_coords_dtype = level_coords_dtype - def _replace(self, index, dim=None, level_coords_dtype=None) -> "PandasMultiIndex": + def _replace(self, index, dim=None, level_coords_dtype=None) -> PandasMultiIndex: if dim is None: dim = self.dim index.name = dim @@ -614,8 +608,8 @@ def _replace(self, index, dim=None, level_coords_dtype=None) -> "PandasMultiInde @classmethod def from_variables( - cls, variables: Mapping[Any, "Variable"] - ) -> Tuple["PandasMultiIndex", IndexVars]: + cls, variables: Mapping[Any, Variable] + ) -> tuple[PandasMultiIndex, IndexVars]: _check_dim_compat(variables) dim = next(iter(variables.values())).dims[0] @@ -635,10 +629,10 @@ def from_variables( @classmethod def concat( # type: ignore[override] cls, - indexes: Sequence["PandasMultiIndex"], + indexes: Sequence[PandasMultiIndex], dim: Hashable, positions: Iterable[Iterable[int]] = None, - ) -> "PandasMultiIndex": + ) -> PandasMultiIndex: new_pd_index = cls._concat_indexes(indexes, dim, positions) if not indexes: @@ -654,8 +648,8 @@ def concat( # type: ignore[override] @classmethod def stack( - cls, variables: Mapping[Any, "Variable"], dim: Hashable - ) -> Tuple["PandasMultiIndex", IndexVars]: + cls, variables: Mapping[Any, Variable], dim: Hashable + ) -> tuple[PandasMultiIndex, IndexVars]: """Create a new Pandas MultiIndex from the product of 1-d variables (levels) along a new dimension. @@ -683,10 +677,10 @@ def stack( return cls.from_pandas_index(index, dim, var_meta=_get_var_metadata(variables)) - def unstack(self) -> Tuple[Dict[Hashable, Index], pd.MultiIndex]: + def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: clean_index = remove_unused_levels_categories(self.index) - new_indexes: Dict[Hashable, Index] = {} + new_indexes: dict[Hashable, Index] = {} for name, lev in zip(clean_index.names, clean_index.levels): idx = PandasIndex( lev.copy(), name, coord_dtype=self.level_coords_dtype[name] @@ -699,18 +693,18 @@ def unstack(self) -> Tuple[Dict[Hashable, Index], pd.MultiIndex]: def from_variables_maybe_expand( cls, dim: Hashable, - current_variables: Mapping[Any, "Variable"], - variables: Mapping[Any, "Variable"], - ) -> Tuple["PandasMultiIndex", IndexVars]: + current_variables: Mapping[Any, Variable], + variables: Mapping[Any, Variable], + ) -> tuple[PandasMultiIndex, IndexVars]: """Create a new multi-index maybe by expanding an existing one with new variables as index levels. The index and its corresponding coordinates may be created along a new dimension. """ - names: List[Hashable] = [] - codes: List[List[int]] = [] - levels: List[List[int]] = [] - level_variables: Dict[Any, "Variable"] = {} + names: list[Hashable] = [] + codes: list[list[int]] = [] + levels: list[list[int]] = [] + level_variables: dict[Any, Variable] = {} _check_dim_compat({**current_variables, **variables}) @@ -750,8 +744,8 @@ def from_variables_maybe_expand( ) def keep_levels( - self, level_variables: Mapping[Any, "Variable"] - ) -> Tuple[Union["PandasMultiIndex", PandasIndex], IndexVars]: + self, level_variables: Mapping[Any, Variable] + ) -> tuple[PandasMultiIndex | PandasIndex, IndexVars]: """Keep only the provided levels and return a new multi-index with its corresponding coordinates. @@ -767,8 +761,8 @@ def keep_levels( return PandasIndex.from_pandas_index(index, self.dim, var_meta=var_meta) def reorder_levels( - self, level_variables: Mapping[Any, "Variable"] - ) -> Tuple["PandasMultiIndex", IndexVars]: + self, level_variables: Mapping[Any, Variable] + ) -> tuple[PandasMultiIndex, IndexVars]: """Re-arrange index levels using input order and return a new multi-index with its corresponding coordinates. @@ -783,8 +777,8 @@ def from_pandas_index( cls, index: pd.MultiIndex, dim: Hashable, - var_meta: Optional[Dict[Any, Dict]] = None, - ) -> Tuple["PandasMultiIndex", IndexVars]: + var_meta: dict[Any, dict] | None = None, + ) -> tuple[PandasMultiIndex, IndexVars]: names = [] idx_dtypes = {} @@ -810,7 +804,7 @@ def from_pandas_index( return cls(index, dim, level_coords_dtype=level_coords_dtype), index_vars def create_variables( - self, variables: Optional[Mapping[Any, "Variable"]] = None + self, variables: Mapping[Any, Variable] | None = None ) -> IndexVars: var_meta = {} if variables is not None: @@ -1004,9 +998,9 @@ def rename(self, name_dict, dims_dict): def create_default_index_implicit( - dim_variable: "Variable", - all_variables: Optional[Union[Mapping, Iterable[Hashable]]] = None, -) -> Tuple[PandasIndex, IndexVars]: + dim_variable: Variable, + all_variables: Mapping | Iterable[Hashable] | None = None, +) -> tuple[PandasIndex, IndexVars]: """Create a default index from a dimension variable. Create a PandasMultiIndex if the given variable wraps a pandas.MultiIndex, @@ -1065,8 +1059,8 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): """ - _indexes: Dict[Any, T_PandasOrXarrayIndex] - _variables: Dict[Any, "Variable"] + _indexes: dict[Any, T_PandasOrXarrayIndex] + _variables: dict[Any, Variable] __slots__ = ( "_indexes", @@ -1079,8 +1073,8 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): def __init__( self, - indexes: Dict[Any, T_PandasOrXarrayIndex], - variables: Dict[Any, "Variable"], + indexes: dict[Any, T_PandasOrXarrayIndex], + variables: dict[Any, Variable], ): """Constructor not for public consumption. @@ -1095,27 +1089,27 @@ def __init__( self._indexes = indexes self._variables = variables - self._dims: Optional[Mapping[Hashable, int]] = None - self.__coord_name_id: Optional[Dict[Any, int]] = None - self.__id_index: Optional[Dict[int, T_PandasOrXarrayIndex]] = None - self.__id_coord_names: Optional[Dict[int, Tuple[Hashable, ...]]] = None + self._dims: Mapping[Hashable, int] | None = None + self.__coord_name_id: dict[Any, int] | None = None + self.__id_index: dict[int, T_PandasOrXarrayIndex] | None = None + self.__id_coord_names: dict[int, tuple[Hashable, ...]] | None = None @property - def _coord_name_id(self) -> Dict[Any, int]: + def _coord_name_id(self) -> dict[Any, int]: if self.__coord_name_id is None: self.__coord_name_id = {k: id(idx) for k, idx in self._indexes.items()} return self.__coord_name_id @property - def _id_index(self) -> Dict[int, T_PandasOrXarrayIndex]: + def _id_index(self) -> dict[int, T_PandasOrXarrayIndex]: if self.__id_index is None: self.__id_index = {id(idx): idx for idx in self.get_unique()} return self.__id_index @property - def _id_coord_names(self) -> Dict[int, Tuple[Hashable, ...]]: + def _id_coord_names(self) -> dict[int, tuple[Hashable, ...]]: if self.__id_coord_names is None: - id_coord_names: Mapping[int, List[Hashable]] = defaultdict(list) + id_coord_names: Mapping[int, list[Hashable]] = defaultdict(list) for k, v in self._coord_name_id.items(): id_coord_names[v].append(k) self.__id_coord_names = {k: tuple(v) for k, v in id_coord_names.items()} @@ -1123,7 +1117,7 @@ def _id_coord_names(self) -> Dict[int, Tuple[Hashable, ...]]: return self.__id_coord_names @property - def variables(self) -> Mapping[Hashable, "Variable"]: + def variables(self) -> Mapping[Hashable, Variable]: return Frozen(self._variables) @property @@ -1135,11 +1129,11 @@ def dims(self) -> Mapping[Hashable, int]: return Frozen(self._dims) - def get_unique(self) -> List[T_PandasOrXarrayIndex]: + def get_unique(self) -> list[T_PandasOrXarrayIndex]: """Return a list of unique indexes, preserving order.""" - unique_indexes: List[T_PandasOrXarrayIndex] = [] - seen: Set[T_PandasOrXarrayIndex] = set() + unique_indexes: list[T_PandasOrXarrayIndex] = [] + seen: set[T_PandasOrXarrayIndex] = set() for index in self._indexes.values(): if index not in seen: @@ -1156,7 +1150,7 @@ def is_multi(self, key: Hashable) -> bool: def get_all_coords( self, key: Hashable, errors: str = "raise" - ) -> Dict[Hashable, "Variable"]: + ) -> dict[Hashable, Variable]: """Return all coordinates having the same index. Parameters @@ -1210,7 +1204,7 @@ def get_all_dims( def group_by_index( self, - ) -> List[Tuple[T_PandasOrXarrayIndex, Dict[Hashable, "Variable"]]]: + ) -> list[tuple[T_PandasOrXarrayIndex, dict[Hashable, Variable]]]: """Returns a list of unique indexes and their corresponding coordinates.""" index_coords = [] @@ -1222,14 +1216,14 @@ def group_by_index( return index_coords - def to_pandas_indexes(self) -> "Indexes[pd.Index]": + def to_pandas_indexes(self) -> Indexes[pd.Index]: """Returns an immutable proxy for Dataset or DataArrary pandas indexes. Raises an error if this proxy contains indexes that cannot be coerced to pandas.Index objects. """ - indexes: Dict[Hashable, pd.Index] = {} + indexes: dict[Hashable, pd.Index] = {} for k, idx in self._indexes.items(): if isinstance(idx, pd.Index): @@ -1239,7 +1233,7 @@ def to_pandas_indexes(self) -> "Indexes[pd.Index]": return Indexes(indexes, self._variables) - def copy_indexes(self, deep: bool = True) -> Dict[Hashable, Index]: + def copy_indexes(self, deep: bool = True) -> dict[Hashable, Index]: """Return a new dictionary with copies of indexes, preserving unique indexes. @@ -1268,8 +1262,8 @@ def __repr__(self): def default_indexes( - coords: Mapping[Any, "Variable"], dims: Iterable -) -> Dict[Hashable, Index]: + coords: Mapping[Any, Variable], dims: Iterable +) -> dict[Hashable, Index]: """Default indexes for a Dataset/DataArray. Parameters @@ -1290,9 +1284,9 @@ def default_indexes( def indexes_equal( index: Index, other_index: Index, - variable: "Variable", - other_variable: "Variable", - cache: Dict[Tuple[int, int], Union[bool, None]] = None, + variable: Variable, + other_variable: Variable, + cache: dict[tuple[int, int], bool | None] = None, ) -> bool: """Check if two indexes are equal, possibly with cached results. @@ -1305,7 +1299,7 @@ def indexes_equal( cache = {} key = (id(index), id(other_index)) - equal: Union[bool, None] = None + equal: bool | None = None if key not in cache: if type(index) is type(other_index): @@ -1327,7 +1321,7 @@ def indexes_equal( def indexes_all_equal( - elements: Sequence[Tuple[Index, Dict[Hashable, "Variable"]]] + elements: Sequence[tuple[Index, dict[Hashable, Variable]]] ) -> bool: """Check if indexes are all equal. @@ -1363,9 +1357,9 @@ def _apply_indexes( indexes: Indexes[Index], args: Mapping[Any, Any], func: str, -) -> Tuple[Dict[Hashable, Index], Dict[Hashable, "Variable"]]: - new_indexes: Dict[Hashable, Index] = {k: v for k, v in indexes.items()} - new_index_variables: Dict[Hashable, Variable] = {} +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes.items()} + new_index_variables: dict[Hashable, Variable] = {} for index, index_vars in indexes.group_by_index(): index_dims = {d for var in index_vars.values() for d in var.dims} @@ -1386,28 +1380,28 @@ def _apply_indexes( def isel_indexes( indexes: Indexes[Index], indexers: Mapping[Any, Any], -) -> Tuple[Dict[Hashable, Index], Dict[Hashable, "Variable"]]: +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: return _apply_indexes(indexes, indexers, "isel") def roll_indexes( indexes: Indexes[Index], shifts: Mapping[Any, int], -) -> Tuple[Dict[Hashable, Index], Dict[Hashable, "Variable"]]: +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: return _apply_indexes(indexes, shifts, "roll") def filter_indexes_from_coords( indexes: Mapping[Any, Index], - filtered_coord_names: Set, -) -> Dict[Hashable, Index]: + filtered_coord_names: set, +) -> dict[Hashable, Index]: """Filter index items given a (sub)set of coordinate names. Drop all multi-coordinate related index items for any key missing in the set of coordinate names. """ - filtered_indexes: Dict[Any, Index] = dict(**indexes) + filtered_indexes: dict[Any, Index] = dict(**indexes) index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set) for name, idx in indexes.items(): @@ -1423,7 +1417,7 @@ def filter_indexes_from_coords( def assert_no_index_corrupted( indexes: Indexes[Index], - coord_names: Set[Hashable], + coord_names: set[Hashable], ) -> None: """Assert removing coordinates will not corrupt indexes.""" From eda0e85e43c7a6cec4683588cd69a73528433ab3 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 15 Feb 2022 15:46:05 +0100 Subject: [PATCH 126/159] Index.rename: return only the new index Create new coordinate variables using Index.create_variables instead (get rid of PandasIndex.from_pandas_index). --- xarray/core/dataset.py | 9 +++++++-- xarray/core/indexes.py | 21 ++++++++++----------- xarray/tests/test_indexes.py | 22 ++++++---------------- 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d2a078cad9c..471d7e22c03 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3256,10 +3256,15 @@ def _rename_indexes(self, name_dict, dims_dict): variables = {} for index, coord_names in self.xindexes.group_by_index(): - new_index, new_index_vars = index.rename(name_dict, dims_dict) - # map new index to its corresponding coordinates + new_index = index.rename(name_dict, dims_dict) new_coord_names = [name_dict.get(k, k) for k in coord_names] indexes.update({k: new_index for k in new_coord_names}) + new_index_vars = new_index.create_variables( + { + new: self._variables[old] + for old, new in zip(coord_names, new_coord_names) + } + ) variables.update(new_index_vars) return indexes, variables diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index a53e07c06ea..8fd315867a6 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -102,8 +102,8 @@ def roll(self, shifts: Mapping[Any, int]) -> Index | None: def rename( self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable] - ) -> tuple[Index, IndexVars]: - return self, {} + ) -> Index: + return self def copy(self, deep: bool = True): # pragma: no cover raise NotImplementedError() @@ -480,14 +480,12 @@ def roll(self, shifts: Mapping[Any, int]) -> PandasIndex: def rename(self, name_dict, dims_dict): if self.index.name not in name_dict and self.dim not in dims_dict: - return self, {} + return self new_name = name_dict.get(self.index.name, self.index.name) index = self.index.rename(new_name) new_dim = dims_dict.get(self.dim, self.dim) - var_meta = {new_name: {"dtype": self.coord_dtype}} - - return self.from_pandas_index(index, dim=new_dim, var_meta=var_meta) + return self._replace(index, dim=new_dim) def copy(self, deep=True): return self._replace(self.index.copy(deep=deep)) @@ -983,18 +981,19 @@ def join(self, other, how: str = "inner"): def rename(self, name_dict, dims_dict): if not set(self.index.names) & set(name_dict) and self.dim not in dims_dict: - return self, {} + return self # pandas 1.3.0: could simply do `self.index.rename(names_dict)` new_names = [name_dict.get(k, k) for k in self.index.names] index = self.index.rename(new_names) new_dim = dims_dict.get(self.dim, self.dim) - var_meta = { - k: {"dtype": v} for k, v in zip(new_names, self.level_coords_dtype.values()) + new_level_coords_dtype = { + k: v for k, v in zip(new_names, self.level_coords_dtype.values()) } - - return self.from_pandas_index(index, new_dim, var_meta=var_meta) + return self._replace( + index, dim=new_dim, level_coords_dtype=new_level_coords_dtype + ) def create_default_index_implicit( diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 82c36ce24ae..106acd3264f 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -209,21 +209,18 @@ def test_rename(self) -> None: index = PandasIndex(pd.Index([1, 2, 3], name="a"), "x", coord_dtype=np.int32) # shortcut - new_index, index_vars = index.rename({}, {}) + new_index = index.rename({}, {}) assert new_index is index - assert index_vars == {} - new_index, index_vars = index.rename({"a": "b"}, {}) + new_index = index.rename({"a": "b"}, {}) assert new_index.index.name == "b" assert new_index.dim == "x" assert new_index.coord_dtype == np.int32 - assert_identical(index_vars["b"], IndexVariable("x", [1, 2, 3])) - new_index, index_vars = index.rename({}, {"x": "y"}) + new_index = index.rename({}, {"x": "y"}) assert new_index.index.name == "a" assert new_index.dim == "y" assert new_index.coord_dtype == np.int32 - assert_identical(index_vars["a"], IndexVariable("y", [1, 2, 3])) def test_copy(self) -> None: expected = PandasIndex([1, 2, 3], "x", coord_dtype=np.int32) @@ -435,25 +432,18 @@ def test_rename(self) -> None: ) # shortcut - new_index, index_vars = index.rename({}, {}) + new_index = index.rename({}, {}) assert new_index is index - assert index_vars == {} - new_index, index_vars = index.rename({"two": "three"}, {}) + new_index = index.rename({"two": "three"}, {}) assert new_index.index.names == ["one", "three"] assert new_index.dim == "x" assert new_index.level_coords_dtype == {"one": " None: level_coords_dtype = {"one": "U<1", "two": np.int32} From 24d771977255d77182bed297751706a0335879ff Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 15 Feb 2022 16:17:23 +0100 Subject: [PATCH 127/159] Index.stack: return only the new index Create new coordinate variables using Index.create_variables instead (get rid of PandasIndex.from_pandas_index) --- xarray/core/dataset.py | 6 ++++-- xarray/core/indexes.py | 9 ++++----- xarray/tests/test_indexes.py | 13 ++----------- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 471d7e22c03..84563f17f3c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3995,8 +3995,10 @@ def _stack_once(self, dims, new_dim, index_cls, create_index=True): product_vars.update(idx_vars) if len(product_vars) == len(dims): - idx, idx_vars = index_cls.stack(product_vars, new_dim) - new_indexes.update({k: idx for k in idx_vars}) + idx = index_cls.stack(product_vars, new_dim) + new_indexes[new_dim] = idx + new_indexes.update({k: idx for k in product_vars}) + idx_vars = idx.create_variables(product_vars) # keep consistent multi-index coordinate order for k in idx_vars: new_variables.pop(k, None) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 8fd315867a6..55be3deb052 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -49,9 +49,7 @@ def concat( raise NotImplementedError() @classmethod - def stack( - cls, variables: Mapping[Any, Variable], dim: Hashable - ) -> tuple[Index, IndexVars]: + def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) -> Index: raise NotImplementedError( f"{cls!r} cannot be used for creating an index of stacked coordinates" ) @@ -647,7 +645,7 @@ def concat( # type: ignore[override] @classmethod def stack( cls, variables: Mapping[Any, Variable], dim: Hashable - ) -> tuple[PandasMultiIndex, IndexVars]: + ) -> PandasMultiIndex: """Create a new Pandas MultiIndex from the product of 1-d variables (levels) along a new dimension. @@ -672,8 +670,9 @@ def stack( labels = [x.ravel() for x in labels_mesh] index = pd.MultiIndex(levels, labels, sortorder=0, names=variables.keys()) + level_coords_dtype = {k: var.dtype for k, var in variables.items()} - return cls.from_pandas_index(index, dim, var_meta=_get_var_metadata(variables)) + return cls(index, dim, level_coords_dtype=level_coords_dtype) def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: clean_index = remove_unused_levels_categories(self.index) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 106acd3264f..646a9fcba04 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -301,7 +301,7 @@ def test_stack(self) -> None: "y": xr.Variable("y", pd.Index([1, 3, 2])), } - index, index_vars = PandasMultiIndex.stack(prod_vars, "z") + index = PandasMultiIndex.stack(prod_vars, "z") assert index.dim == "z" assert index.index.names == ["x", "y"] @@ -309,15 +309,6 @@ def test_stack(self) -> None: index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] ) - assert list(index_vars) == ["z", "x", "y"] - midx = pd.MultiIndex.from_product([["b", "a"], [1, 3, 2]]) - assert_equal(xr.IndexVariable("z", midx), index_vars["z"]) - assert_identical( - xr.IndexVariable("z", ["b", "b", "b", "a", "a", "a"], attrs={"foo": "bar"}), - index_vars["x"], - ) - assert_identical(xr.IndexVariable("z", [1, 3, 2, 1, 3, 2]), index_vars["y"]) - with pytest.raises( ValueError, match=r"conflicting dimensions for multi-index product.*" ): @@ -332,7 +323,7 @@ def test_stack_non_unique(self) -> None: "y": xr.Variable("y", pd.Index([1, 1, 2])), } - index, _ = PandasMultiIndex.stack(prod_vars, "z") + index = PandasMultiIndex.stack(prod_vars, "z") np.testing.assert_array_equal( index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] From 2aea2e34a15d5e95f5fb3c975f5b2e61798cc5a8 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 15 Feb 2022 16:59:10 +0100 Subject: [PATCH 128/159] PandasMultiIndex class methods: return only the index --- xarray/core/dataset.py | 19 +++++++++++++------ xarray/core/indexes.py | 24 ++++++++++++------------ 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 84563f17f3c..d6f653adc12 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3740,7 +3740,8 @@ def set_index( f"dimension mismatch: try setting an index for dimension {dim!r} with " f"variable {var_name!r} that has dimensions {var.dims}" ) - idx, idx_vars = PandasIndex.from_variables({dim: var}) + idx, _ = PandasIndex.from_variables({dim: var}) + idx_var_names = (var_name,) else: if append: current_variables = { @@ -3748,14 +3749,18 @@ def set_index( } else: current_variables = {} - idx, idx_vars = PandasMultiIndex.from_variables_maybe_expand( + idx = PandasMultiIndex.from_variables_maybe_expand( dim, current_variables, {k: self._variables[k] for k in var_names}, ) + idx_var_names = idx.index.names for n in idx.index.names: replace_dims[n] = dim + idx_vars = idx.create_variables( + {k: self._variables[k] for k in idx_var_names} + ) new_indexes.update({k: idx for k in idx_vars}) new_variables.update(idx_vars) @@ -3836,7 +3841,8 @@ def reset_index( if k not in dims_or_levels } if level_vars: - idx, idx_vars = index.keep_levels(level_vars) + idx = index.keep_levels(level_vars) + idx_vars = idx.create_variables(level_vars) new_indexes.update({k: idx for k in idx_vars}) new_variables.update(idx_vars) replaced_indexes.append(index) @@ -3891,10 +3897,11 @@ def reorder_levels( if not isinstance(index, PandasMultiIndex): raise ValueError(f"coordinate {dim} has no MultiIndex") - idx, idx_vars = index.reorder_levels({k: self._variables[k] for k in order}) - - new_variables.update(idx_vars) + level_vars = {k: self._variables[k] for k in order} + idx = index.reorder_levels(level_vars) + idx_vars = idx.create_variables(level_vars) new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) indexes = {k: v for k, v in self.xindexes.items() if k not in new_indexes} indexes.update(new_indexes) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 55be3deb052..7354661e6cc 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -692,7 +692,7 @@ def from_variables_maybe_expand( dim: Hashable, current_variables: Mapping[Any, Variable], variables: Mapping[Any, Variable], - ) -> tuple[PandasMultiIndex, IndexVars]: + ) -> PandasMultiIndex: """Create a new multi-index maybe by expanding an existing one with new variables as index levels. @@ -735,39 +735,39 @@ def from_variables_maybe_expand( level_variables[name] = var index = pd.MultiIndex(levels, codes, names=names) + level_coords_dtype = {k: var.dtype for k, var in level_variables.items()} - return cls.from_pandas_index( - index, dim, var_meta=_get_var_metadata(level_variables) - ) + return cls(index, dim, level_coords_dtype=level_coords_dtype) def keep_levels( self, level_variables: Mapping[Any, Variable] - ) -> tuple[PandasMultiIndex | PandasIndex, IndexVars]: + ) -> PandasMultiIndex | PandasIndex: """Keep only the provided levels and return a new multi-index with its corresponding coordinates. """ - var_meta = _get_var_metadata(level_variables) index = self.index.droplevel( [k for k in self.index.names if k not in level_variables] ) if isinstance(index, pd.MultiIndex): - return self.from_pandas_index(index, self.dim, var_meta=var_meta) + level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} + return self._replace(index, level_coords_dtype=level_coords_dtype) else: - return PandasIndex.from_pandas_index(index, self.dim, var_meta=var_meta) + return PandasIndex( + index, self.dim, coord_dtype=self.level_coords_dtype[index.name] + ) def reorder_levels( self, level_variables: Mapping[Any, Variable] - ) -> tuple[PandasMultiIndex, IndexVars]: + ) -> PandasMultiIndex: """Re-arrange index levels using input order and return a new multi-index with its corresponding coordinates. """ index = self.index.reorder_levels(level_variables.keys()) - return self.from_pandas_index( - index, self.dim, var_meta=_get_var_metadata(level_variables) - ) + level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} + return self._replace(index, level_coords_dtype=level_coords_dtype) @classmethod def from_pandas_index( From 32d887cc083b4a3b50446eb48ae348237cac33fb Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 15 Feb 2022 18:27:43 +0100 Subject: [PATCH 129/159] wip get rid of Pandas(Multi)Index.from_pandas_index --- xarray/core/indexes.py | 124 +++++++++++++++++------------------ xarray/tests/test_indexes.py | 26 +++++++- 2 files changed, 86 insertions(+), 64 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 7354661e6cc..f2654aeb824 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -515,42 +515,6 @@ def _check_dim_compat(variables: Mapping[Any, Variable], all_dims: str = "equal" ) -def _get_var_metadata(variables: Mapping[Any, Variable]) -> dict[Any, dict[str, Any]]: - return { - name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} - for name, var in variables.items() - } - - -def _create_variables_from_multiindex(index, dim, var_meta=None): - from .variable import IndexVariable - - if var_meta is None: - var_meta = {} - - def create_variable(name): - if name == dim: - level = None - else: - level = name - meta = var_meta.get(name, {}) - data = PandasMultiIndexingAdapter(index, dtype=meta.get("dtype"), level=level) - return IndexVariable( - dim, - data, - attrs=meta.get("attrs"), - encoding=meta.get("encoding"), - fastpath=True, - ) - - variables = {} - variables[dim] = create_variable(dim) - for level in index.names: - variables[level] = create_variable(level) - - return variables - - def remove_unused_levels_categories(index: pd.Index) -> pd.Index: """ Remove unused levels from MultiIndex and unused categories from CategoricalIndex @@ -588,6 +552,17 @@ class PandasMultiIndex(PandasIndex): def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None): super().__init__(array, dim) + # default index level names + names = [] + for i, idx in enumerate(self.index.levels): + name = idx.name or f"{dim}_level_{i}" + if name == dim: + raise ValueError( + f"conflicting multi-index level name {name!r} with dimension {dim!r}" + ) + names.append(name) + self.index.names = names + if level_coords_dtype is None: level_coords_dtype = { idx.name: get_valid_numpy_dtype(idx) for idx in self.index.levels @@ -616,10 +591,7 @@ def from_variables( level_coords_dtype = {name: var.dtype for name, var in variables.items()} obj = cls(index, dim, level_coords_dtype=level_coords_dtype) - index_vars = _create_variables_from_multiindex( - index, dim, var_meta=_get_var_metadata(variables) - ) - + index_vars = obj.create_variables(variables) return obj, index_vars @classmethod @@ -797,25 +769,46 @@ def from_pandas_index( index = index.rename(names) index.name = dim - index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta) - return cls(index, dim, level_coords_dtype=level_coords_dtype), index_vars + + obj = cls(index, dim, level_coords_dtype=level_coords_dtype) + index_vars = obj.create_variables() + return obj, index_vars def create_variables( self, variables: Mapping[Any, Variable] | None = None ) -> IndexVars: - var_meta = {} - if variables is not None: - for name in self.index.names: - var = variables[name] - var_meta[name] = { - "dtype": self.level_coords_dtype[name], - "attrs": var.attrs, - "encoding": var.encoding, - } + from .variable import IndexVariable - return _create_variables_from_multiindex( - self.index, self.dim, var_meta=var_meta - ) + if variables is None: + variables = {} + + index_vars: IndexVars = {} + for name in (self.dim,) + self.index.names: + if name == self.dim: + level = None + dtype = None + else: + level = name + dtype = self.level_coords_dtype[name] + + var = variables.get(name, None) + if var is not None: + attrs = var.attrs + encoding = var.encoding + else: + attrs = {} + encoding = {} + + data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) + index_vars[name] = IndexVariable( + self.dim, + data, + attrs=attrs, + encoding=encoding, + fastpath=True, + ) + + return index_vars def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: from .dataarray import DataArray @@ -924,23 +917,27 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: indexer = DataArray(indexer, coords=coords, dims=label.dims) if new_index is not None: - # variable(s) attrs and encoding metadata are propagated - # when replacing the indexes in the resulting xarray object - var_meta = {k: {"dtype": v} for k, v in self.level_coords_dtype.items()} - if isinstance(new_index, pd.MultiIndex): - new_index, new_vars = self.from_pandas_index( - new_index, self.dim, var_meta=var_meta + level_coords_dtype = { + k: self.level_coords_dtype[k] for k in new_index.names + } + new_index = self._replace( + new_index, level_coords_dtype=level_coords_dtype ) dims_dict = {} drop_coords = [] else: - new_index, new_vars = PandasIndex.from_pandas_index( - new_index, new_index.name, var_meta=var_meta + new_index = PandasIndex( + new_index, + new_index.name, + coord_dtype=self.level_coords_dtype[new_index.name], ) dims_dict = {self.dim: new_index.index.name} drop_coords = [self.dim] + # variable(s) attrs and encoding metadata are propagated + # when replacing the indexes in the resulting xarray object + new_vars = new_index.create_variables() indexes = cast(Dict[Any, Index], {k: new_index for k in new_vars}) # add scalar variable for each dropped level @@ -1016,7 +1013,8 @@ def create_default_index_implicit( index: PandasIndex if isinstance(array, pd.MultiIndex): - index, index_vars = PandasMultiIndex.from_pandas_index(array, name) + index = PandasMultiIndex(array, name) + index_vars = index.create_variables() # check for conflict between level names and variable names duplicate_names = [k for k in index_vars if k in all_variables and k != name] if duplicate_names: diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 646a9fcba04..671b28c9174 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -242,6 +242,30 @@ def test_getitem(self) -> None: class TestPandasMultiIndex: + def test_constructor(self) -> None: + foo_data = np.array([0, 0, 1], dtype="int64") + bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) + + index = PandasMultiIndex(pd_idx, "x") + + assert index.dim == "x" + assert index.index.equals(pd_idx) + assert index.index.names == ("foo", "bar") + assert index.index.name == "x" + assert index.level_coords_dtype == { + "foo": foo_data.dtype, + "bar": bar_data.dtype, + } + + with pytest.raises(ValueError, match=".*conflicting multi-index level name.*"): + PandasMultiIndex(pd_idx, "foo") + + # default level names + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data]) + index = PandasMultiIndex(pd_idx, "x") + assert index.index.names == ("x_level_0", "x_level_1") + def test_from_variables(self) -> None: v_level1 = xr.Variable( "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} @@ -373,7 +397,7 @@ def test_create_variables(self) -> None: "bar": IndexVariable("x", bar_data, encoding={"fill_value": 0}), } - index, _ = PandasMultiIndex.from_pandas_index(pd_idx, "x") + index = PandasMultiIndex(pd_idx, "x") actual = index.create_variables(index_vars) for k, expected in index_vars.items(): From 851a94261900ff7cdbe7ef40f8324edea8453997 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 15 Feb 2022 20:29:40 +0100 Subject: [PATCH 130/159] remove Pandas(Multi)Index.from_pandas_index --- xarray/core/indexes.py | 65 ----------------------------------- xarray/tests/test_indexes.py | 45 ++++-------------------- xarray/tests/test_indexing.py | 14 ++++---- 3 files changed, 14 insertions(+), 110 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index f2654aeb824..a6327563bc4 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -328,38 +328,6 @@ def create_variables( var = IndexVariable(self.dim, data, attrs=attrs, encoding=encoding) return {name: var} - @classmethod - def from_pandas_index( - cls, - index: pd.Index, - dim: Hashable, - var_meta: dict[Any, dict] | None = None, - ) -> tuple[PandasIndex, IndexVars]: - from .variable import IndexVariable - - if index.name is None: - name = dim - index = index.copy() - index.name = dim - else: - name = index.name - - if var_meta is None: - var_meta = {name: {}} - - data = PandasIndexingAdapter(index, dtype=var_meta[name].get("dtype")) - index_var = IndexVariable( - dim, - data, - fastpath=True, - attrs=var_meta[name].get("attrs"), - encoding=var_meta[name].get("encoding"), - ) - - return cls(index, dim, coord_dtype=var_meta[name].get("dtype")), { - name: index_var - } - def to_pandas_index(self) -> pd.Index: return self.index @@ -741,39 +709,6 @@ def reorder_levels( level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} return self._replace(index, level_coords_dtype=level_coords_dtype) - @classmethod - def from_pandas_index( - cls, - index: pd.MultiIndex, - dim: Hashable, - var_meta: dict[Any, dict] | None = None, - ) -> tuple[PandasMultiIndex, IndexVars]: - - names = [] - idx_dtypes = {} - for i, idx in enumerate(index.levels): - name = idx.name or f"{dim}_level_{i}" - if name == dim: - raise ValueError( - f"conflicting multi-index level name {name!r} with dimension {dim!r}" - ) - names.append(name) - idx_dtypes[name] = idx.dtype - - if var_meta is None: - var_meta = {k: {} for k in names} - for name, dtype in idx_dtypes.items(): - var_meta[name]["dtype"] = var_meta[name].get("dtype", dtype) - - level_coords_dtype = {k: var_meta[k]["dtype"] for k in names} - - index = index.rename(names) - index.name = dim - - obj = cls(index, dim, level_coords_dtype=level_coords_dtype) - index_vars = obj.create_variables() - return obj, index_vars - def create_variables( self, variables: Mapping[Any, Variable] | None = None ) -> IndexVars: diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 671b28c9174..b504328bb11 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -37,6 +37,12 @@ def test_constructor(self) -> None: assert index.index.equals(pd_idx) assert index.dim == "x" + # test no name set for pd.Index + pd_idx.name = None + index = PandasIndex(pd_idx, "x") + assert index.index is not pd_idx + assert index.index.name == "x" + def test_from_variables(self) -> None: # pandas has only Float64Index but variable dtype should be preserved data = np.array([1.1, 2.2, 3.3], dtype=np.float32) @@ -98,26 +104,9 @@ def test_concat_empty(self) -> None: idx = PandasIndex.concat([], "x") assert idx.coord_dtype is np.dtype("O") - def test_from_pandas_index(self) -> None: - pd_idx = pd.Index([1, 2, 3], name="foo") - - index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x") - - assert index.dim == "x" - assert index.index is pd_idx - assert index.index.name == "foo" - assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3])) - - # test no name set for pd.Index - pd_idx.name = None - index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x") - assert "x" in index_vars - assert index.index is not pd_idx - assert index.index.name == "x" - def test_create_variables(self) -> None: pd_idx = pd.Index([1, 2, 3], name="foo") - index, _ = PandasIndex.from_pandas_index(pd_idx, "x") + index = PandasIndex(pd_idx, "x") index_vars = { "foo": IndexVariable( "x", pd_idx, attrs={"unit": "m"}, encoding={"fill_value": 0} @@ -367,26 +356,6 @@ def test_unstack(self) -> None: assert new_indexes["two"].equals(PandasIndex([1, 2, 3], "two")) assert new_pd_idx.equals(pd_midx) - def test_from_pandas_index(self) -> None: - foo_data = np.array([0, 0, 1], dtype="int64") - bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") - pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) - - index, index_vars = PandasMultiIndex.from_pandas_index(pd_idx, "x") - - assert index.dim == "x" - assert index.index.equals(pd_idx) - assert index.index.names == ("foo", "bar") - assert index.index.name == "x" - assert_identical(index_vars["x"], IndexVariable("x", pd_idx)) - assert_identical(index_vars["foo"], IndexVariable("x", foo_data)) - assert_identical(index_vars["bar"], IndexVariable("x", bar_data)) - assert index_vars["foo"].dtype == foo_data.dtype - assert index_vars["bar"].dtype == bar_data.dtype - - with pytest.raises(ValueError, match=".*conflicting multi-index level name.*"): - PandasMultiIndex.from_pandas_index(pd_idx, "foo") - def test_create_variables(self) -> None: foo_data = np.array([0, 0, 1], dtype="int") bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 79bf71d5242..de9393bb9d2 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -91,13 +91,13 @@ def test_map_index_queries(self) -> None: def create_query_results( x_indexer, x_index, - index_vars, other_vars, drop_coords, drop_indexes, rename_dims, ): dim_indexers = {"x": x_indexer} + index_vars = x_index.create_variables() indexes = {k: x_index for k in index_vars} variables = {} variables.update(index_vars) @@ -147,7 +147,7 @@ def test_indexer( expected = create_query_results( [True, True, False, False, False, False, False, False], - *PandasIndex.from_pandas_index(pd.Index([-1, -2]), "three"), + PandasIndex(pd.Index([-1, -2]), "three"), {"one": Variable((), "a"), "two": Variable((), 1)}, ["x"], ["one", "two"], @@ -157,7 +157,7 @@ def test_indexer( expected = create_query_results( slice(0, 4, None), - *PandasMultiIndex.from_pandas_index( + PandasMultiIndex( pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), "x", ), @@ -170,7 +170,7 @@ def test_indexer( expected = create_query_results( [True, True, True, True, False, False, False, False], - *PandasMultiIndex.from_pandas_index( + PandasMultiIndex( pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), "x", ), @@ -200,7 +200,7 @@ def test_indexer( expected = create_query_results( [True, True, False, False, False, False, False, False], - *PandasIndex.from_pandas_index(pd.Index([-1, -2]), "three"), + PandasIndex(pd.Index([-1, -2]), "three"), {"one": Variable((), "a"), "two": Variable((), 1)}, ["x"], ["one", "two"], @@ -210,7 +210,7 @@ def test_indexer( expected = create_query_results( [True, False, True, False, False, False, False, False], - *PandasIndex.from_pandas_index(pd.Index([1, 2]), "two"), + PandasIndex(pd.Index([1, 2]), "two"), {"one": Variable((), "a"), "three": Variable((), -1)}, ["x"], ["one", "three"], @@ -220,7 +220,7 @@ def test_indexer( expected = create_query_results( [True, True, True, True, False, False, False, False], - *PandasMultiIndex.from_pandas_index( + PandasMultiIndex( pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), "x", ), From 3134f33ebc1d31cf7d75071903bde50a75de19ed Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 15 Feb 2022 21:05:45 +0100 Subject: [PATCH 131/159] Index.from_variables: return the new index only Use Index.create_variables to get the new coordinate variables. --- xarray/core/dataset.py | 8 ++++---- xarray/core/indexes.py | 27 ++++++++------------------- xarray/tests/test_indexes.py | 31 ++++++++++++++++--------------- 3 files changed, 28 insertions(+), 38 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d6f653adc12..f12515ae4ec 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3740,7 +3740,7 @@ def set_index( f"dimension mismatch: try setting an index for dimension {dim!r} with " f"variable {var_name!r} that has dimensions {var.dims}" ) - idx, _ = PandasIndex.from_variables({dim: var}) + idx = PandasIndex.from_variables({dim: var}) idx_var_names = (var_name,) else: if append: @@ -7361,9 +7361,9 @@ def pad( ) # reset default index of dimension coordinates if (name,) == var.dims: - index, index_vars = PandasIndex.from_variables( - {name: variables[name]} - ) + dim_var = {name: variables[name]} + index = PandasIndex.from_variables(dim_var) + index_vars = index.create_variables(dim_var) indexes[name] = index variables[name] = index_vars[name] diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index a6327563bc4..a83c65738df 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -34,9 +34,7 @@ class Index: """Base class inherited by all xarray-compatible indexes.""" @classmethod - def from_variables( - cls, variables: Mapping[Any, Variable] - ) -> tuple[Index, IndexVars]: + def from_variables(cls, variables: Mapping[Any, Variable]) -> Index: raise NotImplementedError() @classmethod @@ -232,11 +230,7 @@ def _replace(self, index, dim=None, coord_dtype=None): return type(self)(index, dim, coord_dtype) @classmethod - def from_variables( - cls, variables: Mapping[Any, Variable] - ) -> tuple[PandasIndex, IndexVars]: - from .variable import IndexVariable - + def from_variables(cls, variables: Mapping[Any, Variable]) -> PandasIndex: if len(variables) != 1: raise ValueError( f"PandasIndex only accepts one variable, found {len(variables)} variables" @@ -267,12 +261,8 @@ def from_variables( obj = cls(data, dim, coord_dtype=var.dtype) assert not isinstance(obj.index, pd.MultiIndex) obj.index.name = name - data = PandasIndexingAdapter(obj.index, dtype=var.dtype) - index_var = IndexVariable( - dim, data, attrs=var.attrs, encoding=var.encoding, fastpath=True - ) - return obj, {name: index_var} + return obj @staticmethod def _concat_indexes(indexes, dim, positions=None) -> pd.Index: @@ -546,9 +536,7 @@ def _replace(self, index, dim=None, level_coords_dtype=None) -> PandasMultiIndex return type(self)(index, dim, level_coords_dtype) @classmethod - def from_variables( - cls, variables: Mapping[Any, Variable] - ) -> tuple[PandasMultiIndex, IndexVars]: + def from_variables(cls, variables: Mapping[Any, Variable]) -> PandasMultiIndex: _check_dim_compat(variables) dim = next(iter(variables.values())).dims[0] @@ -559,8 +547,7 @@ def from_variables( level_coords_dtype = {name: var.dtype for name, var in variables.items()} obj = cls(index, dim, level_coords_dtype=level_coords_dtype) - index_vars = obj.create_variables(variables) - return obj, index_vars + return obj @classmethod def concat( # type: ignore[override] @@ -971,7 +958,9 @@ def create_default_index_implicit( f"conflicting MultiIndex level / variable name(s):\n{conflict_str}" ) else: - index, index_vars = PandasIndex.from_variables({name: dim_variable}) + dim_var = {name: dim_variable} + index = PandasIndex.from_variables(dim_var) + index_vars = index.create_variables(dim_var) return index, index_vars diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index b504328bb11..93414045481 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -50,11 +50,10 @@ def test_from_variables(self) -> None: "x", data, attrs={"unit": "m"}, encoding={"dtype": np.float64} ) - index, index_vars = PandasIndex.from_variables({"x": var}) - assert_identical(var.to_index_variable(), index_vars["x"]) - assert index_vars["x"].dtype == var.dtype + index = PandasIndex.from_variables({"x": var}) assert index.dim == "x" - assert index.index.equals(index_vars["x"].to_index()) + assert index.index.equals(pd.Index(data)) + assert index.coord_dtype == data.dtype var2 = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) with pytest.raises(ValueError, match=r".*only accepts one variable.*"): @@ -71,7 +70,7 @@ def test_from_variables_index_adapter(self) -> None: pd_idx = pd.Index(data) var = xr.Variable("x", pd_idx) - index, _ = PandasIndex.from_variables({"x": var}) + index = PandasIndex.from_variables({"x": var}) assert isinstance(index.index, pd.CategoricalIndex) def test_concat_periods(self): @@ -105,16 +104,20 @@ def test_concat_empty(self) -> None: assert idx.coord_dtype is np.dtype("O") def test_create_variables(self) -> None: - pd_idx = pd.Index([1, 2, 3], name="foo") - index = PandasIndex(pd_idx, "x") + # pandas has only Float64Index but variable dtype should be preserved + data = np.array([1.1, 2.2, 3.3], dtype=np.float32) + pd_idx = pd.Index(data, name="foo") + index = PandasIndex(pd_idx, "x", coord_dtype=data.dtype) index_vars = { "foo": IndexVariable( - "x", pd_idx, attrs={"unit": "m"}, encoding={"fill_value": 0} + "x", data, attrs={"unit": "m"}, encoding={"fill_value": 0.0} ) } actual = index.create_variables(index_vars) assert_identical(actual["foo"], index_vars["foo"]) + assert actual["foo"].dtype == index_vars["foo"].dtype + assert actual["foo"].dtype == index.coord_dtype def test_to_pandas_index(self) -> None: pd_idx = pd.Index([1, 2, 3], name="foo") @@ -263,7 +266,7 @@ def test_from_variables(self) -> None: "x", ["a", "b", "c"], attrs={"unit": "m"}, encoding={"dtype": "U"} ) - index, index_vars = PandasMultiIndex.from_variables( + index = PandasMultiIndex.from_variables( {"level1": v_level1, "level2": v_level2} ) @@ -273,11 +276,6 @@ def test_from_variables(self) -> None: assert index.index.name == "x" assert index.index.names == ["level1", "level2"] - assert list(index_vars) == ["x", "level1", "level2"] - assert_equal(xr.IndexVariable("x", expected_idx), index_vars["x"]) - assert_identical(v_level1.to_index_variable(), index_vars["level1"]) - assert_identical(v_level2.to_index_variable(), index_vars["level2"]) - var = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) with pytest.raises( ValueError, match=r".*only accepts 1-dimensional variables.*" @@ -357,7 +355,7 @@ def test_unstack(self) -> None: assert new_pd_idx.equals(pd_midx) def test_create_variables(self) -> None: - foo_data = np.array([0, 0, 1], dtype="int") + foo_data = np.array([0, 0, 1], dtype="int64") bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) index_vars = { @@ -371,6 +369,9 @@ def test_create_variables(self) -> None: for k, expected in index_vars.items(): assert_identical(actual[k], expected) + assert actual[k].dtype == expected.dtype + if k != "x": + assert actual[k].dtype == index.level_coords_dtype[k] def test_sel(self) -> None: index = PandasMultiIndex( From 9bbf6597b1e5efd330314fcecee936a48d7c6ea4 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 15 Feb 2022 21:11:37 +0100 Subject: [PATCH 132/159] rename: propagate_attrs_encoding not needed Index.create_variables already propagates coordinate variable metadata. --- xarray/core/dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f12515ae4ec..61cd375718b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3274,7 +3274,6 @@ def _rename_all(self, name_dict, dims_dict): dims = self._rename_dims(dims_dict) indexes, index_vars = self._rename_indexes(name_dict, dims_dict) - propagate_attrs_encoding(variables, index_vars) variables = {k: index_vars.get(k, v) for k, v in variables.items()} return variables, coord_names, dims, indexes From af3483425851067c0d8ca37ce531529506971a9b Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 16 Feb 2022 12:32:46 +0100 Subject: [PATCH 133/159] align exclude dims: pass through indexes --- xarray/core/alignment.py | 1 + xarray/core/dataarray.py | 9 ++++++++- xarray/core/dataset.py | 19 +++++++++++++++---- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 5a27f930ed3..310b198a817 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -514,6 +514,7 @@ def _reindex_one( new_variables, new_indexes, self.fill_value, + self.exclude_dims, self.exclude_vars, ) new_obj.encoding = obj.encoding diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 2e6abf6d685..d303ffca5c1 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1479,6 +1479,7 @@ def _reindex_callback( variables: dict[Hashable, Variable], indexes: dict[Hashable, Index], fill_value: Any, + exclude_dims: frozenset[Hashable], exclude_vars: frozenset[Hashable], ) -> DataArray: """Callback called from ``Aligner`` to create a new reindexed DataArray.""" @@ -1492,7 +1493,13 @@ def _reindex_callback( ds = self._to_temp_dataset() reindexed = ds._reindex_callback( - aligner, dim_pos_indexers, variables, indexes, fill_value, exclude_vars + aligner, + dim_pos_indexers, + variables, + indexes, + fill_value, + exclude_dims, + exclude_vars, ) return self._from_temp_dataset(reindexed) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 61cd375718b..765cb5154da 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2551,17 +2551,28 @@ def _reindex_callback( variables: dict[Hashable, Variable], indexes: dict[Hashable, Index], fill_value: Any, + exclude_dims: frozenset[Hashable], exclude_vars: frozenset[Hashable], ) -> Dataset: """Callback called from ``Aligner`` to create a new reindexed Dataset.""" new_variables = variables.copy() + new_indexes = indexes.copy() + + # pass through indexes from excluded dimensions + # no extra check needed for multi-coordinate indexes, potential conflicts + # should already have been detected when aligning the indexes + for name, idx in self.xindexes.items(): + var = self._variables[name] + if set(var.dims) <= exclude_dims: + new_indexes[name] = idx + new_variables[name] = var if not dim_pos_indexers: # fast path for no reindexing necessary - if set(indexes) - set(self.xindexes): + if set(new_indexes) - set(self.xindexes): # this only adds new indexes and their coordinate variables - reindexed = self._overwrite_indexes(indexes, variables) + reindexed = self._overwrite_indexes(new_indexes, new_variables) else: reindexed = self.copy(deep=aligner.copy) else: @@ -2578,9 +2589,9 @@ def _reindex_callback( sparse=aligner.sparse, ) new_variables.update(reindexed_vars) - new_coord_names = self._coord_names | set(indexes) + new_coord_names = self._coord_names | set(new_indexes) reindexed = self._replace_with_new_dims( - new_variables, new_coord_names, indexes=indexes + new_variables, new_coord_names, indexes=new_indexes ) return reindexed From fe4fcc5d6c272dbae311479a2687d1d64c8affd4 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 16 Feb 2022 13:50:57 +0100 Subject: [PATCH 134/159] fix set_index append=True Level variables may have updated names in this case: single index + one level. --- xarray/core/dataset.py | 8 ++------ xarray/core/indexes.py | 6 ++++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 765cb5154da..bece58d5f1b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3751,7 +3751,7 @@ def set_index( f"variable {var_name!r} that has dimensions {var.dims}" ) idx = PandasIndex.from_variables({dim: var}) - idx_var_names = (var_name,) + idx_vars = idx.create_variables({var_name: var}) else: if append: current_variables = { @@ -3759,18 +3759,14 @@ def set_index( } else: current_variables = {} - idx = PandasMultiIndex.from_variables_maybe_expand( + idx, idx_vars = PandasMultiIndex.from_variables_maybe_expand( dim, current_variables, {k: self._variables[k] for k in var_names}, ) - idx_var_names = idx.index.names for n in idx.index.names: replace_dims[n] = dim - idx_vars = idx.create_variables( - {k: self._variables[k] for k in idx_var_names} - ) new_indexes.update({k: idx for k in idx_vars}) new_variables.update(idx_vars) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index a83c65738df..bb039c29483 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -619,7 +619,7 @@ def from_variables_maybe_expand( dim: Hashable, current_variables: Mapping[Any, Variable], variables: Mapping[Any, Variable], - ) -> PandasMultiIndex: + ) -> tuple[PandasMultiIndex, IndexVars]: """Create a new multi-index maybe by expanding an existing one with new variables as index levels. @@ -663,8 +663,10 @@ def from_variables_maybe_expand( index = pd.MultiIndex(levels, codes, names=names) level_coords_dtype = {k: var.dtype for k, var in level_variables.items()} + obj = cls(index, dim, level_coords_dtype=level_coords_dtype) + index_vars = obj.create_variables(level_variables) - return cls(index, dim, level_coords_dtype=level_coords_dtype) + return obj, index_vars def keep_levels( self, level_variables: Mapping[Any, Variable] From df31fa81164af55c7d4fd09a285b75bb49f441eb Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 16 Feb 2022 13:54:05 +0100 Subject: [PATCH 135/159] refactor default_indexes and re-enable invariant check Default indexes invariant check is now optional (enabled by default). --- xarray/core/indexes.py | 20 ++++++++++++---- xarray/core/variable.py | 13 ----------- xarray/testing.py | 42 ++++++++++++++++++++++------------ xarray/tests/__init__.py | 18 +++++++-------- xarray/tests/test_dataarray.py | 16 ++++++------- xarray/tests/test_dataset.py | 16 ++++++++----- xarray/tests/test_indexes.py | 5 ++-- xarray/tests/test_units.py | 15 +++++++++--- 8 files changed, 85 insertions(+), 60 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index bb039c29483..33392f8a0b7 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -208,11 +208,12 @@ class PandasIndex(Index): __slots__ = ("index", "dim", "coord_dtype") def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None): - index = utils.safe_cast_to_index(array) + # make a shallow copy: cheap and because the index name may be updated + # here or in other constructors (cannot use pd.Index.rename as this + # constructor is also called from PandasMultiIndex) + index = utils.safe_cast_to_index(array).copy() + if index.name is None: - # cannot use pd.Index.rename as this constructor is also - # called from PandasMultiIndex - index = index.copy() index.name = dim self.index = index @@ -1200,7 +1201,16 @@ def default_indexes( Mapping from indexing keys (levels/dimension names) to indexes used for indexing along that dimension. """ - return {key: coords[key]._to_xindex() for key in dims if key in coords} + indexes: dict[Hashable, Index] = {} + coord_names = set(coords) + + for name, var in coords.items(): + if name in dims: + index, index_vars = create_default_index_implicit(var, coords) + if set(index_vars) <= coord_names: + indexes.update({k: index for k in index_vars}) + + return indexes def indexes_equal( diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 60b389798e1..a74c665b371 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -17,7 +17,6 @@ from . import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from .arithmetic import VariableArithmetic from .common import AbstractArray -from .indexes import PandasIndex, PandasMultiIndex from .indexing import ( BasicIndexer, OuterIndexer, @@ -531,18 +530,6 @@ def to_index_variable(self): to_coord = utils.alias(to_index_variable, "to_coord") - def _to_xindex(self): - # temporary function used internally as a replacement of to_index() - # returns an xarray Index instance instead of a pd.Index instance - index_var = self.to_index_variable() - index = index_var.to_index() - dim = index_var.dims[0] - - if isinstance(index, pd.MultiIndex): - return PandasMultiIndex(index, dim) - else: - return PandasIndex(index, dim) - def to_index(self): """Convert this variable to a pandas.Index""" return self.to_index_variable().to_index() diff --git a/xarray/testing.py b/xarray/testing.py index b1a6b3741f7..0df34a60e73 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -9,7 +9,7 @@ from xarray.core import duck_array_ops, formatting, utils from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset -from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex +from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes from xarray.core.variable import IndexVariable, Variable __all__ = ( @@ -252,7 +252,9 @@ def assert_chunks_equal(a, b): assert left.chunks == right.chunks -def _assert_indexes_invariants_checks(indexes, possible_coord_variables): +def _assert_indexes_invariants_checks( + indexes, possible_coord_variables, dims, check_default=True +): assert isinstance(indexes, dict), indexes assert all(isinstance(v, Index) for v in indexes.values()), { k: type(v) for k, v in indexes.items() @@ -292,11 +294,13 @@ def _assert_indexes_invariants_checks(indexes, possible_coord_variables): # index identity is used to find unique indexes in `indexes` assert index is indexes[name], (pd_index, indexes[name].index) - # TODO: benbovy - explicit indexes: do we still need these checks? Or opt-in? - # non-default indexes are now supported. - # defaults = default_indexes(possible_coord_variables, dims) - # assert indexes.keys() == defaults.keys(), (set(indexes), set(defaults)) - # assert all(v.equals(defaults[k]) for k, v in indexes.items()), (indexes, defaults) + if check_default: + defaults = default_indexes(possible_coord_variables, dims) + assert indexes.keys() == defaults.keys(), (set(indexes), set(defaults)) + assert all(v.equals(defaults[k]) for k, v in indexes.items()), ( + indexes, + defaults, + ) def _assert_variable_invariants(var: Variable, name: Hashable = None): @@ -315,7 +319,7 @@ def _assert_variable_invariants(var: Variable, name: Hashable = None): assert isinstance(var._attrs, (type(None), dict)), name_or_empty + (var._attrs,) -def _assert_dataarray_invariants(da: DataArray): +def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool): assert isinstance(da._variable, Variable), da._variable _assert_variable_invariants(da._variable) @@ -332,10 +336,12 @@ def _assert_dataarray_invariants(da: DataArray): _assert_variable_invariants(v, k) if da._indexes is not None: - _assert_indexes_invariants_checks(da._indexes, da._coords) + _assert_indexes_invariants_checks( + da._indexes, da._coords, da.dims, check_default=check_default_indexes + ) -def _assert_dataset_invariants(ds: Dataset): +def _assert_dataset_invariants(ds: Dataset, check_default_indexes: bool): assert isinstance(ds._variables, dict), type(ds._variables) assert all(isinstance(v, Variable) for v in ds._variables.values()), ds._variables for k, v in ds._variables.items(): @@ -366,13 +372,17 @@ def _assert_dataset_invariants(ds: Dataset): } if ds._indexes is not None: - _assert_indexes_invariants_checks(ds._indexes, ds._variables) + _assert_indexes_invariants_checks( + ds._indexes, ds._variables, ds._dims, check_default=check_default_indexes + ) assert isinstance(ds._encoding, (type(None), dict)) assert isinstance(ds._attrs, (type(None), dict)) -def _assert_internal_invariants(xarray_obj: Union[DataArray, Dataset, Variable]): +def _assert_internal_invariants( + xarray_obj: Union[DataArray, Dataset, Variable], check_default_indexes: bool +): """Validate that an xarray object satisfies its own internal invariants. This exists for the benefit of xarray's own test suite, but may be useful @@ -382,9 +392,13 @@ def _assert_internal_invariants(xarray_obj: Union[DataArray, Dataset, Variable]) if isinstance(xarray_obj, Variable): _assert_variable_invariants(xarray_obj) elif isinstance(xarray_obj, DataArray): - _assert_dataarray_invariants(xarray_obj) + _assert_dataarray_invariants( + xarray_obj, check_default_indexes=check_default_indexes + ) elif isinstance(xarray_obj, Dataset): - _assert_dataset_invariants(xarray_obj) + _assert_dataset_invariants( + xarray_obj, check_default_indexes=check_default_indexes + ) else: raise TypeError( "{} is not a supported type for xarray invariant checks".format( diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 20dfdaf5076..743410e9029 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -174,25 +174,25 @@ def source_ndarray(array): # invariants -def assert_equal(a, b): +def assert_equal(a, b, check_default_indexes=True): __tracebackhide__ = True xarray.testing.assert_equal(a, b) - xarray.testing._assert_internal_invariants(a) - xarray.testing._assert_internal_invariants(b) + xarray.testing._assert_internal_invariants(a, check_default_indexes) + xarray.testing._assert_internal_invariants(b, check_default_indexes) -def assert_identical(a, b): +def assert_identical(a, b, check_default_indexes=True): __tracebackhide__ = True xarray.testing.assert_identical(a, b) - xarray.testing._assert_internal_invariants(a) - xarray.testing._assert_internal_invariants(b) + xarray.testing._assert_internal_invariants(a, check_default_indexes) + xarray.testing._assert_internal_invariants(b, check_default_indexes) -def assert_allclose(a, b, **kwargs): +def assert_allclose(a, b, check_default_indexes=True, **kwargs): __tracebackhide__ = True xarray.testing.assert_allclose(a, b, **kwargs) - xarray.testing._assert_internal_invariants(a) - xarray.testing._assert_internal_invariants(b) + xarray.testing._assert_internal_invariants(a, check_default_indexes) + xarray.testing._assert_internal_invariants(b, check_default_indexes) def create_test_data(seed=None, add_attrs=True): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index c78c9609b32..d6d846ea24e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1877,41 +1877,41 @@ def test_reset_index(self): expected = DataArray(self.mda.values, coords=coords, dims="x") obj = self.mda.reset_index("x") - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) assert len(obj.xindexes) == 0 obj = self.mda.reset_index(self.mindex.names) - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) assert len(obj.xindexes) == 0 obj = self.mda.reset_index(["x", "level_1"]) - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) assert list(obj.xindexes) == ["level_2"] expected = DataArray(self.mda.values, coords=coords, dims="x") obj = self.mda.reset_index(["level_1"]) - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) assert list(obj.xindexes) == ["level_2"] assert type(obj.xindexes["level_2"]) is PandasIndex coords = {k: v for k, v in coords.items() if k != "x"} expected = DataArray(self.mda.values, coords=coords, dims="x") obj = self.mda.reset_index("x", drop=True) - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) array = self.mda.copy() array = array.reset_index(["x"], drop=True) - assert_identical(array, expected) + assert_identical(array, expected, check_default_indexes=False) # single index array = DataArray([1, 2], coords={"x": ["a", "b"]}, dims="x") obj = array.reset_index("x") - assert_identical(obj, array) + assert_identical(obj, array, check_default_indexes=False) assert len(obj.xindexes) == 0 def test_reset_index_keep_attrs(self): coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True}) da = DataArray([1, 0], [coord_1]) obj = da.reset_index("coord_1") - assert_identical(obj, da) + assert_identical(obj, da, check_default_indexes=False) assert len(obj.xindexes) == 0 def test_reorder_levels(self): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 3f993fa86d8..8f2d7edec63 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2674,12 +2674,14 @@ def test_rename_dims(self): expected = Dataset( {"x": ("x_new", [0, 1, 2]), "y": ("x_new", [10, 11, 12]), "z": 42} ) + # TODO: (benbovy - explicit indexes) update when set_index supports + # seeting index for non-dimension variables expected = expected.set_coords("x") dims_dict = {"x": "x_new"} actual = original.rename_dims(dims_dict) - assert_identical(expected, actual) + assert_identical(expected, actual, check_default_indexes=False) actual_2 = original.rename_dims(**dims_dict) - assert_identical(expected, actual_2) + assert_identical(expected, actual_2, check_default_indexes=False) # Test to raise ValueError dims_dict_bad = {"x_bad": "x_new"} @@ -2694,12 +2696,14 @@ def test_rename_vars(self): expected = Dataset( {"x_new": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42} ) + # TODO: (benbovy - explicit indexes) update when set_index supports + # seeting index for non-dimension variables expected = expected.set_coords("x_new") name_dict = {"x": "x_new"} actual = original.rename_vars(name_dict) - assert_identical(expected, actual) + assert_identical(expected, actual, check_default_indexes=False) actual_2 = original.rename_vars(**name_dict) - assert_identical(expected, actual_2) + assert_identical(expected, actual_2, check_default_indexes=False) # Test to raise ValueError names_dict_bad = {"x_bad": "x_new"} @@ -3054,7 +3058,7 @@ def test_reset_index(self): expected = Dataset({}, coords=coords) obj = ds.reset_index("x") - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) assert len(obj.xindexes) == 0 ds = Dataset(coords={"y": ("x", [1, 2, 3])}) @@ -3065,7 +3069,7 @@ def test_reset_index_keep_attrs(self): coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True}) ds = Dataset({}, {"coord_1": coord_1}) obj = ds.reset_index("coord_1") - assert_identical(obj, ds) + assert_identical(obj, ds, check_default_indexes=False) assert len(obj.xindexes) == 0 def test_reorder_levels(self): diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 93414045481..0672f0659e8 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -35,12 +35,13 @@ def test_constructor(self) -> None: index = PandasIndex(pd_idx, "x") assert index.index.equals(pd_idx) + # makes a shallow copy + assert index.index is not pd_idx assert index.dim == "x" # test no name set for pd.Index pd_idx.name = None index = PandasIndex(pd_idx, "x") - assert index.index is not pd_idx assert index.index.name == "x" def test_from_variables(self) -> None: @@ -122,7 +123,7 @@ def test_create_variables(self) -> None: def test_to_pandas_index(self) -> None: pd_idx = pd.Index([1, 2, 3], name="foo") index = PandasIndex(pd_idx, "x") - assert index.to_pandas_index() is pd_idx + assert index.to_pandas_index() is index.index def test_sel(self) -> None: # TODO: add tests that aren't just for edge cases diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index c820ddd26ca..bc3cc367c0e 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -3609,7 +3609,10 @@ def test_stacking_stacked(self, func, dtype): actual = func(stacked) assert_units_equal(expected, actual) - assert_identical(expected, actual) + if func.name == "reset_index": + assert_identical(expected, actual, check_default_indexes=False) + else: + assert_identical(expected, actual) @pytest.mark.skip(reason="indexes don't support units") def test_to_unstacked_dataset(self, dtype): @@ -4718,7 +4721,10 @@ def test_stacking_stacked(self, variant, func, dtype): actual = func(stacked) assert_units_equal(expected, actual) - assert_equal(expected, actual) + if func.name == "reset_index": + assert_equal(expected, actual, check_default_indexes=False) + else: + assert_equal(expected, actual) @pytest.mark.xfail( reason="stacked dimension's labels have to be hashable, but is a numpy.array" @@ -5502,7 +5508,10 @@ def test_content_manipulation(self, func, variant, dtype): actual = func(ds) assert_units_equal(expected, actual) - assert_equal(expected, actual) + if func.name == "rename_dims": + assert_equal(expected, actual, check_default_indexes=False) + else: + assert_equal(expected, actual) @pytest.mark.parametrize( "unit,error", From 631a8e2a507985111d4ec7299369febaaa3f4b54 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 16 Feb 2022 18:27:11 +0100 Subject: [PATCH 136/159] fix formatting errors --- xarray/core/formatting.py | 2 +- xarray/core/formatting_html.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 5107708921e..c85f288ea27 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -402,7 +402,7 @@ def coords_repr(coords, col_width=None, max_rows=None): summarizer=summarize_variable, expand_option_name="display_expand_coords", col_width=col_width, - indexes=coords.indexes, + indexes=coords.xindexes, max_rows=max_rows, ) diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 2b53cdc499c..db62466a8d3 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -106,7 +106,7 @@ def summarize_variable(name, var, is_index=False, dtype=None): def summarize_coords(variables): li_items = [] for k, v in variables.items(): - li_content = summarize_variable(k, v, is_index=k in variables.indexes) + li_content = summarize_variable(k, v, is_index=k in variables.xindexes) li_items.append(f"
  • {li_content}
  • ") vars_li = "".join(li_items) From 21fa7a84987eb806a3b64869b5813e20cd9d4fbf Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 16 Feb 2022 18:27:39 +0100 Subject: [PATCH 137/159] fix type --- xarray/core/merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index de25869af29..7ba8b4e6cc2 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -660,7 +660,7 @@ class _MergeResult(NamedTuple): variables: dict[Hashable, Variable] coord_names: set[Hashable] dims: dict[Hashable, int] - indexes: dict[Hashable, pd.Index] + indexes: dict[Hashable, Index] attrs: dict[Hashable, Any] From f573085fa9173fce02dfa6ab53fab6d380a87efe Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 16 Feb 2022 18:28:56 +0100 Subject: [PATCH 138/159] Dataset._indexes, DataArray._indexes: always dict Default indexes are created by the default constructors. Indexes are always passed explicitly by the fastpath constructors. --- xarray/core/dataarray.py | 36 ++++++++++++++++++++++++++--------- xarray/core/dataset.py | 41 +++++++++++++++++++--------------------- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d303ffca5c1..d37c3a71e9b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -46,7 +46,6 @@ Index, Indexes, PandasMultiIndex, - default_indexes, filter_indexes_from_coords, isel_indexes, ) @@ -338,7 +337,7 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic): _cache: dict[str, Any] _coords: dict[Any, Variable] _close: Callable[[], None] | None - _indexes: dict[Hashable, Index] | None + _indexes: dict[Hashable, Index] _name: Hashable | None _variable: Variable @@ -375,7 +374,12 @@ def __init__( variable = data assert dims is None assert attrs is None + assert indexes is not None else: + # TODO: (benbovy - explicit indexes) remove assertion + # once it becomes part of the public interface + assert indexes is None, "Providing explicit indexes is not supported yet" + # try to fill in arguments from data if they weren't supplied if coords is None: @@ -409,10 +413,29 @@ def __init__( # TODO(shoyer): document this argument, once it becomes part of the # public interface. - self._indexes = indexes + self._indexes = indexes # type: ignore[assignment] self._close = None + @classmethod + def _construct_direct( + cls, + variable: Variable, + coords: dict[Any, Variable], + name: Hashable, + indexes: dict[Hashable, Index], + ) -> DataArray: + """Shortcut around __init__ for internal use when we want to skip + costly validation + """ + obj = object.__new__(cls) + obj._variable = variable + obj._coords = coords + obj._name = name + obj._indexes = indexes + obj._close = None + return obj + def _replace( self: T_DataArray, variable: Variable = None, @@ -829,8 +852,6 @@ def indexes(self) -> Indexes: @property def xindexes(self) -> Indexes: """Mapping of xarray Index objects used for label based indexing.""" - if self._indexes is None: - self._indexes = default_indexes(self._coords, self.dims) return Indexes(self._indexes, {k: self._coords[k] for k in self._indexes}) @property @@ -1043,10 +1064,7 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: """ variable = self.variable.copy(deep=deep, data=data) coords = {k: v.copy(deep=deep) for k, v in self._coords.items()} - if self._indexes is None: - indexes = self._indexes - else: - indexes = self.xindexes.copy_indexes(deep=deep) + indexes = self.xindexes.copy_indexes(deep=deep) return self._replace(variable, coords, indexes=indexes) def __copy__(self) -> DataArray: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bece58d5f1b..4311be1bada 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -519,7 +519,7 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping): _dims: dict[Hashable, int] _encoding: dict[Hashable, Any] | None _close: Callable[[], None] | None - _indexes: dict[Hashable, Index] | None + _indexes: dict[Hashable, Index] _variables: dict[Hashable, Variable] __slots__ = ( @@ -897,6 +897,10 @@ def _construct_direct( """ if dims is None: dims = calculate_dimensions(variables) + if indexes is None: + # TODO: (benbovy - explicit indexes) this may not be needed + # if all calls to _construct_direct explicitly pass a dict of indexes + indexes = default_indexes({k: variables[k] for k in coord_names}, dims) obj = object.__new__(cls) obj._variables = variables obj._coord_names = coord_names @@ -913,7 +917,7 @@ def _replace( coord_names: set[Hashable] = None, dims: dict[Any, int] = None, attrs: dict[Hashable, Any] | None | Default = _default, - indexes: dict[Hashable, Index] | None | Default = _default, + indexes: dict[Hashable, Index] = None, encoding: dict | None | Default = _default, inplace: bool = False, ) -> Dataset: @@ -934,7 +938,7 @@ def _replace( self._dims = dims if attrs is not _default: self._attrs = attrs - if indexes is not _default: + if indexes is not None: self._indexes = indexes if encoding is not _default: self._encoding = encoding @@ -948,8 +952,8 @@ def _replace( dims = self._dims.copy() if attrs is _default: attrs = copy.copy(self._attrs) - if indexes is _default: - indexes = copy.copy(self._indexes) + if indexes is None: + indexes = self._indexes.copy() if encoding is _default: encoding = copy.copy(self._encoding) obj = self._construct_direct( @@ -962,7 +966,7 @@ def _replace_with_new_dims( variables: dict[Hashable, Variable], coord_names: set = None, attrs: dict[Hashable, Any] | None | Default = _default, - indexes: dict[Hashable, Index] | None | Default = _default, + indexes: dict[Hashable, Index] = None, inplace: bool = False, ) -> Dataset: """Replace variables with recalculated dimensions.""" @@ -1276,10 +1280,7 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: if k in self._coord_names and set(self.variables[k].dims) <= needed_dims: coords[k] = self.variables[k] - if self._indexes is None: - indexes = None - else: - indexes = filter_indexes_from_coords(self.xindexes, set(coords)) + indexes = filter_indexes_from_coords(self.xindexes, set(coords)) return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) @@ -1498,8 +1499,7 @@ def __delitem__(self, key: Hashable) -> None: """Remove a variable from this dataset.""" assert_no_index_corrupted(self.xindexes, {key}) - if key in self.xindexes: - assert self._indexes is not None + if key in self._indexes: del self._indexes[key] del self._variables[key] self._coord_names.discard(key) @@ -1592,8 +1592,6 @@ def indexes(self) -> Indexes[pd.Index]: @property def xindexes(self) -> Indexes[Index]: """Mapping of xarray Index objects used for label based indexing.""" - if self._indexes is None: - self._indexes = default_indexes(self._variables, self._dims) return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) @property @@ -3260,7 +3258,7 @@ def _rename_dims(self, name_dict): return {name_dict.get(k, k): v for k, v in self.dims.items()} def _rename_indexes(self, name_dict, dims_dict): - if self._indexes is None: + if not self._indexes: return {}, {} indexes = {} @@ -5463,17 +5461,16 @@ def to_array(self, dim="variable", name=None): broadcast_vars = broadcast_variables(*data_vars) data = duck_array_ops.stack([b.data for b in broadcast_vars], axis=0) - coords = dict(self.coords) + dims = (dim,) + broadcast_vars[0].dims + variable = Variable(dims, data, self.attrs, fastpath=True) + + coords = {k: v.variable for k, v in self.coords.items()} indexes = filter_indexes_from_coords(self.xindexes, set(coords)) new_dim_index = PandasIndex(list(self.data_vars), dim) - indexes[new_dim_index] = new_dim_index + indexes[dim] = new_dim_index coords.update(new_dim_index.create_variables()) - dims = (dim,) + broadcast_vars[0].dims - - return DataArray( - data, coords, dims, attrs=self.attrs, name=name, indexes=indexes - ) + return DataArray._construct_direct(variable, coords, name, indexes) def _normalize_dim_order( self, dim_order: list[Hashable] = None From 3280eebda609ba86383df69b1efb95856be80bf3 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 17 Feb 2022 00:33:33 +0100 Subject: [PATCH 139/159] remove _level_coords property It was only used for plotting and has been replaced by simpler logic. --- xarray/core/dataarray.py | 15 --------------- xarray/core/dataset.py | 15 --------------- xarray/plot/utils.py | 19 ++++++++++--------- 3 files changed, 10 insertions(+), 39 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d37c3a71e9b..edcdeac950f 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -738,21 +738,6 @@ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: key = indexing.expanded_indexer(key, self.ndim) return dict(zip(self.dims, key)) - @property - def _level_coords(self) -> dict[Hashable, Hashable]: - """Return a mapping of all MultiIndex levels and their corresponding - coordinate name. - """ - level_coords: dict[Hashable, Hashable] = {} - - for _, var in self._coords.items(): - if var.ndim == 1 and isinstance(var, IndexVariable): - level_names = var.level_names - if level_names is not None: - (dim,) = var.dims - level_coords.update({lname: dim for lname in level_names}) - return level_coords - def _getitem_coord(self, key): from .dataset import _get_virtual_variable diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4311be1bada..498c52c633a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1205,21 +1205,6 @@ def as_numpy(self: Dataset) -> Dataset: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - @property - def _level_coords(self) -> dict[str, Hashable]: - """Return a mapping of all MultiIndex levels and their corresponding - coordinate name. - """ - level_coords: dict[str, Hashable] = {} - for name, index in self.xindexes.items(): - # TODO: benbovy - flexible indexes: update when MultIndex has its own xarray class. - pd_index = index.to_pandas_index() - if isinstance(pd_index, pd.MultiIndex): - level_names = pd_index.names - (dim,) = self.variables[name].dims - level_coords.update({lname: dim for lname in level_names}) - return level_coords - def _copy_listed(self, names: Iterable[Hashable]) -> Dataset: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index f09d1eb1853..d942f6656ba 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd +from ..core.indexes import PandasMultiIndex from ..core.options import OPTIONS from ..core.pycompat import DuckArrayModule from ..core.utils import is_scalar @@ -383,11 +384,9 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): _assert_valid_xy(darray, x, "x") _assert_valid_xy(darray, y, "y") - if ( - all(k in darray._level_coords for k in (x, y)) - and darray._level_coords[x] == darray._level_coords[y] - ): - raise ValueError("x and y cannot be levels of the same MultiIndex") + if darray._indexes.get(x, 1) is darray._indexes.get(y, 2): + if isinstance(darray._indexes[x], PandasMultiIndex): + raise ValueError("x and y cannot be levels of the same MultiIndex") return x, y @@ -398,11 +397,13 @@ def _assert_valid_xy(darray, xy, name): """ # MultiIndex cannot be plotted; no point in allowing them here - multiindex = {darray._level_coords[lc] for lc in darray._level_coords} + multiindex_dims = { + idx.dim + for idx in darray.xindexes.get_unique() + if isinstance(idx, PandasMultiIndex) + } - valid_xy = ( - set(darray.dims) | set(darray.coords) | set(darray._level_coords) - ) - multiindex + valid_xy = (set(darray.dims) | set(darray.coords)) - multiindex_dims if xy not in valid_xy: valid_xy_str = "', '".join(sorted(valid_xy)) From a415188d6116450d6112bd344afd974b3658e133 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 17 Feb 2022 01:21:48 +0100 Subject: [PATCH 140/159] clean-up Use ``_indexes`` instead of ``.xindexes`` when possible, for speed. ``xindexes`` is not cached: it would return inconsistent results if updating the object in-place. Fix some types. --- xarray/core/alignment.py | 4 +-- xarray/core/combine.py | 2 +- xarray/core/common.py | 5 ++- xarray/core/concat.py | 2 +- xarray/core/dataarray.py | 15 ++++----- xarray/core/dataset.py | 71 +++++++++++++++++++--------------------- xarray/core/groupby.py | 2 +- xarray/core/indexes.py | 10 +++--- xarray/core/indexing.py | 2 +- xarray/core/merge.py | 4 +-- xarray/core/missing.py | 6 ++-- xarray/core/parallel.py | 16 ++++----- 12 files changed, 64 insertions(+), 75 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 310b198a817..01c64321a40 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -853,7 +853,7 @@ def reindex( # TODO: (benbovy - explicit indexes): uncomment? # --> from reindex docstrings: "any mis-matched dimension is simply ignored" - # bad_keys = [k for k in indexers if k not in obj.xindexes and k not in obj.dims] + # bad_keys = [k for k in indexers if k not in obj._indexes and k not in obj.dims] # if bad_keys: # raise ValueError( # f"indexer keys {bad_keys} do not correspond to any indexed coordinate " @@ -886,7 +886,7 @@ def reindex_like( Not public API. """ - if not other.xindexes: + if not other._indexes: # This check is not performed in Aligner. for dim in other.dims: if dim in obj.dims: diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 081b53391ba..ab793fc5eb0 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -86,7 +86,7 @@ def _infer_concat_order_from_coords(datasets): if dim in ds0: # Need to read coordinate values to do ordering - indexes = [ds.xindexes.get(dim) for ds in datasets] + indexes = [ds._indexes.get(dim) for ds in datasets] if any(index is None for index in indexes): raise ValueError( "Every dimension needs a coordinate for " diff --git a/xarray/core/common.py b/xarray/core/common.py index bee59c6cc7d..60a01951eac 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -404,7 +404,7 @@ def get_index(self, key: Hashable) -> pd.Index: raise KeyError(key) try: - return self.xindexes[key].to_pandas_index() + return self._indexes[key].to_pandas_index() except KeyError: return pd.Index(range(self.sizes[key]), name=key) @@ -1151,8 +1151,7 @@ def resample( category=FutureWarning, ) - # TODO (benbovy - flexible indexes): update when CFTimeIndex is an xarray Index subclass - if isinstance(self.xindexes[dim_name].to_pandas_index(), CFTimeIndex): + if isinstance(self._indexes[dim_name].to_pandas_index(), CFTimeIndex): from .resample_cftime import CFTimeGrouper grouper = CFTimeGrouper(freq, closed, label, base, loffset) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index e1aa16de7cd..58f73ce8ad7 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -527,7 +527,7 @@ def ensure_common_dims(vars): # TODO: (benbovy - explicit indexes): check index types and/or coordinates # of all datasets? try: - indexes = [ds.xindexes[name] for ds in datasets] + indexes = [ds._indexes[name] for ds in datasets] except KeyError: combined_var = concat_vars( vars, dim, positions, combine_attrs=combine_attrs diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index edcdeac950f..d431fb946b8 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -467,13 +467,13 @@ def _replace_maybe_drop_dims( for k, v in self._coords.items() if v.shape == tuple(new_sizes[d] for d in v.dims) } - indexes = filter_indexes_from_coords(self.xindexes, set(coords)) + indexes = filter_indexes_from_coords(self._indexes, set(coords)) else: allowed_dims = set(variable.dims) coords = { k: v for k, v in self._coords.items() if set(v.dims) <= allowed_dims } - indexes = filter_indexes_from_coords(self.xindexes, set(coords)) + indexes = filter_indexes_from_coords(self._indexes, set(coords)) return self._replace(variable, coords, name, indexes=indexes) def _overwrite_indexes( @@ -494,7 +494,7 @@ def _overwrite_indexes( new_variable = self.variable.copy() new_coords = self._coords.copy() - new_indexes = dict(self.xindexes) + new_indexes = dict(self._indexes) for name in indexes: new_coords[name] = coords[name] @@ -533,7 +533,7 @@ def subset(dim, label): variables = {label: subset(dim, label) for label in self.get_index(dim)} variables.update({k: v for k, v in self._coords.items() if k != dim}) coord_names = set(self._coords) - {dim} - indexes = filter_indexes_from_coords(self.xindexes, coord_names) + indexes = filter_indexes_from_coords(self._indexes, coord_names) dataset = Dataset._construct_direct( variables, coord_names, indexes=indexes, attrs=self.attrs ) @@ -865,7 +865,7 @@ def reset_coords( Dataset, or DataArray if ``drop == True`` """ if names is None: - names = set(self.coords) - set(self.xindexes) + names = set(self.coords) - set(self._indexes) dataset = self.coords.to_dataset().reset_coords(names, drop) if drop: return self._replace(coords=dataset._variables) @@ -2293,10 +2293,7 @@ def to_unstacked_dataset(self, dim, level=0): -------- Dataset.to_stacked_array """ - - # TODO: benbovy - flexible indexes: update when MultIndex has its own - # class inheriting from xarray.Index - idx = self.xindexes[dim].to_pandas_index() + idx = self._indexes[dim].to_pandas_index() if not isinstance(idx, pd.MultiIndex): raise ValueError(f"'{dim}' is not a stacked coordinate") diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 498c52c633a..27dba3ac904 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -60,7 +60,6 @@ PandasMultiIndex, assert_no_index_corrupted, create_default_index_implicit, - default_indexes, filter_indexes_from_coords, isel_indexes, remove_unused_levels_categories, @@ -898,9 +897,7 @@ def _construct_direct( if dims is None: dims = calculate_dimensions(variables) if indexes is None: - # TODO: (benbovy - explicit indexes) this may not be needed - # if all calls to _construct_direct explicitly pass a dict of indexes - indexes = default_indexes({k: variables[k] for k in coord_names}, dims) + indexes = {} obj = object.__new__(cls) obj._variables = variables obj._coord_names = coord_names @@ -1022,7 +1019,7 @@ def _overwrite_indexes( new_variables = self._variables.copy() new_coord_names = self._coord_names.copy() - new_indexes = dict(self.xindexes) + new_indexes = dict(self._indexes) index_variables = {} no_index_variables = {} @@ -1244,7 +1241,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Dataset: variables[k] = self._variables[k] coord_names.add(k) - indexes.update(filter_indexes_from_coords(self.xindexes, coord_names)) + indexes.update(filter_indexes_from_coords(self._indexes, coord_names)) return self._replace(variables, coord_names, dims, indexes=indexes) @@ -1265,7 +1262,7 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: if k in self._coord_names and set(self.variables[k].dims) <= needed_dims: coords[k] = self.variables[k] - indexes = filter_indexes_from_coords(self.xindexes, set(coords)) + indexes = filter_indexes_from_coords(self._indexes, set(coords)) return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) @@ -1641,14 +1638,14 @@ def reset_coords( Dataset """ if names is None: - names = self._coord_names - set(self.xindexes) + names = self._coord_names - set(self._indexes) else: if isinstance(names, str) or not isinstance(names, Iterable): names = [names] else: names = list(names) self._assert_all_in_dataset(names) - bad_coords = set(names) & set(self.xindexes) + bad_coords = set(names) & set(self._indexes) if bad_coords: raise ValueError( f"cannot remove index coordinates with reset_coords: {bad_coords}" @@ -2073,9 +2070,7 @@ def _validate_indexers( v = np.asarray(v) if v.dtype.kind in "US": - # TODO: benbovy - flexible indexes - # update when CFTimeIndex has its own xarray index class - index = self.xindexes[k].to_pandas_index() + index = self._indexes[k].to_pandas_index() if isinstance(index, pd.DatetimeIndex): v = v.astype("datetime64[ns]") elif isinstance(index, xr.CFTimeIndex): @@ -2545,7 +2540,7 @@ def _reindex_callback( # pass through indexes from excluded dimensions # no extra check needed for multi-coordinate indexes, potential conflicts # should already have been detected when aligning the indexes - for name, idx in self.xindexes.items(): + for name, idx in self._indexes.items(): var = self._variables[name] if set(var.dims) <= exclude_dims: new_indexes[name] = idx @@ -2553,7 +2548,7 @@ def _reindex_callback( if not dim_pos_indexers: # fast path for no reindexing necessary - if set(new_indexes) - set(self.xindexes): + if set(new_indexes) - set(self._indexes): # this only adds new indexes and their coordinate variables reindexed = self._overwrite_indexes(new_indexes, new_variables) else: @@ -3112,11 +3107,11 @@ def _validate_interp_indexer(x, new_x): method=method_non_numeric, exclude_vars=variables.keys(), ) - indexes = dict(reindexed.xindexes) + indexes = dict(reindexed._indexes) variables.update(reindexed.variables) else: # Get the indexes that are not being interpolated along - indexes = {k: v for k, v in obj.xindexes.items() if k not in indexers} + indexes = {k: v for k, v in obj._indexes.items() if k not in indexers} # Get the coords that also exist in the variables: coord_names = obj._coord_names & variables.keys() @@ -3485,8 +3480,8 @@ def swap_dims( if k in result_dims: var = v.to_index_variable() var.dims = dims - if k in self.xindexes: - indexes[k] = self.xindexes[k] + if k in self._indexes: + indexes[k] = self._indexes[k] variables[k] = var else: index, index_vars = create_default_index_implicit(var) @@ -3576,7 +3571,7 @@ def expand_dims( ) variables: dict[Hashable, Variable] = {} - indexes: dict[Hashable, Index] = dict(self.xindexes) + indexes: dict[Hashable, Index] = dict(self._indexes) coord_names = self._coord_names.copy() # If dim is a dict, then ensure that the values are either integers # or iterables. @@ -3725,7 +3720,7 @@ def set_index( drop_variables += var_names - if len(var_names) == 1 and (not append or dim not in self.xindexes): + if len(var_names) == 1 and (not append or dim not in self._indexes): var_name = var_names[0] var = self._variables[var_name] if var.dims != (dim,): @@ -3754,7 +3749,7 @@ def set_index( new_variables.update(idx_vars) indexes_: dict[Any, Index] = { - k: v for k, v in self.xindexes.items() if k not in maybe_drop_indexes + k: v for k, v in self._indexes.items() if k not in maybe_drop_indexes } indexes_.update(new_indexes) @@ -3803,7 +3798,7 @@ def reset_index( if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence): dims_or_levels = [dims_or_levels] - invalid_coords = set(dims_or_levels) - set(self.xindexes) + invalid_coords = set(dims_or_levels) - set(self._indexes) if invalid_coords: raise ValueError( f"{tuple(invalid_coords)} are not coordinates with an index" @@ -3816,7 +3811,7 @@ def reset_index( new_variables: dict[Hashable, IndexVariable] = {} for name in dims_or_levels: - index = self.xindexes[name] + index = self._indexes[name] drop_indexes += list(self.xindexes.get_all_coords(name)) if isinstance(index, PandasMultiIndex) and name not in self.dims: @@ -3839,7 +3834,7 @@ def reset_index( if drop: drop_variables.append(name) - indexes = {k: v for k, v in self.xindexes.items() if k not in drop_indexes} + indexes = {k: v for k, v in self._indexes.items() if k not in drop_indexes} indexes.update(new_indexes) variables = { @@ -3876,12 +3871,12 @@ def reorder_levels( """ dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") variables = self._variables.copy() - indexes = dict(self.xindexes) + indexes = dict(self._indexes) new_indexes: dict[Hashable, Index] = {} new_variables: dict[Hashable, IndexVariable] = {} for dim, order in dim_order.items(): - index = self.xindexes[dim] + index = self._indexes[dim] if not isinstance(index, PandasMultiIndex): raise ValueError(f"coordinate {dim} has no MultiIndex") @@ -3892,7 +3887,7 @@ def reorder_levels( new_indexes.update({k: idx for k in idx_vars}) new_variables.update(idx_vars) - indexes = {k: v for k, v in self.xindexes.items() if k not in new_indexes} + indexes = {k: v for k, v in self._indexes.items() if k not in new_indexes} indexes.update(new_indexes) variables = {k: v for k, v in self._variables.items() if k not in new_variables} @@ -3919,7 +3914,7 @@ def _get_stack_index( stack_index: Index | None = None stack_coords: dict[Hashable, Variable] = {} - for name, index in self.xindexes.items(): + for name, index in self._indexes.items(): var = self._variables[name] if ( var.ndim == 1 @@ -4001,7 +3996,7 @@ def _stack_once(self, dims, new_dim, index_cls, create_index=True): new_variables.update(idx_vars) new_coord_names.update(idx_vars) - indexes = {k: v for k, v in self.xindexes.items() if k not in drop_indexes} + indexes = {k: v for k, v in self._indexes.items() if k not in drop_indexes} indexes.update(new_indexes) return self._replace_with_new_dims( @@ -4176,7 +4171,7 @@ def _unstack_once( ) -> Dataset: index, index_vars = index_and_vars variables: dict[Hashable, Variable] = {} - indexes = {k: v for k, v in self.xindexes.items() if k != dim} + indexes = {k: v for k, v in self._indexes.items() if k != dim} new_indexes, clean_index = index.unstack() indexes.update(new_indexes) @@ -4216,7 +4211,7 @@ def _unstack_full_reindex( ) -> Dataset: index, index_vars = index_and_vars variables: dict[Hashable, Variable] = {} - indexes = {k: v for k, v in self.xindexes.items() if k != dim} + indexes = {k: v for k, v in self._indexes.items() if k != dim} new_indexes, clean_index = index.unstack() indexes.update(new_indexes) @@ -4330,7 +4325,7 @@ def unstack( # We only check the non-index variables. # https://github.com/pydata/xarray/issues/5902 nonindexes = [ - self.variables[k] for k in set(self.variables) - set(self.xindexes) + self.variables[k] for k in set(self.variables) - set(self._indexes) ] # Notes for each of these cases: # 1. Dask arrays don't support assignment by index, which the fast unstack @@ -4529,7 +4524,7 @@ def drop_vars( variables = {k: v for k, v in self._variables.items() if k not in names} coord_names = {k for k in self._coord_names if k in variables} - indexes = {k: v for k, v in self.xindexes.items() if k not in names} + indexes = {k: v for k, v in self._indexes.items() if k not in names} return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) @@ -5243,7 +5238,7 @@ def reduce( ) coord_names = {k for k in self.coords if k in variables} - indexes = {k: v for k, v in self.xindexes.items() if k in variables} + indexes = {k: v for k, v in self._indexes.items() if k in variables} attrs = self.attrs if keep_attrs else None return self._replace_with_new_dims( variables, coord_names=coord_names, attrs=attrs, indexes=indexes @@ -5450,7 +5445,7 @@ def to_array(self, dim="variable", name=None): variable = Variable(dims, data, self.attrs, fastpath=True) coords = {k: v.variable for k, v in self.coords.items()} - indexes = filter_indexes_from_coords(self.xindexes, set(coords)) + indexes = filter_indexes_from_coords(self._indexes, set(coords)) new_dim_index = PandasIndex(list(self.data_vars), dim) indexes[dim] = new_dim_index coords.update(new_dim_index.create_variables()) @@ -6186,7 +6181,7 @@ def roll( indexes, index_vars = roll_indexes(self.xindexes, shifts) unrolled_vars = () else: - indexes = dict(self.xindexes) + indexes = dict(self._indexes) index_vars = dict(self.xindexes.variables) unrolled_vars = tuple(self.coords) @@ -6450,7 +6445,7 @@ def quantile( # construct the new dataset coord_names = {k for k in self.coords if k in variables} - indexes = {k: v for k, v in self.xindexes.items() if k in variables} + indexes = {k: v for k, v in self._indexes.items() if k in variables} if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) attrs = self.attrs if keep_attrs else None @@ -6683,7 +6678,7 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): variables[k] = Variable(v_dims, integ) else: variables[k] = v - indexes = {k: v for k, v in self.xindexes.items() if k in variables} + indexes = {k: v for k, v in self._indexes.items() if k in variables} return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b6086cf910d..1c9b13fa058 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -522,7 +522,7 @@ def _maybe_unstack(self, obj): for dim in self._inserted_dims: if dim in obj.coords: del obj.coords[dim] - obj._indexes = filter_indexes_from_coords(obj.xindexes, set(obj.coords)) + obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) return obj def fillna(self, value): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 33392f8a0b7..982f4409226 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1156,7 +1156,7 @@ def to_pandas_indexes(self) -> Indexes[pd.Index]: return Indexes(indexes, self._variables) - def copy_indexes(self, deep: bool = True) -> dict[Hashable, Index]: + def copy_indexes(self, deep: bool = True) -> dict[Hashable, T_PandasOrXarrayIndex]: """Return a new dictionary with copies of indexes, preserving unique indexes. @@ -1168,16 +1168,16 @@ def copy_indexes(self, deep: bool = True) -> dict[Hashable, Index]: return new_indexes - def __iter__(self) -> Iterator[pd.Index]: + def __iter__(self) -> Iterator[T_PandasOrXarrayIndex]: return iter(self._indexes) - def __len__(self): + def __len__(self) -> int: return len(self._indexes) - def __contains__(self, key): + def __contains__(self, key) -> bool: return key in self._indexes - def __getitem__(self, key) -> pd.Index: + def __getitem__(self, key) -> T_PandasOrXarrayIndex: return self._indexes[key] def __repr__(self): diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 4a44364ad3b..d95d282d4ed 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -196,7 +196,7 @@ def map_index_queries( # (.sel() already ensures alignment) for k, v in merged.dim_indexers.items(): if isinstance(v, DataArray): - if k in v.xindexes: + if k in v._indexes: v = v.reset_index(k) drop_coords = [name for name in v._coords if name in merged.dim_indexers] merged.dim_indexers[k] = v.drop_vars(drop_coords) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 7ba8b4e6cc2..cf9c74e9d23 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -334,13 +334,13 @@ def append_all(variables, indexes): for mapping in list_of_mappings: if isinstance(mapping, Dataset): - append_all(mapping.variables, mapping.xindexes) + append_all(mapping.variables, mapping._indexes) continue for name, variable in mapping.items(): if isinstance(variable, DataArray): coords = variable._coords.copy() # use private API for speed - indexes = dict(variable.xindexes) + indexes = dict(variable._indexes) # explicitly overwritten variables should take precedence coords.pop(name, None) indexes.pop(name, None) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 39e7730dd58..84bcf60d8be 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -321,12 +321,10 @@ def interp_na( if not is_scalar(max_gap): raise ValueError("max_gap must be a scalar.") - # TODO: benbovy - flexible indexes: update when CFTimeIndex (and DatetimeIndex?) - # has its own class inheriting from xarray.Index if ( - dim in self.xindexes + dim in self._indexes and isinstance( - self.xindexes[dim].to_pandas_index(), (pd.DatetimeIndex, CFTimeIndex) + self._indexes[dim].to_pandas_index(), (pd.DatetimeIndex, CFTimeIndex) ) and use_coordinate ): diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 3f6bb34a36e..fd1f3f9e999 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -292,7 +292,7 @@ def _wrapper( ) # check that index lengths and values are as expected - for name, index in result.xindexes.items(): + for name, index in result._indexes.items(): if name in expected["shapes"]: if result.sizes[name] != expected["shapes"][name]: raise ValueError( @@ -359,27 +359,27 @@ def _wrapper( # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) - input_indexes = dict(npargs[0].xindexes) + input_indexes = dict(npargs[0]._indexes) for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) - input_indexes.update(arg.xindexes) + input_indexes.update(arg._indexes) if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) - template_indexes = set(template.xindexes) + template_indexes = set(template._indexes) preserved_indexes = template_indexes & set(input_indexes) new_indexes = template_indexes - set(input_indexes) indexes = {dim: input_indexes[dim] for dim in preserved_indexes} - indexes.update({k: template.xindexes[k] for k in new_indexes}) + indexes.update({k: template._indexes[k] for k in new_indexes}) output_chunks = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks } else: # template xarray object has been provided with proper sizes and chunk shapes - indexes = dict(template.xindexes) + indexes = dict(template._indexes) if isinstance(template, DataArray): output_chunks = dict( zip(template.dims, template.chunks) # type: ignore[arg-type] @@ -558,7 +558,7 @@ def subset_dataset_to_block( attrs=template.attrs, ) - for index in result.xindexes: + for index in result._indexes: result[index].attrs = template[index].attrs result[index].encoding = template[index].encoding @@ -568,7 +568,7 @@ def subset_dataset_to_block( for dim in dims: if dim in output_chunks: var_chunks.append(output_chunks[dim]) - elif dim in result.xindexes: + elif dim in result._indexes: var_chunks.append((result.sizes[dim],)) elif dim in template.dims: # new unindexed dimension From 6405ffda52083ba836d642a6fd17c6ab5e843b02 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 17 Feb 2022 11:52:27 +0100 Subject: [PATCH 141/159] optimize Dataset/DataArray copy Avoid copying pandas indexes twice (indexes + coordinate variables) --- xarray/core/dataarray.py | 11 +++++++++-- xarray/core/dataset.py | 18 +++++++++++------- xarray/core/indexes.py | 9 +++++++-- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d431fb946b8..0c487dc6e9d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1048,8 +1048,15 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: pandas.DataFrame.copy """ variable = self.variable.copy(deep=deep, data=data) - coords = {k: v.copy(deep=deep) for k, v in self._coords.items()} - indexes = self.xindexes.copy_indexes(deep=deep) + indexes, index_vars = self.xindexes.copy_indexes(deep=deep) + + coords = {} + for k, v in self._coords.items(): + if k in index_vars: + coords[k] = index_vars[k] + else: + coords[k] = v.copy(deep=deep) + return self._replace(variable, coords, indexes=indexes) def __copy__(self) -> DataArray: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 27dba3ac904..f4e5c97f06e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1161,10 +1161,11 @@ def copy(self, deep: bool = False, data: Mapping = None) -> Dataset: pandas.DataFrame.copy """ if data is None: - variables = {k: v.copy(deep=deep) for k, v in self._variables.items()} + data = {} elif not utils.is_dict_like(data): raise ValueError("Data must be dict-like") - else: + + if data: var_keys = set(self.data_vars.keys()) data_keys = set(data.keys()) keys_not_in_vars = data_keys - var_keys @@ -1179,12 +1180,15 @@ def copy(self, deep: bool = False, data: Mapping = None) -> Dataset: "Data must contain all variables in original " "dataset. Data is missing {}".format(keys_missing_from_data) ) - variables = { - k: v.copy(deep=deep, data=data.get(k)) - for k, v in self._variables.items() - } - indexes = self.xindexes.copy_indexes(deep=deep) + indexes, index_vars = self.xindexes.copy_indexes(deep=deep) + + variables = {} + for k, v in self._variables.items(): + if k in index_vars: + variables[k] = index_vars[k] + else: + variables[k] = v.copy(deep=deep, data=data.get(k)) attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 982f4409226..fad20f55e7a 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1156,17 +1156,22 @@ def to_pandas_indexes(self) -> Indexes[pd.Index]: return Indexes(indexes, self._variables) - def copy_indexes(self, deep: bool = True) -> dict[Hashable, T_PandasOrXarrayIndex]: + def copy_indexes( + self, deep: bool = True + ) -> tuple[dict[Hashable, T_PandasOrXarrayIndex], dict[Hashable, Variable]]: """Return a new dictionary with copies of indexes, preserving unique indexes. """ new_indexes = {} + new_index_vars = {} for idx, coords in self.group_by_index(): new_idx = idx.copy(deep=deep) + idx_vars = idx.create_variables(coords) new_indexes.update({k: new_idx for k in coords}) + new_index_vars.update(idx_vars) - return new_indexes + return new_indexes, new_index_vars def __iter__(self) -> Iterator[T_PandasOrXarrayIndex]: return iter(self._indexes) From 2b1f90b7c3ddb88551ee21cf06c1147b7b6408ea Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 17 Feb 2022 11:54:39 +0100 Subject: [PATCH 142/159] add default implementation for Index.copy Better than raising an error? --- xarray/core/indexes.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index fad20f55e7a..a7234a98516 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections.abc +import copy from collections import defaultdict from typing import ( TYPE_CHECKING, @@ -101,8 +102,23 @@ def rename( ) -> Index: return self - def copy(self, deep: bool = True): # pragma: no cover - raise NotImplementedError() + def __copy__(self) -> Index: + return self.copy(deep=False) + + def __deepcopy__(self, memo=None) -> Index: + # memo does nothing but is required for compatibility with + # copy.deepcopy + return self.copy(deep=True) + + def copy(self, deep: bool = True) -> Index: + cls = self.__class__ + copied = cls.__new__(cls) + if deep: + for k, v in self.__dict__.items(): + setattr(copied, k, copy.deepcopy(v)) + else: + copied.__dict__.update(self.__dict__) + return copied def __getitem__(self, indexer: Any): raise NotImplementedError() From 73ba9d40e0beb44c67a0e235715e8ef6d419a947 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 17 Feb 2022 11:55:32 +0100 Subject: [PATCH 143/159] add/fix indexes tests --- xarray/tests/test_indexes.py | 85 +++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 0672f0659e8..e5d202b1142 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -1,3 +1,4 @@ +import copy from typing import Any, Dict, List import numpy as np @@ -29,6 +30,84 @@ def test_asarray_tuplesafe() -> None: assert res[1] == (1,) +class CustomIndex(Index): + def __init__(self, dims) -> None: + self.dims = dims + + +class TestIndex: + @pytest.fixture + def index(self) -> CustomIndex: + return CustomIndex({"x": 2}) + + def test_from_variables(self) -> None: + with pytest.raises(NotImplementedError): + Index.from_variables({}) + + def test_concat(self) -> None: + with pytest.raises(NotImplementedError): + Index.concat([], "x") + + def test_stack(self) -> None: + with pytest.raises(NotImplementedError): + Index.stack({}, "x") + + def test_unstack(self, index) -> None: + with pytest.raises(NotImplementedError): + index.unstack() + + def test_create_variables(self, index) -> None: + assert index.create_variables() == {} + assert index.create_variables({"x": "var"}) == {"x": "var"} + + def test_to_pandas_index(self, index) -> None: + with pytest.raises(TypeError): + index.to_pandas_index() + + def test_isel(self, index) -> None: + assert index.isel({}) is None + + def test_sel(self, index) -> None: + with pytest.raises(NotImplementedError): + index.sel({}) + + def test_join(self, index) -> None: + with pytest.raises(NotImplementedError): + index.join(CustomIndex({"y": 2})) + + def test_reindex_like(self, index) -> None: + with pytest.raises(NotImplementedError): + index.reindex_like(CustomIndex({"y": 2})) + + def test_equals(self, index) -> None: + with pytest.raises(NotImplementedError): + index.equals(CustomIndex({"y": 2})) + + def test_roll(self, index) -> None: + assert index.roll({}) is None + + def test_rename(self, index) -> None: + assert index.rename({}, {}) is index + + @pytest.mark.parametrize("deep", [True, False]) + def test_copy(self, index, deep) -> None: + copied = index.copy(deep=deep) + assert isinstance(copied, CustomIndex) + assert copied is not index + + copied.dims["x"] = 3 + if deep: + assert copied.dims != index.dims + assert copied.dims != copy.deepcopy(index).dims + else: + assert copied.dims is index.dims + assert copied.dims is copy.copy(index).dims + + def test_getitem(self, index) -> None: + with pytest.raises(NotImplementedError): + index[:] + + class TestPandasIndex: def test_constructor(self) -> None: pd_idx = pd.Index([1, 2, 3]) @@ -537,10 +616,14 @@ def test_to_pandas_indexes(self, indexes) -> None: assert indexes.variables == pd_indexes.variables def test_copy_indexes(self, indexes) -> None: - copied = indexes.copy_indexes() + copied, index_vars = indexes.copy_indexes() assert copied.keys() == indexes.keys() for new, original in zip(copied.values(), indexes.values()): assert new.equals(original) # check unique index objects preserved assert copied["z"] is copied["one"] is copied["two"] + + assert index_vars.keys() == indexes.variables.keys() + for new, original in zip(index_vars.values(), indexes.variables.values()): + assert_identical(new, original) From 3927479f09f0cf486e201e5b6c0e7ef1f11d9d48 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 17 Feb 2022 12:08:32 +0100 Subject: [PATCH 144/159] tweak indexes formatting This will need more work (in a follow-up PR) to be consistent with the other data model components (i.e., coordinates and data variables): add summary (inline) and detailed reprs, maybe group by coordinates, etc. --- xarray/core/formatting.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index c85f288ea27..f2fcf1bad6d 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -408,9 +408,12 @@ def coords_repr(coords, col_width=None, max_rows=None): def indexes_repr(indexes): - summary = [] - for k, v in indexes.items(): - summary.append(wrap_indent(repr(v), f"{k}: ")) + summary = ["Indexes:"] + if indexes: + for k, v in indexes.items(): + summary.append(wrap_indent(repr(v), f"{k}: ")) + else: + summary += [EMPTY_REPR] return "\n".join(summary) From a3173b52021178f91ad63fed217d38c8d772d544 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 17 Feb 2022 12:35:54 +0100 Subject: [PATCH 145/159] fix doctests --- xarray/core/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f4e5c97f06e..728a28f70ef 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2724,6 +2724,7 @@ def reindex( temperature (station) float64 10.98 14.3 12.06 10.9 pressure (station) float64 211.8 322.9 218.8 445.9 >>> x.indexes + Indexes: station: Index(['boston', 'nyc', 'seattle', 'denver'], dtype='object', name='station') Create a new index and reindex the dataset. By default values in the new index that From 709a58e1ac66abd948cb6d47dac073d10a827e8f Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 18 Feb 2022 10:05:46 +0100 Subject: [PATCH 146/159] PandasIndex.copy: avoid too many pd.Index copies --- xarray/core/indexes.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index a7234a98516..bc6e8186493 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -461,7 +461,12 @@ def rename(self, name_dict, dims_dict): return self._replace(index, dim=new_dim) def copy(self, deep=True): - return self._replace(self.index.copy(deep=deep)) + if deep: + index = self.index.copy(deep=True) + else: + # index will be copied in constructor + index = self.index + return self._replace(index) def __getitem__(self, indexer: Any): return self._replace(self.index[indexer]) From ed1af4d660a5c560200031c65982f3df80655987 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 21 Feb 2022 09:38:02 +0100 Subject: [PATCH 147/159] DataArray stack test: revert to original Now that default (range) indexes are created again (create_index=True). --- xarray/tests/test_dataarray.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index d6d846ea24e..634fb601167 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2186,26 +2186,24 @@ def test_dataset_math(self): def test_stack_unstack(self): orig = DataArray( [[0, 1], [2, 3]], - coords={"x": [0, 1], "y": ["a", "b"]}, dims=["x", "y"], attrs={"foo": 2}, ) assert_identical(orig, orig.unstack()) # test GH3000 - # no default range index anymore - # a = orig[:0, :1].stack(dim=("x", "y")).dim.to_index() - # b = pd.MultiIndex( - # levels=[pd.Index([], np.int64), pd.Index([0], np.int64)], - # codes=[[], []], - # names=["x", "y"], - # ) - # pd.testing.assert_index_equal(a, b) - - actual = orig.stack(z=["x", "y"]).unstack("z") + a = orig[:0, :1].stack(dim=("x", "y")).indexes["dim"] + b = pd.MultiIndex( + levels=[pd.Index([], np.int64), pd.Index([0], np.int64)], + codes=[[], []], + names=["x", "y"], + ) + pd.testing.assert_index_equal(a, b) + + actual = orig.stack(z=["x", "y"]).unstack("z").drop_vars(["x", "y"]) assert_identical(orig, actual) - actual = orig.stack(z=[...]).unstack("z") + actual = orig.stack(z=[...]).unstack("z").drop_vars(["x", "y"]) assert_identical(orig, actual) dims = ["a", "b", "c", "d", "e"] From 9708eb32ce06e70ed4d33ede4e97aa728cf1e0f6 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 21 Feb 2022 09:44:16 +0100 Subject: [PATCH 148/159] test indexes/indexing: rename query -> sel --- xarray/tests/test_indexes.py | 6 +++--- xarray/tests/test_indexing.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index e5d202b1142..ff04ec644e4 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -214,7 +214,7 @@ def test_sel(self) -> None: with pytest.raises(ValueError, match=r"does not have a MultiIndex"): index.sel({"x": {"one": 0}}) - def test_query_boolean(self) -> None: + def test_sel_boolean(self) -> None: # index should be ignored and indexer dtype should not be coerced # see https://github.com/pydata/xarray/issues/5727 index = PandasIndex(pd.Index([0.0, 2.0, 1.0, 3.0]), "x") @@ -224,7 +224,7 @@ def test_query_boolean(self) -> None: actual.dim_indexers["x"], expected_dim_indexers["x"] ) - def test_query_datetime(self) -> None: + def test_sel_datetime(self) -> None: index = PandasIndex( pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x" ) @@ -235,7 +235,7 @@ def test_query_datetime(self) -> None: actual = index.sel({"x": index.to_pandas_index().to_numpy()[1]}) assert actual.dim_indexers == expected_dim_indexers - def test_query_unsorted_datetime_index_raises(self) -> None: + def test_sel_unsorted_datetime_index_raises(self) -> None: index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"]), "x") with pytest.raises(KeyError): # pandas will try to convert this into an array indexer. We should diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index de9393bb9d2..0b40bd18223 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -88,7 +88,7 @@ def test_group_indexers_by_index(self) -> None: indexing.group_indexers_by_index(data, {"z": 1}, {"method": "nearest"}) def test_map_index_queries(self) -> None: - def create_query_results( + def create_sel_results( x_indexer, x_index, other_vars, @@ -145,7 +145,7 @@ def test_indexer( test_indexer(data, Variable([], 1), indexing.IndexSelResult({"x": 0})) test_indexer(mdata, ("a", 1, -1), indexing.IndexSelResult({"x": 0})) - expected = create_query_results( + expected = create_sel_results( [True, True, False, False, False, False, False, False], PandasIndex(pd.Index([-1, -2]), "three"), {"one": Variable((), "a"), "two": Variable((), 1)}, @@ -155,7 +155,7 @@ def test_indexer( ) test_indexer(mdata, ("a", 1), expected) - expected = create_query_results( + expected = create_sel_results( slice(0, 4, None), PandasMultiIndex( pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), @@ -168,7 +168,7 @@ def test_indexer( ) test_indexer(mdata, "a", expected) - expected = create_query_results( + expected = create_sel_results( [True, True, True, True, False, False, False, False], PandasMultiIndex( pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), @@ -198,7 +198,7 @@ def test_indexer( indexing.IndexSelResult({"x": 0}), ) - expected = create_query_results( + expected = create_sel_results( [True, True, False, False, False, False, False, False], PandasIndex(pd.Index([-1, -2]), "three"), {"one": Variable((), "a"), "two": Variable((), 1)}, @@ -208,7 +208,7 @@ def test_indexer( ) test_indexer(mdata, {"one": "a", "two": 1}, expected) - expected = create_query_results( + expected = create_sel_results( [True, False, True, False, False, False, False, False], PandasIndex(pd.Index([1, 2]), "two"), {"one": Variable((), "a"), "three": Variable((), -1)}, @@ -218,7 +218,7 @@ def test_indexer( ) test_indexer(mdata, {"one": "a", "three": -1}, expected) - expected = create_query_results( + expected = create_sel_results( [True, True, True, True, False, False, False, False], PandasMultiIndex( pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), From 7f57db5dd72e2a0f87a4ece325cb651152f8f3aa Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 21 Feb 2022 21:03:06 +0100 Subject: [PATCH 149/159] alignment: use future annotations --- xarray/core/alignment.py | 64 +++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 01c64321a40..6d0873d90c0 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import operator from collections import defaultdict @@ -7,17 +9,13 @@ Any, Callable, Dict, - FrozenSet, Generic, Hashable, Iterable, - List, Mapping, - Set, Tuple, Type, TypeVar, - Union, ) import numpy as np @@ -42,7 +40,7 @@ def reindex_variables( copy: bool = True, fill_value: Any = dtypes.NA, sparse: bool = False, -) -> Dict[Hashable, Variable]: +) -> dict[Hashable, Variable]: """Conform a dictionary of variables onto a new set of variables reindexed with dimension positional indexers and possibly filled with missing values. @@ -107,23 +105,23 @@ class Aligner(Generic[DataAlignable]): NormalizedIndexes = Dict[MatchingIndexKey, Index] NormalizedIndexVars = Dict[MatchingIndexKey, Dict[Hashable, Variable]] - objects: Tuple[DataAlignable, ...] - objects_matching_indexes: Tuple[Dict[MatchingIndexKey, Index], ...] + objects: tuple[DataAlignable, ...] + objects_matching_indexes: tuple[dict[MatchingIndexKey, Index], ...] join: str - exclude_dims: FrozenSet[Hashable] - exclude_vars: FrozenSet[Hashable] + exclude_dims: frozenset[Hashable] + exclude_vars: frozenset[Hashable] copy: bool fill_value: Any sparse: bool - indexes: Dict[MatchingIndexKey, Index] - index_vars: Dict[MatchingIndexKey, Dict[Hashable, Variable]] - all_indexes: Dict[MatchingIndexKey, List[Index]] - all_index_vars: Dict[MatchingIndexKey, List[Dict[Hashable, Variable]]] - aligned_indexes: Dict[MatchingIndexKey, Index] - aligned_index_vars: Dict[MatchingIndexKey, Dict[Hashable, Variable]] - reindex: Dict[MatchingIndexKey, bool] - reindex_kwargs: Dict[str, Any] - unindexed_dim_sizes: Dict[Hashable, Set] + indexes: dict[MatchingIndexKey, Index] + index_vars: dict[MatchingIndexKey, dict[Hashable, Variable]] + all_indexes: dict[MatchingIndexKey, list[Index]] + all_index_vars: dict[MatchingIndexKey, list[dict[Hashable, Variable]]] + aligned_indexes: dict[MatchingIndexKey, Index] + aligned_index_vars: dict[MatchingIndexKey, dict[Hashable, Variable]] + reindex: dict[MatchingIndexKey, bool] + reindex_kwargs: dict[str, Any] + unindexed_dim_sizes: dict[Hashable, set] new_indexes: Indexes[Index] def __init__( @@ -134,7 +132,7 @@ def __init__( exclude_dims: Iterable = frozenset(), exclude_vars: Iterable[Hashable] = frozenset(), method: str = None, - tolerance: Union[Union[int, float], Iterable[Union[int, float]]] = None, + tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = dtypes.NA, sparse: bool = False, @@ -175,7 +173,7 @@ def __init__( def _normalize_indexes( self, indexes: Mapping[Any, Any], - ) -> Tuple[NormalizedIndexes, NormalizedIndexVars]: + ) -> tuple[NormalizedIndexes, NormalizedIndexVars]: """Normalize the indexes/indexers used for re-indexing or alignment. Return dictionaries of xarray Index objects and coordinate variables @@ -187,7 +185,7 @@ def _normalize_indexes( else: xr_variables = {} - xr_indexes: Dict[Hashable, Index] = {} + xr_indexes: dict[Hashable, Index] = {} for k, idx in indexes.items(): if not isinstance(idx, Index): if getattr(idx, "dims", (k,)) != (k,): @@ -443,7 +441,7 @@ def assert_unindexed_dim_sizes_equal(self): f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg ) - def override_indexes(self) -> Tuple[DataAlignable, ...]: + def override_indexes(self) -> tuple[DataAlignable, ...]: objects = list(self.objects) for i, obj in enumerate(objects[1:]): @@ -464,8 +462,8 @@ def override_indexes(self) -> Tuple[DataAlignable, ...]: def _get_dim_pos_indexers( self, - matching_indexes: Dict[MatchingIndexKey, Index], - ) -> Dict[Hashable, Any]: + matching_indexes: dict[MatchingIndexKey, Index], + ) -> dict[Hashable, Any]: dim_pos_indexers = {} for key, aligned_idx in self.aligned_indexes.items(): @@ -480,8 +478,8 @@ def _get_dim_pos_indexers( def _get_indexes_and_vars( self, obj: DataAlignable, - matching_indexes: Dict[MatchingIndexKey, Index], - ) -> Tuple[Dict[Hashable, Index], Dict[Hashable, Variable]]: + matching_indexes: dict[MatchingIndexKey, Index], + ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: new_indexes = {} new_variables = {} @@ -503,7 +501,7 @@ def _get_indexes_and_vars( def _reindex_one( self, obj: DataAlignable, - matching_indexes: Dict[MatchingIndexKey, Index], + matching_indexes: dict[MatchingIndexKey, Index], ) -> DataAlignable: new_indexes, new_variables = self._get_indexes_and_vars(obj, matching_indexes) dim_pos_indexers = self._get_dim_pos_indexers(matching_indexes) @@ -520,7 +518,7 @@ def _reindex_one( new_obj.encoding = obj.encoding return new_obj - def reindex_all(self) -> Tuple[DataAlignable, ...]: + def reindex_all(self) -> tuple[DataAlignable, ...]: result = [] for obj, matching_indexes in zip(self.objects, self.objects_matching_indexes): @@ -528,7 +526,7 @@ def reindex_all(self) -> Tuple[DataAlignable, ...]: return tuple(result) - def align(self) -> Tuple[DataAlignable, ...]: + def align(self) -> tuple[DataAlignable, ...]: if not self.indexes and len(self.objects) == 1: # fast path for the trivial case (obj,) = self.objects @@ -553,7 +551,7 @@ def align( indexes=None, exclude=frozenset(), fill_value=dtypes.NA, -) -> Tuple[DataAlignable, ...]: +) -> tuple[DataAlignable, ...]: """ Given any number of Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -839,7 +837,7 @@ def reindex( obj: DataAlignable, indexers: Mapping[Any, Any], method: str = None, - tolerance: Union[Union[int, float], Iterable[Union[int, float]]] = None, + tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = dtypes.NA, sparse: bool = False, @@ -875,9 +873,9 @@ def reindex( def reindex_like( obj: DataAlignable, - other: Union["Dataset", "DataArray"], + other: Dataset | DataArray, method: str = None, - tolerance: Union[Union[int, float], Iterable[Union[int, float]]] = None, + tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = dtypes.NA, ) -> DataAlignable: From df64f0da0154ff97788661cb88d2607164c78b15 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 21 Feb 2022 21:22:41 +0100 Subject: [PATCH 150/159] misc. tweaks --- xarray/core/alignment.py | 12 ++++++------ xarray/core/dataset.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 6d0873d90c0..b391168c6cb 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -519,12 +519,12 @@ def _reindex_one( return new_obj def reindex_all(self) -> tuple[DataAlignable, ...]: - result = [] - - for obj, matching_indexes in zip(self.objects, self.objects_matching_indexes): - result.append(self._reindex_one(obj, matching_indexes)) - - return tuple(result) + return tuple( + self._reindex_one(obj, matching_indexes) + for obj, matching_indexes in zip( + self.objects, self.objects_matching_indexes + ) + ) def align(self) -> tuple[DataAlignable, ...]: if not self.indexes and len(self.objects) == 1: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 728a28f70ef..f38cf177b66 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2869,7 +2869,7 @@ def _reindex( **indexers_kwargs: Any, ) -> Dataset: """ - Same than reindex but supports sparse option. + Same as reindex but supports sparse option. """ indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") return alignment.reindex( From a9add155698a3602307dab8a92131c220c21b600 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Feb 2022 20:24:54 +0000 Subject: [PATCH 151/159] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/contributing.rst | 4 ++-- doc/whats-new.rst | 4 ++-- xarray/core/computation.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/contributing.rst b/doc/contributing.rst index df279caa54f..0913702fd83 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -274,13 +274,13 @@ Some other important things to know about the docs: .. ipython:: python x = 2 - x ** 3 + x**3 will be rendered as:: In [1]: x = 2 - In [2]: x ** 3 + In [2]: x**3 Out[2]: 8 Almost all code examples in the docs are run (and the output saved) during the diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 02b341f963e..4ed596342cd 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -5121,7 +5121,7 @@ Enhancements .. ipython:: python ds = xray.Dataset(coords={"x": range(100), "y": range(100)}) - ds["distance"] = np.sqrt(ds.x ** 2 + ds.y ** 2) + ds["distance"] = np.sqrt(ds.x**2 + ds.y**2) @savefig where_example.png width=4in height=4in ds.distance.where(ds.distance < 100).plot() @@ -5329,7 +5329,7 @@ Enhancements .. ipython:: python ds = xray.Dataset({"y": ("x", [1, 2, 3])}) - ds.assign(z=lambda ds: ds.y ** 2) + ds.assign(z=lambda ds: ds.y**2) ds.assign_coords(z=("x", ["a", "b", "c"])) These methods return a new Dataset (or DataArray) with updated data or diff --git a/xarray/core/computation.py b/xarray/core/computation.py index aa5bd039707..b8215b16f7a 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -964,7 +964,7 @@ def apply_ufunc( Calculate the vector magnitude of two arguments: >>> def magnitude(a, b): - ... func = lambda x, y: np.sqrt(x ** 2 + y ** 2) + ... func = lambda x, y: np.sqrt(x**2 + y**2) ... return xr.apply_ufunc(func, a, b) ... From f807733838f782f1532346dcd65fdb4f5f8fe8df Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 7 Mar 2022 09:18:41 +0100 Subject: [PATCH 152/159] Aligner tweaks Access aligned objects through the ``.result`` attribute. Enable type checking in internal methods. --- xarray/core/alignment.py | 61 ++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index b391168c6cb..d201e3a613f 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -89,6 +89,12 @@ def reindex_variables( return new_variables +CoordNamesAndDims = Tuple[Tuple[Hashable, Tuple[Hashable, ...]], ...] +MatchingIndexKey = Tuple[CoordNamesAndDims, Type[Index]] +NormalizedIndexes = Dict[MatchingIndexKey, Index] +NormalizedIndexVars = Dict[MatchingIndexKey, Dict[Hashable, Variable]] + + class Aligner(Generic[DataAlignable]): """Implements all the complex logic for the re-indexing and alignment of Xarray objects. @@ -96,16 +102,14 @@ class Aligner(Generic[DataAlignable]): For internal use only, not public API. Usage: - aligned_objects = Alignator(*objects, **kwargs).align() + aligner = Aligner(*objects, **kwargs) + aligner.align() + aligned_objects = aligner.results """ - CoordNamesAndDims = Tuple[Tuple[Hashable, Tuple[Hashable, ...]], ...] - MatchingIndexKey = Tuple[CoordNamesAndDims, Type[Index]] - NormalizedIndexes = Dict[MatchingIndexKey, Index] - NormalizedIndexVars = Dict[MatchingIndexKey, Dict[Hashable, Variable]] - objects: tuple[DataAlignable, ...] + results: tuple[DataAlignable, ...] objects_matching_indexes: tuple[dict[MatchingIndexKey, Index], ...] join: str exclude_dims: frozenset[Hashable] @@ -170,6 +174,8 @@ def __init__( self.aligned_index_vars = {} self.reindex = {} + self.results = tuple() + def _normalize_indexes( self, indexes: Mapping[Any, Any], @@ -232,7 +238,12 @@ def _normalize_indexes( return normalized_indexes, normalized_index_vars - def find_matching_indexes(self): + def find_matching_indexes(self) -> None: + all_indexes: dict[MatchingIndexKey, list[Index]] + all_index_vars: dict[MatchingIndexKey, list[dict[Hashable, Variable]]] + all_indexes_dim_sizes: dict[MatchingIndexKey, dict[Hashable, set]] + objects_matching_indexes: list[dict[MatchingIndexKey, Index]] + all_indexes = defaultdict(list) all_index_vars = defaultdict(list) all_indexes_dim_sizes = defaultdict(lambda: defaultdict(set)) @@ -261,7 +272,7 @@ def find_matching_indexes(self): f"along dimension {dim!r} that don't have the same size" ) - def find_matching_unindexed_dims(self): + def find_matching_unindexed_dims(self) -> None: unindexed_dim_sizes = defaultdict(set) for obj in self.objects: @@ -271,7 +282,7 @@ def find_matching_unindexed_dims(self): self.unindexed_dim_sizes = unindexed_dim_sizes - def assert_no_index_conflict(self): + def assert_no_index_conflict(self) -> None: """Check for uniqueness of both coordinate and dimension names accross all sets of matching indexes. @@ -287,10 +298,10 @@ def assert_no_index_conflict(self): """ matching_keys = set(self.all_indexes) | set(self.indexes) - coord_count = defaultdict(int) - dim_count = defaultdict(int) + coord_count: dict[Hashable, int] = defaultdict(int) + dim_count: dict[Hashable, int] = defaultdict(int) for coord_names_dims, _ in matching_keys: - dims_set = set() + dims_set: set[Hashable] = set() for name, dims in coord_names_dims: coord_count[name] += 1 dims_set.update(dims) @@ -343,7 +354,7 @@ def _get_index_joiner(self, index_cls) -> Callable: # join='exact' return dummy lambda (error is raised) return lambda _: None - def align_indexes(self): + def align_indexes(self) -> None: """Compute all aligned indexes and their corresponding coordinate variables.""" aligned_indexes = {} @@ -424,7 +435,7 @@ def align_indexes(self): self.reindex = reindex self.new_indexes = Indexes(new_indexes, new_index_vars) - def assert_unindexed_dim_sizes_equal(self): + def assert_unindexed_dim_sizes_equal(self) -> None: for dim, sizes in self.unindexed_dim_sizes.items(): index_size = self.new_indexes.dims.get(dim) if index_size is not None: @@ -441,7 +452,7 @@ def assert_unindexed_dim_sizes_equal(self): f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg ) - def override_indexes(self) -> tuple[DataAlignable, ...]: + def override_indexes(self) -> None: objects = list(self.objects) for i, obj in enumerate(objects[1:]): @@ -458,7 +469,7 @@ def override_indexes(self) -> tuple[DataAlignable, ...]: objects[i + 1] = obj._overwrite_indexes(new_indexes, new_variables) - return tuple(objects) + self.results = tuple(objects) def _get_dim_pos_indexers( self, @@ -518,19 +529,19 @@ def _reindex_one( new_obj.encoding = obj.encoding return new_obj - def reindex_all(self) -> tuple[DataAlignable, ...]: - return tuple( + def reindex_all(self) -> None: + self.results = tuple( self._reindex_one(obj, matching_indexes) for obj, matching_indexes in zip( self.objects, self.objects_matching_indexes ) ) - def align(self) -> tuple[DataAlignable, ...]: + def align(self) -> None: if not self.indexes and len(self.objects) == 1: # fast path for the trivial case (obj,) = self.objects - return (obj.copy(deep=self.copy),) + self.results = (obj.copy(deep=self.copy),) self.find_matching_indexes() self.find_matching_unindexed_dims() @@ -539,9 +550,9 @@ def align(self) -> tuple[DataAlignable, ...]: self.assert_unindexed_dim_sizes_equal() if self.join == "override": - return self.override_indexes() + self.override_indexes() else: - return self.reindex_all() + self.reindex_all() def align( @@ -747,7 +758,8 @@ def align( exclude_dims=exclude, fill_value=fill_value, ) - return aligner.align() + aligner.align() + return aligner.results def deep_align( @@ -868,7 +880,8 @@ def reindex( sparse=sparse, exclude_vars=exclude_vars, ) - return aligner.align()[0] + aligner.align() + return aligner.results[0] def reindex_like( From e97653ec757661cb5d1bf194bfa2f7ebeaecb6d5 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 7 Mar 2022 11:54:46 +0100 Subject: [PATCH 153/159] remove assert_unique_multiindex_level_names It is not needed anymore as multi-indexes have their own coordinates and we now check for possible name conflicts when "unpacking" multi-index levels at Dataset / DataArray creation. --- xarray/core/dataarray.py | 9 +-------- xarray/core/variable.py | 37 ------------------------------------- 2 files changed, 1 insertion(+), 45 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0c487dc6e9d..14061debcff 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -60,12 +60,7 @@ _default, either_dict_or_kwargs, ) -from .variable import ( # assert_unique_multiindex_level_names, - IndexVariable, - Variable, - as_compatible_data, - as_variable, -) +from .variable import IndexVariable, Variable, as_compatible_data, as_variable if TYPE_CHECKING: try: @@ -159,8 +154,6 @@ def _infer_coords_and_dims( "matching the dimension size" ) - # assert_unique_multiindex_level_names(new_coords) - return new_coords, dims diff --git a/xarray/core/variable.py b/xarray/core/variable.py index a74c665b371..f367b8b021d 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -4,7 +4,6 @@ import itertools import numbers import warnings -from collections import defaultdict from datetime import timedelta from typing import TYPE_CHECKING, Any, Hashable, Mapping, Sequence @@ -2986,42 +2985,6 @@ def concat( return Variable.concat(variables, dim, positions, shortcut, combine_attrs) -def assert_unique_multiindex_level_names(variables): - """Check for uniqueness of MultiIndex level names in all given - variables. - - Not public API. Used for checking consistency of DataArray and Dataset - objects. - """ - level_names = defaultdict(list) - all_level_names = set() - for var_name, var in variables.items(): - if isinstance(var._data, PandasIndexingAdapter): - idx_level_names = var.to_index_variable().level_names - if idx_level_names is not None: - for n in idx_level_names: - level_names[n].append(f"{n!r} ({var_name})") - if idx_level_names: - all_level_names.update(idx_level_names) - - for k, v in level_names.items(): - if k in variables: - v.append(f"({k})") - - duplicate_names = [v for v in level_names.values() if len(v) > 1] - if duplicate_names: - conflict_str = "\n".join(", ".join(v) for v in duplicate_names) - raise ValueError(f"conflicting MultiIndex level name(s):\n{conflict_str}") - # Check confliction between level names and dimensions GH:2299 - for k, v in variables.items(): - for d in v.dims: - if d in all_level_names: - raise ValueError( - "conflicting level / dimension names. {} " - "already exists as a level name.".format(d) - ) - - def propagate_attrs_encoding( old_variables: Mapping[Any, Variable], new_variables: Mapping[Any, Variable] ) -> None: From 72d0782f3585310542afbe628fc57c1d6c2ceb3f Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 7 Mar 2022 12:29:41 +0100 Subject: [PATCH 154/159] remove propagate_attrs_encoding Only used in one place. --- xarray/core/dataset.py | 19 ++++++++++--------- xarray/core/variable.py | 14 -------------- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f38cf177b66..cd5ecd46599 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -97,7 +97,6 @@ as_variable, broadcast_variables, calculate_dimensions, - propagate_attrs_encoding, ) if TYPE_CHECKING: @@ -1015,26 +1014,28 @@ def _overwrite_indexes( if drop_indexes is None: drop_indexes = [] - propagate_attrs_encoding(self._variables, variables) - new_variables = self._variables.copy() new_coord_names = self._coord_names.copy() new_indexes = dict(self._indexes) index_variables = {} no_index_variables = {} - for k, v in variables.items(): - if k in indexes: - index_variables[k] = v + for name, var in variables.items(): + old_var = self._variables.get(name) + if old_var is not None: + var.attrs.update(old_var.attrs) + var.encoding.update(old_var.encoding) + if name in indexes: + index_variables[name] = var else: - no_index_variables[k] = v + no_index_variables[name] = var for name in indexes: new_indexes[name] = indexes[name] - for name in index_variables: + for name, var in index_variables.items(): new_coord_names.add(name) - new_variables[name] = variables[name] + new_variables[name] = var # append no-index variables at the end for k in no_index_variables: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f367b8b021d..f6c08d040a6 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2985,20 +2985,6 @@ def concat( return Variable.concat(variables, dim, positions, shortcut, combine_attrs) -def propagate_attrs_encoding( - old_variables: Mapping[Any, Variable], new_variables: Mapping[Any, Variable] -) -> None: - """Propagate any attrs and/or encoding items from old variables that are not present - in new variables. - - """ - for name, var in new_variables.items(): - old_var = old_variables.get(name) - if old_var is not None: - var.attrs = {**old_var.attrs, **var.attrs} - var.encoding = {**old_var.encoding, **var.encoding} - - def calculate_dimensions(variables: Mapping[Any, Variable]) -> dict[Hashable, int]: """Calculate the dimensions corresponding to a set of variables. From 4f24faf4d71cf93cbb2e51fddad8b2dece39709a Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 7 Mar 2022 12:45:28 +0100 Subject: [PATCH 155/159] assert -> check + ValueError --- xarray/core/dataarray.py | 5 +++-- xarray/core/indexes.py | 7 ++++++- xarray/tests/test_indexes.py | 6 ++++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 14061debcff..a4a6a395f27 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -369,9 +369,10 @@ def __init__( assert attrs is None assert indexes is not None else: - # TODO: (benbovy - explicit indexes) remove assertion + # TODO: (benbovy - explicit indexes) remove # once it becomes part of the public interface - assert indexes is None, "Providing explicit indexes is not supported yet" + if indexes is not None: + raise ValueError("Providing explicit indexes is not supported yet") # try to fill in arguments from data if they weren't supplied if coords is None: diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index bc6e8186493..e02e1f569b2 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -288,7 +288,12 @@ def _concat_indexes(indexes, dim, positions=None) -> pd.Index: if not indexes: new_pd_index = pd.Index([]) else: - assert all(idx.dim == dim for idx in indexes) + if not all(idx.dim == dim for idx in indexes): + dims = ",".join({f"{idx.dim!r}" for idx in indexes}) + raise ValueError( + f"Cannot concatenate along dimension {dim!r} indexes with " + f"dimensions: {dims}" + ) pd_indexes = [idx.index for idx in indexes] new_pd_index = pd_indexes[0].append(pd_indexes[1:]) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index ff04ec644e4..7edcaa15105 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -183,6 +183,12 @@ def test_concat_empty(self) -> None: idx = PandasIndex.concat([], "x") assert idx.coord_dtype is np.dtype("O") + def test_concat_dim_error(self) -> None: + indexes = [PandasIndex([0, 1], "x"), PandasIndex([2, 3], "y")] + + with pytest.raises(ValueError, match=r"Cannot concatenate.*dimensions.*"): + PandasIndex.concat(indexes, "x") + def test_create_variables(self) -> None: # pandas has only Float64Index but variable dtype should be preserved data = np.array([1.1, 2.2, 3.3], dtype=np.float32) From 601dc3ab807dea8d5d2a66a66863fa4aea93d389 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 7 Mar 2022 13:18:42 +0100 Subject: [PATCH 156/159] misc. fixes and tweaks --- xarray/core/formatting.py | 2 +- xarray/core/indexing.py | 4 ++-- xarray/core/merge.py | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index f2fcf1bad6d..81617ae38f9 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -286,7 +286,7 @@ def summarize_variable( name: Hashable, var, col_width: int, max_width: int = None, is_index: bool = False ): """Summarize a variable in one line, e.g., for the Dataset.__repr__.""" - variable = var.variable if hasattr(var, "variable") else var + variable = getattr(var, "variable", var) if max_width is None: max_width_options = OPTIONS["display_width"] diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index d95d282d4ed..c797e6652de 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -56,9 +56,9 @@ class IndexSelResult: drop_coords : list, optional Coordinate(s) to drop in the resulting DataArray or Dataset. drop_indexes : list, optional - Indexes(s) to drop in the resulting DataArray or Dataset. + Index(es) to drop in the resulting DataArray or Dataset. rename_dims : dict, optional - A dictionnary in the form ``{old_dim: new_dim}`` for dimension(s) to + A dictionary in the form ``{old_dim: new_dim}`` for dimension(s) to rename in the resulting DataArray or Dataset. """ diff --git a/xarray/core/merge.py b/xarray/core/merge.py index cf9c74e9d23..8edfa5f0626 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -178,7 +178,6 @@ def _assert_prioritized_valid( ) -> None: """Make sure that elements given in prioritized will not corrupt any index given in grouped. - """ prioritized_names = set(prioritized) grouped_by_index: dict[int, list[Hashable]] = defaultdict(list) From bf0dbb7eba34993c5478742b5657e543f44ea826 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 7 Mar 2022 14:00:14 +0100 Subject: [PATCH 157/159] update what's new --- doc/whats-new.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b22c6e4d858..05bdfcf78bb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,10 +22,18 @@ v2022.03.1 (unreleased) New Features ~~~~~~~~~~~~ +- Add a ``create_index=True`` parameter to :py:meth:`Dataset.stack` and + :py:meth:`DataArray.stack` so that the creation of multi-indexes is optional + (:pull:`5692`). By `Benoît Bovy `_. +- Multi-index levels are now accessible through their own, regular coordinates + instead of virtual coordinates (:pull:`5692`). + By `Benoît Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~ +- The Dataset and DataArray ``rename*`` methods do not implicitly add or drop + indexes. (:pull:`5692`). By `Benoît Bovy `_. Deprecations ~~~~~~~~~~~~ @@ -37,6 +45,9 @@ Bug fixes - Set ``skipna=None`` for all ``quantile`` methods (e.g. :py:meth:`Dataset.quantile`) and ensure it skips missing values for float dtypes (consistent with other methods). This should not change the behavior (:pull:`6303`). By `Mathias Hauser `_. +- Many bugs fixed by the explicit indexes refactor, mainly related to multi-index (virtual) + coordinates. See the corresponding pull-request on GitHub for more details. (:pull:`5692`). + By `Benoît Bovy `_. Documentation ~~~~~~~~~~~~~ @@ -45,6 +56,9 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Many internal changes due to the explicit indexes refactor. See the + corresponding pull-request on GitHub for more details. (:pull:`5692`). + By `Benoît Bovy `_. .. _whats-new.2022.02.0: .. _whats-new.2022.03.0: From c1b778a933ff8839c13c272547104b7b56f19064 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 11 Mar 2022 11:37:08 +0100 Subject: [PATCH 158/159] concat: remove fall-backs to Variable.concat Raise an error instead when trying to concat a mix of indexed and un-indexed coordinates or when the index doesn't support concat. We still support the concatenation of a mix of scalar and dimension indexed coordinates by creating a PandasIndex on-the-fly for scalar coordinates. --- xarray/core/concat.py | 66 ++++++++++++++++++++++--------------- xarray/tests/test_concat.py | 25 ++++++++++++++ 2 files changed, 64 insertions(+), 27 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 5161e70810e..8ee4672c49a 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -7,7 +7,7 @@ from . import dtypes, utils from .alignment import align from .duck_array_ops import lazy_array_equiv -from .indexes import PandasIndex +from .indexes import Index, PandasIndex from .merge import ( _VALID_COMPAT, collect_variables_and_indexes, @@ -512,6 +512,20 @@ def ensure_common_dims(vars): var = var.set_dims(common_dims, common_shape) yield var + # get the indexes to concatenate together, create a PandasIndex + # for any scalar coordinate variable found with ``name`` matching ``dim``. + # TODO: depreciate concat a mix of scalar and dimensional indexed coodinates? + # TODO: (benbovy - explicit indexes): check index types and/or coordinates + # of all datasets? + def get_indexes(name): + for ds in datasets: + if name in ds._indexes: + yield ds._indexes[name] + elif name == dim: + var = ds._variables[name] + if not var.dims: + yield PandasIndex([var.values], dim) + # stack up each variable and/or index to fill-out the dataset (in order) # n.b. this loop preserves variable order, needed for groupby. for name in datasets[0].variables: @@ -521,36 +535,34 @@ def ensure_common_dims(vars): except KeyError: raise ValueError(f"{name!r} is not present in all datasets.") - # Try concatenate the indexes first, silently fallback to concatenate - # the variables when no index is found on all datasets or when the - # 1st index doesn't implement concat. - # TODO: (benbovy - explicit indexes): check index types and/or coordinates - # of all datasets? - try: - indexes = [ds._indexes[name] for ds in datasets] - except KeyError: + # Try concatenate the indexes, concatenate the variables when no index + # is found on all datasets. + indexes: list[Index] = list(get_indexes(name)) + if indexes: + if len(indexes) < len(datasets): + raise ValueError( + f"{name!r} must have either an index or no index in all datasets, " + f"found {len(indexes)}/{len(datasets)} datasets with an index." + ) + combined_idx = indexes[0].concat(indexes, dim, positions) + if name in datasets[0]._indexes: + idx_vars = datasets[0].xindexes.get_all_coords(name) + else: + # index created from a scalar coordinate + idx_vars = {name: datasets[0][name].variable} + result_indexes.update({k: combined_idx for k in idx_vars}) + combined_idx_vars = combined_idx.create_variables(idx_vars) + for k, v in combined_idx_vars.items(): + v.attrs = merge_attrs( + [ds.variables[k].attrs for ds in datasets], + combine_attrs=combine_attrs, + ) + result_vars[k] = v + else: combined_var = concat_vars( vars, dim, positions, combine_attrs=combine_attrs ) result_vars[name] = combined_var - else: - try: - combined_idx = indexes[0].concat(indexes, dim, positions) - except NotImplementedError: - combined_var = concat_vars( - vars, dim, positions, combine_attrs=combine_attrs - ) - result_vars[name] = combined_var - else: - idx_vars = datasets[0].xindexes.get_all_coords(name) - result_indexes.update({k: combined_idx for k in idx_vars}) - combined_idx_vars = combined_idx.create_variables(idx_vars) - for k, v in combined_idx_vars.items(): - v.attrs = merge_attrs( - [ds.variables[k].attrs for ds in datasets], - combine_attrs=combine_attrs, - ) - result_vars[k] = v elif name in result_vars: # preserves original variable order diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 7de72c352f6..aa15896a7a3 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -7,6 +7,7 @@ from xarray import DataArray, Dataset, Variable, concat from xarray.core import dtypes, merge +from xarray.core.indexes import PandasIndex from . import ( InaccessibleArray, @@ -783,3 +784,27 @@ def test_concat_typing_check() -> None: match="The elements in the input list need to be either all 'Dataset's or all 'DataArray's", ): concat([da, ds], dim="foo") + + +def test_concat_not_all_indexes() -> None: + ds1 = Dataset(coords={"x": ("x", [1, 2])}) + # ds2.x has no default index + ds2 = Dataset(coords={"x": ("y", [3, 4])}) + + with pytest.raises( + ValueError, match=r"'x' must have either an index or no index in all datasets.*" + ): + concat([ds1, ds2], dim="x") + + +def test_concat_index_not_same_dim() -> None: + ds1 = Dataset(coords={"x": ("x", [1, 2])}) + ds2 = Dataset(coords={"x": ("y", [3, 4])}) + # TODO: use public API for setting a non-default index, when available + ds2._indexes["x"] = PandasIndex([3, 4], "y") + + with pytest.raises( + ValueError, + match="Cannot concatenate along dimension 'x' indexes with dimensions: 'y','x'", + ): + concat([ds1, ds2], dim="x") From 77fdaf0e3a268d1d1fbdb6c7aef9abfd07bf0d32 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 11 Mar 2022 12:16:28 +0100 Subject: [PATCH 159/159] fix flaky test --- xarray/tests/test_concat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index aa15896a7a3..8abede64761 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -805,6 +805,6 @@ def test_concat_index_not_same_dim() -> None: with pytest.raises( ValueError, - match="Cannot concatenate along dimension 'x' indexes with dimensions: 'y','x'", + match=r"Cannot concatenate along dimension 'x' indexes with dimensions.*", ): concat([ds1, ds2], dim="x")