|
18 | 18 | MutableMapping,
|
19 | 19 | Sequence,
|
20 | 20 | )
|
21 |
| -from functools import partial |
| 21 | +from functools import lru_cache, partial |
22 | 22 | from html import escape
|
23 | 23 | from numbers import Number
|
24 | 24 | from operator import methodcaller
|
@@ -236,7 +236,6 @@ def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint):
|
236 | 236 | """
|
237 | 237 | Return map from each dim to chunk sizes, accounting for backend's preferred chunks.
|
238 | 238 | """
|
239 |
| - |
240 | 239 | if isinstance(var, IndexVariable):
|
241 | 240 | return {}
|
242 | 241 | dims = var.dims
|
@@ -266,31 +265,56 @@ def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint):
|
266 | 265 | preferred_chunk_sizes = preferred_chunks[dim]
|
267 | 266 | except KeyError:
|
268 | 267 | continue
|
269 |
| - # Determine the stop indices of the preferred chunks, but omit the last stop |
270 |
| - # (equal to the dim size). In particular, assume that when a sequence |
271 |
| - # expresses the preferred chunks, the sequence sums to the size. |
272 |
| - preferred_stops = ( |
273 |
| - range(preferred_chunk_sizes, size, preferred_chunk_sizes) |
274 |
| - if isinstance(preferred_chunk_sizes, int) |
275 |
| - else itertools.accumulate(preferred_chunk_sizes[:-1]) |
276 |
| - ) |
277 |
| - # Gather any stop indices of the specified chunks that are not a stop index |
278 |
| - # of a preferred chunk. Again, omit the last stop, assuming that it equals |
279 |
| - # the dim size. |
280 |
| - breaks = set(itertools.accumulate(chunk_sizes[:-1])).difference( |
281 |
| - preferred_stops |
| 268 | + disagreement = _get_breaks_cached( |
| 269 | + size=size, |
| 270 | + chunk_sizes=chunk_sizes, |
| 271 | + preferred_chunk_sizes=preferred_chunk_sizes, |
282 | 272 | )
|
283 |
| - if breaks: |
284 |
| - warnings.warn( |
| 273 | + if disagreement: |
| 274 | + emit_user_level_warning( |
285 | 275 | "The specified chunks separate the stored chunks along "
|
286 |
| - f'dimension "{dim}" starting at index {min(breaks)}. This could ' |
| 276 | + f'dimension "{dim}" starting at index {disagreement}. This could ' |
287 | 277 | "degrade performance. Instead, consider rechunking after loading.",
|
288 |
| - stacklevel=2, |
289 | 278 | )
|
290 | 279 |
|
291 | 280 | return dict(zip(dims, chunk_shape, strict=True))
|
292 | 281 |
|
293 | 282 |
|
| 283 | +@lru_cache(maxsize=512) |
| 284 | +def _get_breaks_cached( |
| 285 | + *, |
| 286 | + size: int, |
| 287 | + chunk_sizes: tuple[int, ...], |
| 288 | + preferred_chunk_sizes: int | tuple[int, ...], |
| 289 | +) -> int | None: |
| 290 | + if isinstance(preferred_chunk_sizes, int) and preferred_chunk_sizes == 1: |
| 291 | + # short-circuit for the trivial case |
| 292 | + return None |
| 293 | + # Determine the stop indices of the preferred chunks, but omit the last stop |
| 294 | + # (equal to the dim size). In particular, assume that when a sequence |
| 295 | + # expresses the preferred chunks, the sequence sums to the size. |
| 296 | + preferred_stops = ( |
| 297 | + range(preferred_chunk_sizes, size, preferred_chunk_sizes) |
| 298 | + if isinstance(preferred_chunk_sizes, int) |
| 299 | + else set(itertools.accumulate(preferred_chunk_sizes[:-1])) |
| 300 | + ) |
| 301 | + |
| 302 | + # Gather any stop indices of the specified chunks that are not a stop index |
| 303 | + # of a preferred chunk. Again, omit the last stop, assuming that it equals |
| 304 | + # the dim size. |
| 305 | + actual_stops = itertools.accumulate(chunk_sizes[:-1]) |
| 306 | + # This copy is required for parallel iteration |
| 307 | + actual_stops_2 = itertools.accumulate(chunk_sizes[:-1]) |
| 308 | + |
| 309 | + disagrees = itertools.compress( |
| 310 | + actual_stops_2, (a not in preferred_stops for a in actual_stops) |
| 311 | + ) |
| 312 | + try: |
| 313 | + return next(disagrees) |
| 314 | + except StopIteration: |
| 315 | + return None |
| 316 | + |
| 317 | + |
294 | 318 | def _maybe_chunk(
|
295 | 319 | name: Hashable,
|
296 | 320 | var: Variable,
|
|
0 commit comments