Skip to content

Add Index.load() and Index.chunk() methods #8128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,12 +691,7 @@ def _update_coords(
self._data._variables = variables
self._data._coord_names.update(new_coord_names)
self._data._dims = dims

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = dict(self._data.xindexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes
self._data._indexes.update(indexes)

def _drop_coords(self, coord_names):
# should drop indexed coordinates only
Expand Down Expand Up @@ -777,12 +772,7 @@ def _update_coords(
"cannot add coordinates with new dimensions to a DataArray"
)
self._data._coords = coords

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = dict(self._data.xindexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes
self._data._indexes.update(indexes)

def _drop_coords(self, coord_names):
# should drop indexed coordinates only
Expand Down
52 changes: 37 additions & 15 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@
PandasIndex,
PandasMultiIndex,
assert_no_index_corrupted,
chunk_indexes,
create_default_index_implicit,
filter_indexes_from_coords,
isel_indexes,
load_indexes,
remove_unused_levels_categories,
roll_indexes,
)
Expand Down Expand Up @@ -816,6 +818,11 @@ def load(self: T_Dataset, **kwargs) -> T_Dataset:
--------
dask.compute
"""
# apply Index.load, collect new indexes and variables and replace the existing ones
# new index variables may still be lazy: load them here after
indexes, index_variables = load_indexes(self.xindexes, kwargs)
self.coords._update_coords(index_variables, indexes)

# access .data to coerce everything to numpy or dask arrays
lazy_data = {
k: v._data for k, v in self.variables.items() if is_chunked_array(v._data)
Expand Down Expand Up @@ -2641,21 +2648,36 @@ def chunk(
if from_array_kwargs is None:
from_array_kwargs = {}

variables = {
k: _maybe_chunk(
k,
v,
chunks,
token,
lock,
name_prefix,
inline_array=inline_array,
chunked_array_type=chunkmanager,
from_array_kwargs=from_array_kwargs.copy(),
)
for k, v in self.variables.items()
}
return self._replace(variables)
# apply Index.chunk, collect new indexes and variables
indexes, index_variables = chunk_indexes(
self.xindexes,
chunks,
name_prefix=name_prefix,
token=token,
lock=lock,
inline_array=inline_array,
chunked_array_type=chunkmanager,
from_array_kwargs=from_array_kwargs,
)

variables = {}
for k, v in self.variables.items():
if k in index_variables:
variables[k] = index_variables[k]
else:
variables[k] = _maybe_chunk(
k,
v,
chunks,
token,
lock,
name_prefix,
inline_array=inline_array,
chunked_array_type=chunkmanager,
from_array_kwargs=from_array_kwargs.copy(),
)

return self._replace(variables=variables, indexes=indexes)

def _validate_indexers(
self, indexers: Mapping[Any, Any], missing_dims: ErrorOptionsWithWarn = "raise"
Expand Down
133 changes: 120 additions & 13 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import copy
from collections import defaultdict
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast

import numpy as np
import pandas as pd
Expand All @@ -15,6 +15,7 @@
PandasIndexingAdapter,
PandasMultiIndexingAdapter,
)
from xarray.core.parallelcompat import ChunkManagerEntrypoint
from xarray.core.utils import (
Frozen,
emit_user_level_warning,
Expand Down Expand Up @@ -54,8 +55,6 @@ class Index:
corresponding operation on a :py:meth:`Dataset` or :py:meth:`DataArray`
either will raise a ``NotImplementedError`` or will simply drop/pass/copy
the index from/to the result.

Do not use this class directly for creating index objects.
"""

@classmethod
Expand Down Expand Up @@ -321,6 +320,55 @@ def equals(self: T_Index, other: T_Index) -> bool:
"""
raise NotImplementedError()

def load(self, **kwargs) -> Index | None:
"""Method called when calling :py:meth:`Dataset.load` or
:py:meth:`Dataset.compute` (or DataArray equivalent methods).

The default implementation will simply drop the index by returning
``None``.

Possible re-implementations in subclasses are:

- For an index with coordinate data fully in memory like a ``PandasIndex``:
return itself
- For an index with lazy coordinate data (e.g., a dask array):
return an index object of another type like ``PandasIndex``

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.compute``.
"""
return None

def chunk(
self,
chunks: Literal["auto"] | Mapping[Any, None | tuple[int, ...]],
name_prefix: str = "xarray-",
token: str | None = None,
lock: bool = False,
inline_array: bool = False,
chunked_array_type: ChunkManagerEntrypoint | None = None,
from_array_kwargs=None,
) -> Index | None:
"""Method called when calling :py:meth:`Dataset.chunk` or
:py:meth:`Dataset.chunk` (or DataArray equivalent methods).

The default implementation will simply drop the index by returning
``None``.

Possible re-implementations in subclasses are:

- For an index with coordinate data fully in memory like a ``PandasIndex``:
return itself (do not chunk)
- For an index with lazy coordinate data (e.g., a dask array):
rebuild the index with an internal lookup structure that is
in sync with the new chunks

For more details about the parameters, see :py:meth:`Dataset.chunk`.
"""
return None

def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None:
"""Roll this index by an offset along one or more dimensions.

Expand Down Expand Up @@ -821,6 +869,23 @@ def reindex_like(

return {self.dim: get_indexer_nd(self.index, other.index, method, tolerance)}

def load(self: T_PandasIndex, **kwargs) -> T_PandasIndex:
# both index and coordinate(s) already loaded in-memory
return self

def chunk(
self: T_PandasIndex,
chunks: Literal["auto"] | Mapping[Any, None | tuple[int, ...]],
name_prefix: str = "xarray-",
token: str | None = None,
lock: bool = False,
inline_array: bool = False,
chunked_array_type: ChunkManagerEntrypoint | None = None,
from_array_kwargs=None,
) -> T_PandasIndex:
# skip chunk
return self

def roll(self, shifts: Mapping[Any, int]) -> PandasIndex:
shift = shifts[self.dim] % self.index.shape[0]

Expand Down Expand Up @@ -1764,19 +1829,46 @@ def check_variables():
return not not_equal


def _apply_indexes(
def _apply_index_method(
indexes: Indexes[Index],
args: Mapping[Any, Any],
func: str,
method_name: str,
dim_args: Mapping | None = None,
kwargs: Mapping | None = None,
) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
"""Utility function that applies a given Index method to an Indexes
collection and that returns new collections of indexes and coordinate
variables.

Index method calls and arguments are filtered according to ``dim_args`` if
it is not None. Otherwise, the method is called unconditionally for each
index.

``kwargs`` is passed to every call of the index method.

"""
if kwargs is None:
kwargs = {}

new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes.items()}
new_index_variables: dict[Hashable, Variable] = {}

for index, index_vars in indexes.group_by_index():
index_dims = {d for var in index_vars.values() for d in var.dims}
index_args = {k: v for k, v in args.items() if k in index_dims}
if index_args:
new_index = getattr(index, func)(index_args)
func = getattr(index, method_name)

if dim_args is None:
new_index = func(**kwargs)
skip_index = False
else:
index_dims = {d for var in index_vars.values() for d in var.dims}
index_args = {k: v for k, v in dim_args.items() if k in index_dims}
if index_args:
new_index = func(index_args, **kwargs)
skip_index = False
else:
new_index = None
skip_index = True

if not skip_index:
if new_index is not None:
new_indexes.update({k: new_index for k in index_vars})
new_index_vars = new_index.create_variables(index_vars)
Expand All @@ -1790,16 +1882,31 @@ def _apply_indexes(

def isel_indexes(
indexes: Indexes[Index],
indexers: Mapping[Any, Any],
indexers: Mapping,
) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
return _apply_indexes(indexes, indexers, "isel")
return _apply_index_method(indexes, "isel", dim_args=indexers)


def roll_indexes(
indexes: Indexes[Index],
shifts: Mapping[Any, int],
) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
return _apply_indexes(indexes, shifts, "roll")
return _apply_index_method(indexes, "roll", dim_args=shifts)


def load_indexes(
indexes: Indexes[Index],
kwargs: Mapping,
) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
return _apply_index_method(indexes, "load", kwargs=kwargs)


def chunk_indexes(
indexes: Indexes[Index],
chunks: Mapping,
**kwargs,
) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
return _apply_index_method(indexes, "chunk", dim_args=chunks, kwargs=kwargs)


def filter_indexes_from_coords(
Expand Down