Skip to content

Commit 7b32304

Browse files
authored
Fail during planning if map_blocks drop_axis is for a chunked dimension (#569)
1 parent ecfb10f commit 7b32304

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

cubed/primitive/blockwise.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,15 @@ def make_blockwise_key_function(
687687
False,
688688
)
689689

690+
for axes, (arg, _) in zip(concat_axes, argpairs):
691+
for ax in axes:
692+
if numblocks[arg][ax] > 1:
693+
raise ValueError(
694+
f"Cannot have multiple chunks in dropped axis {ax}. "
695+
"To fix, use a reduction after calling map_blocks "
696+
"without specifying drop_axis, or rechunk first."
697+
)
698+
690699
def key_function(out_key):
691700
out_coords = out_key[1:]
692701

cubed/tests/primitive/test_blockwise.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,11 @@ def test_make_blockwise_key_function_contract():
266266
func = lambda x: 0
267267

268268
key_fn = make_blockwise_key_function(
269-
func, "z", "ik", "x", "ij", "y", "jk", numblocks={"x": (2, 2), "y": (2, 2)}
269+
func, "z", "ik", "x", "ij", "y", "jk", numblocks={"x": (2, 1), "y": (1, 2)}
270270
)
271271

272272
graph = make_blockwise_graph(
273-
func, "z", "ik", "x", "ij", "y", "jk", numblocks={"x": (2, 2), "y": (2, 2)}
273+
func, "z", "ik", "x", "ij", "y", "jk", numblocks={"x": (2, 1), "y": (1, 2)}
274274
)
275275
check_consistent_with_graph(key_fn, graph)
276276

@@ -290,10 +290,10 @@ def test_make_blockwise_key_function_contract_0d():
290290
func = lambda x: 0
291291

292292
key_fn = make_blockwise_key_function(
293-
func, "z", "", "x", "ij", numblocks={"x": (2, 2)}
293+
func, "z", "", "x", "ij", numblocks={"x": (1, 1)}
294294
)
295295

296-
graph = make_blockwise_graph(func, "z", "", "x", "ij", numblocks={"x": (2, 2)})
296+
graph = make_blockwise_graph(func, "z", "", "x", "ij", numblocks={"x": (1, 1)})
297297
check_consistent_with_graph(key_fn, graph)
298298

299299

cubed/tests/test_core.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,28 @@ def func(x, y):
235235
assert_array_equal(c.compute(), np.array([[[12, 13]]]))
236236

237237

238+
def test_map_blocks_drop_axis_chunking(spec):
239+
# This tests the case illustrated in https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
240+
# Unlike Dask, Cubed does not support concatenating chunks, and will fail if the dropped axis has multiple chunks.
241+
242+
def func(x):
243+
return nxp.sum(x, axis=2)
244+
245+
an = np.arange(8 * 6 * 2).reshape((8, 6, 2))
246+
247+
# single chunk in axis=2 works fine
248+
a = xp.asarray(an, chunks=(5, 4, 2), spec=spec)
249+
b = cubed.map_blocks(func, a, drop_axis=2)
250+
assert_array_equal(b.compute(), np.sum(an, axis=2))
251+
252+
# multiple chunks in axis=2 raises
253+
a = xp.asarray(an, chunks=(5, 4, 1), spec=spec)
254+
with pytest.raises(
255+
ValueError, match=r"Cannot have multiple chunks in dropped axis 2."
256+
):
257+
cubed.map_blocks(func, a, drop_axis=2)
258+
259+
238260
def test_map_blocks_with_non_cubed_array(spec):
239261
a = xp.arange(10, dtype="int64", chunks=(2,), spec=spec)
240262
b = np.array([1, 2], dtype="int64") # numpy array will be coerced to cubed

0 commit comments

Comments
 (0)