Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3050.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Fixed potential error in `AsyncGroup.create_dataset()` where `dtype` argument could be missing when calling `create_array()`
29 changes: 21 additions & 8 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,8 +1155,11 @@
# create_dataset in zarr 2.x requires shape but not dtype if data is
# provided. Allow this configuration by inferring dtype from data if
# necessary and passing it to create_array
if "dtype" not in kwargs and data is not None:
kwargs["dtype"] = data.dtype
if "dtype" not in kwargs:
if data is not None:
kwargs["dtype"] = data.dtype

Check warning on line 1160 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L1158-L1160

Added lines #L1158 - L1160 were not covered by tests
else:
raise ValueError("dtype must be provided if data is None")

Check warning on line 1162 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L1162

Added line #L1162 was not covered by tests
array = await self.create_array(name, shape=shape, **kwargs)
if data is not None:
await array.setitem(slice(None), data)
Expand Down Expand Up @@ -2544,12 +2547,17 @@
----------
name : str
Array name.
**kwargs :
See :func:`zarr.Group.create_dataset`.
shape : int or tuple of ints
Array shape.
dtype : str or dtype, optional
NumPy dtype.
exact : bool, optional
If True, require `dtype` to match exactly. If false, require
`dtype` can be cast from array dtype.

Returns
-------
a : Array
a : AsyncArray
"""
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))

Expand All @@ -2562,12 +2570,17 @@
----------
name : str
Array name.
**kwargs :
See :func:`zarr.Group.create_array`.
shape : int or tuple of ints
Array shape.
dtype : str or dtype, optional
NumPy dtype.
exact : bool, optional
If True, require `dtype` to match exactly. If false, require
`dtype` can be cast from array dtype.

Returns
-------
a : Array
a : AsyncArray
"""
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))

Expand Down
13 changes: 8 additions & 5 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
from hypothesis import assume, given, settings
from hypothesis import assume, given, settings, HealthCheck

from zarr.abc.store import Store
from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON
Expand Down Expand Up @@ -75,7 +75,7 @@ def deep_equal(a: Any, b: Any) -> bool:

return a == b


@settings(deadline=None) # Increased from default 200ms to None
@given(data=st.data(), zarr_format=zarr_formats)
def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None:
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
Expand Down Expand Up @@ -117,10 +117,11 @@ def test_basic_indexing(data: st.DataObject) -> None:
assert_array_equal(nparray, zarray[:])


@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(data=st.data())
def test_oindex(data: st.DataObject) -> None:
# integer_array_indices can't handle 0-size dimensions.
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=3, min_side=1, max_side=8)))
nparray = zarray[:]

zindexer, npindexer = data.draw(orthogonal_indices(shape=nparray.shape))
Expand All @@ -138,15 +139,17 @@ def test_oindex(data: st.DataObject) -> None:
assert_array_equal(nparray, zarray[:])


@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(data=st.data())
def test_vindex(data: st.DataObject) -> None:
# integer_array_indices can't handle 0-size dimensions.
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=3, min_side=1, max_side=8)))
nparray = zarray[:]

indexer = data.draw(
npst.integer_array_indices(
shape=nparray.shape, result_shape=npst.array_shapes(min_side=1, max_dims=None)
shape=nparray.shape,
result_shape=npst.array_shapes(min_side=1, max_dims=2, max_side=8)
)
)
actual = zarray.vindex[indexer]
Expand Down
Loading