From 782c14ab1513d1da4a6130ce96430dd78c0c99fa Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 31 Aug 2023 15:54:56 +0200 Subject: [PATCH 1/4] add Index.load and Index.chunk methods --- xarray/core/indexes.py | 123 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 109 insertions(+), 14 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index b5e396963a1..b2db0decbb9 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -4,7 +4,7 @@ import copy from collections import defaultdict from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast import numpy as np import pandas as pd @@ -15,6 +15,7 @@ PandasIndexingAdapter, PandasMultiIndexingAdapter, ) +from xarray.core.parallelcompat import ChunkManagerEntrypoint from xarray.core.utils import ( Frozen, emit_user_level_warning, @@ -54,8 +55,6 @@ class Index: corresponding operation on a :py:meth:`Dataset` or :py:meth:`DataArray` either will raise a ``NotImplementedError`` or will simply drop/pass/copy the index from/to the result. - - Do not use this class directly for creating index objects. """ @classmethod @@ -321,6 +320,55 @@ def equals(self: T_Index, other: T_Index) -> bool: """ raise NotImplementedError() + def load(self, **kwargs) -> Index | None: + """Method called when calling :py:meth:`Dataset.load` or + :py:meth:`Dataset.compute` (or DataArray equivalent methods). + + The default implementation will simply drop the index by returning + ``None``. + + Possible re-implementations in subclasses are: + + - For an index with coordinate data fully in memory like a ``PandasIndex``: + return itself + - For an index with lazy coordinate data (e.g., a dask array): + return an index object of another type like ``PandasIndex`` + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + """ + return None + + def chunk( + self, + chunks: Literal["auto"] | Mapping[Any, None | tuple[int, ...]], + name_prefix: str = "xarray-", + token: str | None = None, + lock: bool = False, + inline_array: bool = False, + chunked_array_type: ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, + ) -> Index | None: + """Method called when calling :py:meth:`Dataset.chunk` or + :py:meth:`Dataset.chunk` (or DataArray equivalent methods). + + The default implementation will simply drop the index by returning + ``None``. + + Possible re-implementations in subclasses are: + + - For an index with coordinate data fully in memory like a ``PandasIndex``: + return itself (do not chunk) + - For an index with lazy coordinate data (e.g., a dask array): + rebuild the index with an internal lookup structure that is + in sync with the new chunks + + For more details about the parameters, see :py:meth:`Dataset.chunk`. + """ + return None + def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None: """Roll this index by an offset along one or more dimensions. @@ -821,6 +869,23 @@ def reindex_like( return {self.dim: get_indexer_nd(self.index, other.index, method, tolerance)} + def load(self: T_PandasIndex, **kwargs) -> T_PandasIndex: + # both index and coordinate(s) already loaded in-memory + return self + + def chunk( + self: T_PandasIndex, + chunks: Literal["auto"] | Mapping[Any, None | tuple[int, ...]], + name_prefix: str = "xarray-", + token: str | None = None, + lock: bool = False, + inline_array: bool = False, + chunked_array_type: ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, + ) -> T_PandasIndex: + # skip chunk + return self + def roll(self, shifts: Mapping[Any, int]) -> PandasIndex: shift = shifts[self.dim] % self.index.shape[0] @@ -1764,19 +1829,42 @@ def check_variables(): return not not_equal -def _apply_indexes( +def _apply_index_method( indexes: Indexes[Index], - args: Mapping[Any, Any], - func: str, + method_name: str, + dim_args: Mapping | None = None, + kwargs: Mapping | None = None, ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + """Utility function that applies a given Index method for an Indexes + collection and that returns new collections of indexes and coordinate + variables. + + Filter index method calls and arguments according to ``dim_args`` if not + None. Otherwise, call the method unconditionally for each index. + + ``kwargs`` is passed to every method call. + + """ + if kwargs is None: + kwargs = {} + new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes.items()} new_index_variables: dict[Hashable, Variable] = {} for index, index_vars in indexes.group_by_index(): - index_dims = {d for var in index_vars.values() for d in var.dims} - index_args = {k: v for k, v in args.items() if k in index_dims} - if index_args: - new_index = getattr(index, func)(index_args) + func = getattr(index, method_name) + + if dim_args is not None: + index_dims = {d for var in index_vars.values() for d in var.dims} + index_args = {k: v for k, v in dim_args.items() if k in index_dims} + else: + index_args = {} + + if dim_args is None or index_args: + if index_args: + new_index = func(index_args, **kwargs) + else: + new_index = func(**kwargs) if new_index is not None: new_indexes.update({k: new_index for k in index_vars}) new_index_vars = new_index.create_variables(index_vars) @@ -1790,16 +1878,23 @@ def _apply_indexes( def isel_indexes( indexes: Indexes[Index], - indexers: Mapping[Any, Any], + dim_indexers: Mapping, ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: - return _apply_indexes(indexes, indexers, "isel") + return _apply_index_method(indexes, "isel", dim_args=dim_indexers) def roll_indexes( indexes: Indexes[Index], - shifts: Mapping[Any, int], + dim_shifts: Mapping[Any, int], +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + return _apply_index_method(indexes, "roll", dim_args=dim_shifts) + + +def load_indexes( + indexes: Indexes[Index], + kwargs: Mapping, ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: - return _apply_indexes(indexes, shifts, "roll") + return _apply_index_method(indexes, "load", kwargs=kwargs) def filter_indexes_from_coords( From a859ea1db5b6fb752c16ddfc502b985efdadfb76 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 31 Aug 2023 15:55:36 +0200 Subject: [PATCH 2/4] refactor Dataset.load --- xarray/core/coordinates.py | 14 ++------------ xarray/core/dataset.py | 6 ++++++ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index c539536a294..35421c7f71e 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -691,12 +691,7 @@ def _update_coords( self._data._variables = variables self._data._coord_names.update(new_coord_names) self._data._dims = dims - - # TODO(shoyer): once ._indexes is always populated by a dict, modify - # it to update inplace instead. - original_indexes = dict(self._data.xindexes) - original_indexes.update(indexes) - self._data._indexes = original_indexes + self._data._indexes.update(indexes) def _drop_coords(self, coord_names): # should drop indexed coordinates only @@ -777,12 +772,7 @@ def _update_coords( "cannot add coordinates with new dimensions to a DataArray" ) self._data._coords = coords - - # TODO(shoyer): once ._indexes is always populated by a dict, modify - # it to update inplace instead. - original_indexes = dict(self._data.xindexes) - original_indexes.update(indexes) - self._data._indexes = original_indexes + self._data._indexes.update(indexes) def _drop_coords(self, coord_names): # should drop indexed coordinates only diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f1a0cb9dc34..70f9de9055d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -67,6 +67,7 @@ create_default_index_implicit, filter_indexes_from_coords, isel_indexes, + load_indexes, remove_unused_levels_categories, roll_indexes, ) @@ -816,6 +817,11 @@ def load(self: T_Dataset, **kwargs) -> T_Dataset: -------- dask.compute """ + # apply Index.load, collect new indexes and variables and replace the existing ones + # new index variables may still be lazy: load them here after + indexes, index_variables = load_indexes(self.xindexes, kwargs) + self.coords._update_coords(index_variables, indexes) + # access .data to coerce everything to numpy or dask arrays lazy_data = { k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) From d77311d636096c5a1f849d83c9db7527100cf2b0 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 31 Aug 2023 17:48:07 +0200 Subject: [PATCH 3/4] tweaks and fixes --- xarray/core/indexes.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index b2db0decbb9..03584d84819 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1835,14 +1835,15 @@ def _apply_index_method( dim_args: Mapping | None = None, kwargs: Mapping | None = None, ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: - """Utility function that applies a given Index method for an Indexes + """Utility function that applies a given Index method to an Indexes collection and that returns new collections of indexes and coordinate variables. - Filter index method calls and arguments according to ``dim_args`` if not - None. Otherwise, call the method unconditionally for each index. + Index method calls and arguments are filtered according to ``dim_args`` if + it is not None. Otherwise, the method is called unconditionally for each + index. - ``kwargs`` is passed to every method call. + ``kwargs`` is passed to every call of the index method. """ if kwargs is None: @@ -1854,17 +1855,20 @@ def _apply_index_method( for index, index_vars in indexes.group_by_index(): func = getattr(index, method_name) - if dim_args is not None: + if dim_args is None: + new_index = func(**kwargs) + skip_index = False + else: index_dims = {d for var in index_vars.values() for d in var.dims} index_args = {k: v for k, v in dim_args.items() if k in index_dims} - else: - index_args = {} - - if dim_args is None or index_args: if index_args: new_index = func(index_args, **kwargs) + skip_index = False else: - new_index = func(**kwargs) + new_index = None + skip_index = True + + if not skip_index: if new_index is not None: new_indexes.update({k: new_index for k in index_vars}) new_index_vars = new_index.create_variables(index_vars) @@ -1878,16 +1882,16 @@ def _apply_index_method( def isel_indexes( indexes: Indexes[Index], - dim_indexers: Mapping, + indexers: Mapping, ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: - return _apply_index_method(indexes, "isel", dim_args=dim_indexers) + return _apply_index_method(indexes, "isel", dim_args=indexers) def roll_indexes( indexes: Indexes[Index], - dim_shifts: Mapping[Any, int], + shifts: Mapping[Any, int], ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: - return _apply_index_method(indexes, "roll", dim_args=dim_shifts) + return _apply_index_method(indexes, "roll", dim_args=shifts) def load_indexes( From 4506cb600caba75f163c088171f590b67f59264b Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 31 Aug 2023 17:48:34 +0200 Subject: [PATCH 4/4] refactor Dataset.chunk --- xarray/core/dataset.py | 46 ++++++++++++++++++++++++++++-------------- xarray/core/indexes.py | 8 ++++++++ 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 70f9de9055d..28d3b9dcffb 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -64,6 +64,7 @@ PandasIndex, PandasMultiIndex, assert_no_index_corrupted, + chunk_indexes, create_default_index_implicit, filter_indexes_from_coords, isel_indexes, @@ -2647,21 +2648,36 @@ def chunk( if from_array_kwargs is None: from_array_kwargs = {} - variables = { - k: _maybe_chunk( - k, - v, - chunks, - token, - lock, - name_prefix, - inline_array=inline_array, - chunked_array_type=chunkmanager, - from_array_kwargs=from_array_kwargs.copy(), - ) - for k, v in self.variables.items() - } - return self._replace(variables) + # apply Index.chunk, collect new indexes and variables + indexes, index_variables = chunk_indexes( + self.xindexes, + chunks, + name_prefix=name_prefix, + token=token, + lock=lock, + inline_array=inline_array, + chunked_array_type=chunkmanager, + from_array_kwargs=from_array_kwargs, + ) + + variables = {} + for k, v in self.variables.items(): + if k in index_variables: + variables[k] = index_variables[k] + else: + variables[k] = _maybe_chunk( + k, + v, + chunks, + token, + lock, + name_prefix, + inline_array=inline_array, + chunked_array_type=chunkmanager, + from_array_kwargs=from_array_kwargs.copy(), + ) + + return self._replace(variables=variables, indexes=indexes) def _validate_indexers( self, indexers: Mapping[Any, Any], missing_dims: ErrorOptionsWithWarn = "raise" diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 03584d84819..f3f45b124f5 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1901,6 +1901,14 @@ def load_indexes( return _apply_index_method(indexes, "load", kwargs=kwargs) +def chunk_indexes( + indexes: Indexes[Index], + chunks: Mapping, + **kwargs, +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + return _apply_index_method(indexes, "chunk", dim_args=chunks, kwargs=kwargs) + + def filter_indexes_from_coords( indexes: Mapping[Any, Index], filtered_coord_names: set,