diff --git a/doc/whats-new.rst b/doc/whats-new.rst index add40bb6b81..c49a957b1e1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -12,7 +12,8 @@ v2025.07.2 (unreleased) New Features ~~~~~~~~~~~~ - +Support chunking by :py:class:`~xarray.groupers.SeasonResampler` for seasonal data analysis (:issue:`10425`, :pull:`10519`). +By `Dhruva Kumar Kaushal `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -67,7 +68,7 @@ Bug fixes creates extra variables that don't match the provided coordinate names, instead of silently ignoring them. The error message suggests using the factory method pattern with :py:meth:`xarray.Coordinates.from_xindex` and - :py:meth:`Dataset.assign_coords` for advanced use cases (:issue:`10499`). + :py:meth:`Dataset.assign_coords` for advanced use cases (:issue:`10499`, :pull:`10503`). By `Dhruva Kumar Kaushal `_. Documentation diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 26db282c3df..3e84ed057dd 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2411,13 +2411,14 @@ def chunk( sizes along that dimension will not be updated; non-dask arrays will be converted into dask arrays with a single block. - Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted. + Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` or :py:class:`groupers.SeasonResampler` object is also accepted. Parameters ---------- - chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional + chunks : int, tuple of int, "auto" or mapping of hashable to int or a Resampler, optional Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or - ``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``. + ``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}`` or + ``{"time": SeasonResampler(["DJF", "MAM", "JJA", "SON"])}``. name_prefix : str, default: "xarray-" Prefix for the name of any new dask arrays. token : str, optional @@ -2452,8 +2453,7 @@ def chunk( xarray.unify_chunks dask.array.from_array """ - from xarray.core.dataarray import DataArray - from xarray.groupers import TimeResampler + from xarray.groupers import Resampler if chunks is None and not chunks_kwargs: warnings.warn( @@ -2481,41 +2481,28 @@ def chunk( f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}" ) - def _resolve_frequency( - name: Hashable, resampler: TimeResampler - ) -> tuple[int, ...]: + def _resolve_resampler(name: Hashable, resampler: Resampler) -> tuple[int, ...]: variable = self._variables.get(name, None) if variable is None: raise ValueError( - f"Cannot chunk by resampler {resampler!r} for virtual variables." + f"Cannot chunk by resampler {resampler!r} for virtual variable {name!r}." ) - elif not _contains_datetime_like_objects(variable): + if variable.ndim != 1: raise ValueError( - f"chunks={resampler!r} only supported for datetime variables. " - f"Received variable {name!r} with dtype {variable.dtype!r} instead." + f"chunks={resampler!r} only supported for 1D variables. " + f"Received variable {name!r} with {variable.ndim} dimensions instead." ) - - assert variable.ndim == 1 - chunks = ( - DataArray( - np.ones(variable.shape, dtype=int), - dims=(name,), - coords={name: variable}, + newchunks = resampler.compute_chunks(name, variable) + if sum(newchunks) != variable.shape[0]: + raise ValueError( + f"Logic bug in rechunking using {resampler!r}. New chunks tuple does not match size of data. Please open an issue." ) - .resample({name: resampler}) - .sum() - ) - # When bins (binning) or time periods are missing (resampling) - # we can end up with NaNs. Drop them. - if chunks.dtype.kind == "f": - chunks = chunks.dropna(name).astype(int) - chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist()) - return chunks_tuple + return newchunks chunks_mapping_ints: Mapping[Any, T_ChunkDim] = { name: ( - _resolve_frequency(name, chunks) - if isinstance(chunks, TimeResampler) + _resolve_resampler(name, chunks) + if isinstance(chunks, Resampler) else chunks ) for name, chunks in chunks_mapping.items() diff --git a/xarray/core/types.py b/xarray/core/types.py index 736a11f5f17..2305ce56199 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -32,7 +32,7 @@ from xarray.core.indexes import Index, Indexes from xarray.core.utils import Frozen from xarray.core.variable import IndexVariable, Variable - from xarray.groupers import Grouper, TimeResampler + from xarray.groupers import Grouper, Resampler from xarray.structure.alignment import Aligner GroupInput: TypeAlias = ( @@ -201,7 +201,7 @@ def copy( # FYI in some cases we don't allow `None`, which this doesn't take account of. # FYI the `str` is for a size string, e.g. "16MB", supported by dask. T_ChunkDim: TypeAlias = str | int | Literal["auto"] | tuple[int, ...] | None # noqa: PYI051 -T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim] +T_ChunkDimFreq: TypeAlias = Union["Resampler", T_ChunkDim] T_ChunksFreq: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDimFreq] # We allow the tuple form of this (though arguably we could transition to named dims only) T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim] diff --git a/xarray/groupers.py b/xarray/groupers.py index 4424c65a94b..b016306349e 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -12,7 +12,7 @@ import operator from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Mapping, Sequence +from collections.abc import Hashable, Mapping, Sequence from dataclasses import dataclass, field from itertools import chain, pairwise from typing import TYPE_CHECKING, Any, Literal, cast @@ -52,6 +52,8 @@ "EncodedGroups", "Grouper", "Resampler", + "SeasonGrouper", + "SeasonResampler", "TimeResampler", "UniqueGrouper", ] @@ -169,7 +171,26 @@ class Resampler(Grouper): Currently only used for TimeResampler, but could be used for SpaceResampler in the future. """ - pass + def compute_chunks(self, name: Hashable, variable: Variable) -> tuple[int, ...]: + """ + Compute chunk sizes for this resampler. + + This method should be implemented by subclasses to provide appropriate + chunking behavior for their specific resampling strategy. + + Parameters + ---------- + name : Hashable + The name of the dimension being chunked. + variable : Variable + The variable being chunked. + + Returns + ------- + tuple[int, ...] + A tuple of chunk sizes for the dimension. + """ + raise NotImplementedError("Subclasses must implement compute_chunks method") @dataclass @@ -565,6 +586,49 @@ def factorize(self, group: T_Group) -> EncodedGroups: coords=coordinates_from_variable(unique_coord), ) + def compute_chunks(self, name: Hashable, variable: Variable) -> tuple[int, ...]: + """ + Compute chunk sizes for this time resampler. + + This method is used during chunking operations to determine appropriate + chunk sizes for the given variable when using this resampler. + + Parameters + ---------- + name : Hashable + The name of the dimension being chunked. + variable : Variable + The variable being chunked. + + Returns + ------- + tuple[int, ...] + A tuple of chunk sizes for the dimension. + """ + from xarray.core.dataarray import DataArray + + if not _contains_datetime_like_objects(variable): + raise ValueError( + f"chunks={self!r} only supported for datetime variables. " + f"Received variable {name!r} with dtype {variable.dtype!r} instead." + ) + + chunks = ( + DataArray( + np.ones(variable.shape, dtype=int), + dims=(name,), + coords={name: variable}, + ) + .resample({name: self}) + .sum() + ) + # When bins (binning) or time periods are missing (resampling) + # we can end up with NaNs. Drop them. + if chunks.dtype.kind == "f": + chunks = chunks.dropna(name).astype(int) + chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist()) + return chunks_tuple + def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray: # Copied from flox @@ -968,5 +1032,52 @@ def get_label(year, season): return EncodedGroups(codes=codes, full_index=full_index) + def compute_chunks(self, name: Hashable, variable: Variable) -> tuple[int, ...]: + """ + Compute chunk sizes for this season resampler. + + This method is used during chunking operations to determine appropriate + chunk sizes for the given variable when using this resampler. + + Parameters + ---------- + name : Hashable + The name of the dimension being chunked. + variable : Variable + The variable being chunked. + + Returns + ------- + tuple[int, ...] + A tuple of chunk sizes for the dimension. + """ + from xarray.core.dataarray import DataArray + + if not _contains_datetime_like_objects(variable): + raise ValueError( + f"chunks={self!r} only supported for datetime variables. " + f"Received variable {name!r} with dtype {variable.dtype!r} instead." + ) + + # Create a temporary resampler that ignores drop_incomplete for chunking + # This prevents data from being silently dropped during chunking + resampler_for_chunking = type(self)(seasons=self.seasons, drop_incomplete=False) + + chunks = ( + DataArray( + np.ones(variable.shape, dtype=int), + dims=(name,), + coords={name: variable}, + ) + .resample({name: resampler_for_chunking}) + .sum() + ) + # When bins (binning) or time periods are missing (resampling) + # we can end up with NaNs. Drop them. + if chunks.dtype.kind == "f": + chunks = chunks.dropna(name).astype(int) + chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist()) + return chunks_tuple + def reset(self) -> Self: return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 3e0734c8a1a..9cc1d64172e 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -46,7 +46,7 @@ from xarray.core.indexes import Index, PandasIndex from xarray.core.types import ArrayLike from xarray.core.utils import is_scalar -from xarray.groupers import TimeResampler +from xarray.groupers import SeasonResampler, TimeResampler from xarray.namedarray.pycompat import array_type, integer_types from xarray.testing import _assert_internal_invariants from xarray.tests import ( @@ -1135,6 +1135,86 @@ def test_chunks_does_not_load_data(self) -> None: ds = open_dataset(store) assert ds.chunks == {} + @requires_dask + @pytest.mark.parametrize( + "use_cftime,calendar", + [(True, "standard"), (False, "standard"), (True, "noleap"), (True, "360_day")], + ) + def test_chunk_by_season_resampler(self, use_cftime: bool, calendar: str) -> None: + """Test chunking using SeasonResampler.""" + import dask.array + + if use_cftime: + pytest.importorskip("cftime") + + # Skip non-standard calendars with use_cftime=False as they're incompatible + if not use_cftime and calendar != "standard": + pytest.skip(f"Calendar '{calendar}' requires use_cftime=True") + + N = 366 + 365 # 2 years + time = xr.date_range( + "2000-01-01", periods=N, freq="D", use_cftime=use_cftime, calendar=calendar + ) + ds = Dataset( + { + "pr": ("time", dask.array.random.random((N), chunks=(20))), + "pr2d": (("x", "time"), dask.array.random.random((10, N), chunks=(20))), + "ones": ("time", np.ones((N,))), + }, + coords={"time": time}, + ) + + # Test standard seasons + rechunked = ds.chunk(x=2, time=SeasonResampler(["DJF", "MAM", "JJA", "SON"])) + # With 2 years of data starting Jan 1, we get 9 seasonal chunks: + # partial DJF (Jan-Feb), MAM, JJA, SON, DJF, MAM, JJA, SON, partial DJF (Dec) + assert len(rechunked.chunksizes["time"]) == 9 + assert rechunked.chunksizes["x"] == (2,) * 5 + + # Test custom seasons + rechunked = ds.chunk( + {"x": 2, "time": SeasonResampler(["DJFM", "AM", "JJA", "SON"])} + ) + # Custom seasons also produce boundary chunks + assert len(rechunked.chunksizes["time"]) == 9 + assert rechunked.chunksizes["x"] == (2,) * 5 + + # Test that drop_incomplete doesn't affect chunking + rechunked_drop_true = ds.chunk( + time=SeasonResampler(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True) + ) + rechunked_drop_false = ds.chunk( + time=SeasonResampler(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + ) + assert ( + rechunked_drop_true.chunksizes["time"] + == rechunked_drop_false.chunksizes["time"] + ) + + @requires_dask + def test_chunk_by_season_resampler_errors(self): + """Test error handling for SeasonResampler chunking.""" + ds = Dataset({"foo": ("x", [1, 2, 3])}) + + # Test error on virtual variable + with pytest.raises(ValueError, match="virtual variable"): + ds.chunk(x=SeasonResampler(["DJF", "MAM", "JJA", "SON"])) + + # Test error on non-datetime variable + ds["x"] = ("x", [1, 2, 3]) + with pytest.raises(ValueError, match="datetime variables"): + ds.chunk(x=SeasonResampler(["DJF", "MAM", "JJA", "SON"])) + + # Test successful case with 1D datetime variable + ds["x"] = ("x", xr.date_range("2001-01-01", periods=3, freq="D")) + # This should work + result = ds.chunk(x=SeasonResampler(["DJF", "MAM", "JJA", "SON"])) + assert result.chunks is not None + + # Test error on missing season (should fail with incomplete seasons) + with pytest.raises(ValueError): + ds.chunk(x=SeasonResampler(["DJF", "MAM", "SON"])) + @requires_dask def test_chunk(self) -> None: data = create_test_data()