Skip to content

Commit 1c88f1e

Browse files
authored
Faster chunk checking for backend datasets (pydata#9808)
* Faster chunk checking for backend datasets * limit size * fix test * optimize
1 parent e510a9e commit 1c88f1e

File tree

2 files changed

+45
-19
lines changed

2 files changed

+45
-19
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ New Features
6464
underlying array's backend. Provides better support for certain wrapped array types
6565
like ``jax.numpy.ndarray``. (:issue:`7848`, :pull:`9776`).
6666
By `Sam Levang <https://github.com/slevang>`_.
67+
- Speed up loading of large zarr stores using dask arrays. (:issue:`8902`)
68+
By `Deepak Cherian <https://github.com/dcherian>`_.
6769

6870
Breaking changes
6971
~~~~~~~~~~~~~~~~

xarray/core/dataset.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
MutableMapping,
1919
Sequence,
2020
)
21-
from functools import partial
21+
from functools import lru_cache, partial
2222
from html import escape
2323
from numbers import Number
2424
from operator import methodcaller
@@ -236,7 +236,6 @@ def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint):
236236
"""
237237
Return map from each dim to chunk sizes, accounting for backend's preferred chunks.
238238
"""
239-
240239
if isinstance(var, IndexVariable):
241240
return {}
242241
dims = var.dims
@@ -266,31 +265,56 @@ def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint):
266265
preferred_chunk_sizes = preferred_chunks[dim]
267266
except KeyError:
268267
continue
269-
# Determine the stop indices of the preferred chunks, but omit the last stop
270-
# (equal to the dim size). In particular, assume that when a sequence
271-
# expresses the preferred chunks, the sequence sums to the size.
272-
preferred_stops = (
273-
range(preferred_chunk_sizes, size, preferred_chunk_sizes)
274-
if isinstance(preferred_chunk_sizes, int)
275-
else itertools.accumulate(preferred_chunk_sizes[:-1])
276-
)
277-
# Gather any stop indices of the specified chunks that are not a stop index
278-
# of a preferred chunk. Again, omit the last stop, assuming that it equals
279-
# the dim size.
280-
breaks = set(itertools.accumulate(chunk_sizes[:-1])).difference(
281-
preferred_stops
268+
disagreement = _get_breaks_cached(
269+
size=size,
270+
chunk_sizes=chunk_sizes,
271+
preferred_chunk_sizes=preferred_chunk_sizes,
282272
)
283-
if breaks:
284-
warnings.warn(
273+
if disagreement:
274+
emit_user_level_warning(
285275
"The specified chunks separate the stored chunks along "
286-
f'dimension "{dim}" starting at index {min(breaks)}. This could '
276+
f'dimension "{dim}" starting at index {disagreement}. This could '
287277
"degrade performance. Instead, consider rechunking after loading.",
288-
stacklevel=2,
289278
)
290279

291280
return dict(zip(dims, chunk_shape, strict=True))
292281

293282

283+
@lru_cache(maxsize=512)
284+
def _get_breaks_cached(
285+
*,
286+
size: int,
287+
chunk_sizes: tuple[int, ...],
288+
preferred_chunk_sizes: int | tuple[int, ...],
289+
) -> int | None:
290+
if isinstance(preferred_chunk_sizes, int) and preferred_chunk_sizes == 1:
291+
# short-circuit for the trivial case
292+
return None
293+
# Determine the stop indices of the preferred chunks, but omit the last stop
294+
# (equal to the dim size). In particular, assume that when a sequence
295+
# expresses the preferred chunks, the sequence sums to the size.
296+
preferred_stops = (
297+
range(preferred_chunk_sizes, size, preferred_chunk_sizes)
298+
if isinstance(preferred_chunk_sizes, int)
299+
else set(itertools.accumulate(preferred_chunk_sizes[:-1]))
300+
)
301+
302+
# Gather any stop indices of the specified chunks that are not a stop index
303+
# of a preferred chunk. Again, omit the last stop, assuming that it equals
304+
# the dim size.
305+
actual_stops = itertools.accumulate(chunk_sizes[:-1])
306+
# This copy is required for parallel iteration
307+
actual_stops_2 = itertools.accumulate(chunk_sizes[:-1])
308+
309+
disagrees = itertools.compress(
310+
actual_stops_2, (a not in preferred_stops for a in actual_stops)
311+
)
312+
try:
313+
return next(disagrees)
314+
except StopIteration:
315+
return None
316+
317+
294318
def _maybe_chunk(
295319
name: Hashable,
296320
var: Variable,

0 commit comments

Comments
 (0)