Skip to content

Commit d14777c

Browse files
committed
Migrate to zarr-python 3
1 parent e709eea commit d14777c

File tree

4 files changed

+145
-64
lines changed

4 files changed

+145
-64
lines changed

pyproject.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: 3.12",
2424
"Programming Language :: Python :: Implementation :: CPython",
2525
]
26-
dependencies = [
27-
"zarr<3",
28-
"pydantic>2.0.0"
29-
]
26+
dependencies = ["zarr>=3", "pydantic>2.0.0"]
3027

3128
[project.urls]
3229
Documentation = "https://zarr.dev/pydantic-zarr/"
@@ -194,7 +191,9 @@ addopts = [
194191
"--durations=10", "-ra", "--strict-config", "--strict-markers",
195192
]
196193
filterwarnings = [
197-
"error"
194+
"error",
195+
# https://github.com/zarr-developers/zarr-python/issues/2948
196+
"ignore:The `order` keyword argument has no effect for Zarr format 3 arrays:RuntimeWarning",
198197
]
199198

200199
[tool.repo-review]

src/pydantic_zarr/v2.py

Lines changed: 88 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import math
34
import os
45
from collections.abc import Mapping
56
from typing import (
@@ -21,9 +22,9 @@
2122
from numcodecs.abc import Codec
2223
from pydantic import AfterValidator, model_validator
2324
from pydantic.functional_validators import BeforeValidator
25+
from zarr.abc.store import Store
26+
from zarr.core.sync_group import get_node
2427
from zarr.errors import ContainsArrayError, ContainsGroupError
25-
from zarr.storage import BaseStore, contains_array, contains_group, init_group
26-
from zarr.util import guess_chunks
2728

2829
from pydantic_zarr.core import (
2930
IncEx,
@@ -36,6 +37,17 @@
3637
TItem = TypeVar("TItem", bound=Union["GroupSpec", "ArraySpec"])
3738

3839

40+
def _contains_array(store: Store, path: str) -> bool:
41+
try:
42+
return isinstance(get_node(store, path, zarr_format=2), zarr.Array)
43+
except FileNotFoundError:
44+
return False
45+
46+
47+
def _contains_group(store: Store, path: str) -> bool:
48+
return isinstance(get_node(store, path, zarr_format=2), zarr.Group)
49+
50+
3951
def stringify_dtype(value: npt.DTypeLike) -> str:
4052
"""
4153
Convert a `numpy.dtype` object into a `str`.
@@ -318,14 +330,14 @@ def from_zarr(cls, array: zarr.Array) -> Self:
318330
fill_value=array.dtype.type(array.fill_value).tolist(),
319331
order=array.order,
320332
filters=array.filters,
321-
dimension_separator=array._dimension_separator,
322-
compressor=array.compressor,
333+
dimension_separator=array.metadata.dimension_separator,
334+
compressor=array.compressors[0].get_config(),
323335
attributes=array.attrs.asdict(),
324336
)
325337

326338
def to_zarr(
327339
self,
328-
store: BaseStore,
340+
store: Store,
329341
path: str,
330342
*,
331343
overwrite: bool = False,
@@ -337,14 +349,15 @@ def to_zarr(
337349
338350
Parameters
339351
----------
340-
store : instance of zarr.BaseStore
352+
store : instance of zarr.abc.store.Store
341353
The storage backend that will manifest the array.
342354
path : str
343355
The location of the array inside the store.
344356
overwrite: bool, default = False
345357
Whether to overwrite existing objects in storage to create the Zarr array.
346358
**kwargs : Any
347359
Additional keyword arguments are passed to `zarr.create`.
360+
348361
Returns
349362
-------
350363
zarr.Array
@@ -356,24 +369,20 @@ def to_zarr(
356369
spec_dict["compressor"] = numcodecs.get_codec(spec_dict["compressor"])
357370
if self.filters is not None:
358371
spec_dict["filters"] = [numcodecs.get_codec(f) for f in spec_dict["filters"]]
359-
if contains_array(store, path):
360-
extant_array = zarr.open_array(store, path=path, mode="r")
372+
if _contains_array(store, path):
373+
extant_array = zarr.open_array(store, path=path, mode="r", zarr_format=2)
361374

362375
if not self.like(extant_array):
363376
if not overwrite:
364-
msg = (
365-
f"An array already exists at path {path}. "
366-
"That array is structurally dissimilar to the array you are trying to "
367-
"store. Call to_zarr with overwrite=True to overwrite that array."
368-
)
369-
raise ContainsArrayError(msg)
377+
raise ContainsArrayError(store, path)
370378
else:
371379
if not overwrite:
372380
# extant_array is read-only, so we make a new array handle that
373381
# takes **kwargs
374382
return zarr.open_array(
375-
store=extant_array.store, path=extant_array.path, **kwargs
383+
store=extant_array.store, path=extant_array.path, zarr_format=2, **kwargs
376384
)
385+
spec_dict["zarr_format"] = spec_dict.pop("zarr_version", 2)
377386
result = zarr.create(store=store, path=path, overwrite=overwrite, **spec_dict, **kwargs)
378387
result.attrs.put(attrs)
379388
return result
@@ -519,13 +528,14 @@ def from_zarr(cls, group: zarr.Group, *, depth: int = -1) -> Self:
519528
result = cls(attributes=attributes, members=members)
520529
return result
521530

522-
def to_zarr(self, store: BaseStore, path: str, *, overwrite: bool = False, **kwargs):
531+
def to_zarr(self, store: Store, path: str, *, overwrite: bool = False, **kwargs):
523532
"""
524-
Serialize this `GroupSpec` to a Zarr group at a specific path in a `zarr.BaseStore`.
533+
Serialize this `GroupSpec` to a Zarr group at a specific path in a `zarr.abc.store.Store`.
525534
This operation will create metadata documents in the store.
535+
526536
Parameters
527537
----------
528-
store : zarr.BaseStore
538+
store : zarr.abc.store.Store
529539
The storage backend that will manifest the group and its contents.
530540
path : str
531541
The location of the group inside the store.
@@ -542,7 +552,7 @@ def to_zarr(self, store: BaseStore, path: str, *, overwrite: bool = False, **kwa
542552
"""
543553
spec_dict = self.model_dump(exclude={"members": True})
544554
attrs = spec_dict.pop("attributes")
545-
if contains_group(store, path):
555+
if _contains_group(store, path):
546556
extant_group = zarr.group(store, path=path)
547557
if not self.like(extant_group):
548558
if not overwrite:
@@ -558,14 +568,14 @@ def to_zarr(self, store: BaseStore, path: str, *, overwrite: bool = False, **kwa
558568
# then just return the extant group
559569
return extant_group
560570

561-
elif contains_array(store, path) and not overwrite:
571+
elif _contains_array(store, path) and not overwrite:
562572
msg = (
563573
f"An array already exists at path {path}. "
564574
"Call to_zarr with overwrite=True to overwrite the array."
565575
)
566576
raise ContainsArrayError(msg)
567577
else:
568-
init_group(store=store, overwrite=overwrite, path=path)
578+
zarr.create_group(store=store, overwrite=overwrite, path=path, zarr_format=2)
569579

570580
result = zarr.group(store=store, path=path, overwrite=overwrite)
571581
result.attrs.put(attrs)
@@ -746,7 +756,7 @@ def from_zarr(element: zarr.Array | zarr.Group, depth: int = -1) -> ArraySpec |
746756
@overload
747757
def to_zarr(
748758
spec: ArraySpec,
749-
store: BaseStore,
759+
store: Store,
750760
path: str,
751761
*,
752762
overwrite: bool = False,
@@ -757,7 +767,7 @@ def to_zarr(
757767
@overload
758768
def to_zarr(
759769
spec: GroupSpec,
760-
store: BaseStore,
770+
store: Store,
761771
path: str,
762772
*,
763773
overwrite: bool = False,
@@ -767,7 +777,7 @@ def to_zarr(
767777

768778
def to_zarr(
769779
spec: ArraySpec | GroupSpec,
770-
store: BaseStore,
780+
store: Store,
771781
path: str,
772782
*,
773783
overwrite: bool = False,
@@ -781,7 +791,7 @@ def to_zarr(
781791
----------
782792
spec : ArraySpec | GroupSpec
783793
The `GroupSpec` or `ArraySpec` that will be serialized to storage.
784-
store : zarr.BaseStore
794+
store : zarr.abc.store.BaseStore
785795
The storage backend that will manifest the Zarr group or array modeled by `spec`.
786796
path : str
787797
The location of the Zarr group or array inside the store.
@@ -985,7 +995,7 @@ def auto_chunks(data: Any) -> tuple[int, ...]:
985995
return data.chunksize
986996
if hasattr(data, "chunks"):
987997
return data.chunks
988-
return guess_chunks(data.shape, np.dtype(data.dtype).itemsize)
998+
return _guess_chunks(data.shape, np.dtype(data.dtype).itemsize)
989999

9901000

9911001
def auto_attributes(data: Any) -> Mapping[str, Any]:
@@ -1045,3 +1055,55 @@ def auto_dimension_separator(data: Any) -> Literal["/", "."]:
10451055
if hasattr(data, "dimension_separator"):
10461056
return data.dimension_separator
10471057
return "/"
1058+
1059+
1060+
def _guess_chunks(shape: tuple[int, ...], typesize: int) -> tuple[int, ...]:
1061+
"""
1062+
Vendored from zarr-python v2.
1063+
1064+
Guess an appropriate chunk layout for an array, given its shape and
1065+
the size of each element in bytes. Will allocate chunks only as large
1066+
as MAX_SIZE. Chunks are generally close to some power-of-2 fraction of
1067+
each axis, slightly favoring bigger values for the last index.
1068+
Undocumented and subject to change without warning.
1069+
"""
1070+
1071+
CHUNK_BASE = 256 * 1024 # Multiplier by which chunks are adjusted
1072+
CHUNK_MIN = 128 * 1024 # Soft lower limit (128k)
1073+
CHUNK_MAX = 64 * 1024 * 1024 # Hard upper limit
1074+
1075+
ndims = len(shape)
1076+
# require chunks to have non-zero length for all dimensions
1077+
chunks = np.maximum(np.array(shape, dtype="=f8"), 1)
1078+
1079+
# Determine the optimal chunk size in bytes using a PyTables expression.
1080+
# This is kept as a float.
1081+
dset_size = np.prod(chunks) * typesize
1082+
target_size = CHUNK_BASE * (2 ** np.log10(dset_size / (1024.0 * 1024)))
1083+
1084+
if target_size > CHUNK_MAX:
1085+
target_size = CHUNK_MAX
1086+
elif target_size < CHUNK_MIN:
1087+
target_size = CHUNK_MIN
1088+
1089+
idx = 0
1090+
while True:
1091+
# Repeatedly loop over the axes, dividing them by 2. Stop when:
1092+
# 1a. We're smaller than the target chunk size, OR
1093+
# 1b. We're within 50% of the target chunk size, AND
1094+
# 2. The chunk is smaller than the maximum chunk size
1095+
1096+
chunk_bytes = np.prod(chunks) * typesize
1097+
1098+
if (
1099+
chunk_bytes < target_size or abs(chunk_bytes - target_size) / target_size < 0.5
1100+
) and chunk_bytes < CHUNK_MAX:
1101+
break
1102+
1103+
if np.prod(chunks) == 1:
1104+
break # Element size larger than CHUNK_MAX
1105+
1106+
chunks[idx % ndims] = math.ceil(chunks[idx % ndims] / 2.0)
1107+
idx += 1
1108+
1109+
return tuple(int(x) for x in chunks)

src/pydantic_zarr/v3.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import numpy.typing as npt
1515
import zarr
16-
from zarr.storage import BaseStore
16+
from zarr.abc.store import Store
1717

1818
from pydantic_zarr.core import StrictBase
1919
from pydantic_zarr.v2 import DtypeStr
@@ -159,13 +159,13 @@ def from_zarr(cls, zarray: zarr.Array):
159159
"""
160160
raise NotImplementedError
161161

162-
def to_zarr(self, store: BaseStore, path: str, overwrite: bool = False) -> zarr.Array:
162+
def to_zarr(self, store: Store, path: str, overwrite: bool = False) -> zarr.Array:
163163
"""
164164
Serialize an ArraySpec to a zarr array at a specific path in a zarr store.
165165
166166
Parameters
167167
----------
168-
store : instance of zarr.BaseStore
168+
store : instance of zarr.abc.store.Store
169169
The storage backend that will manifest the array.
170170
path : str
171171
The location of the array inside the store.
@@ -222,13 +222,13 @@ def from_zarr(cls, group: zarr.Group) -> GroupSpec[TAttr, TItem]:
222222

223223
raise NotImplementedError
224224

225-
def to_zarr(self, store: BaseStore, path: str, overwrite: bool = False):
225+
def to_zarr(self, store: Store, path: str, overwrite: bool = False):
226226
"""
227227
Serialize a GroupSpec to a zarr group at a specific path in a zarr store.
228228
229229
Parameters
230230
----------
231-
store : instance of zarr.BaseStore
231+
store : instance of zarr.abc.store.Store
232232
The storage backend that will manifest the group and its contents.
233233
path : str
234234
The location of the group inside the store.
@@ -273,7 +273,7 @@ def from_zarr(element: zarr.Array | zarr.Group) -> ArraySpec | GroupSpec:
273273
@overload
274274
def to_zarr(
275275
spec: ArraySpec,
276-
store: BaseStore,
276+
store: Store,
277277
path: str,
278278
overwrite: bool = False,
279279
) -> zarr.Array: ...
@@ -282,15 +282,15 @@ def to_zarr(
282282
@overload
283283
def to_zarr(
284284
spec: GroupSpec,
285-
store: BaseStore,
285+
store: Store,
286286
path: str,
287287
overwrite: bool = False,
288288
) -> zarr.Group: ...
289289

290290

291291
def to_zarr(
292292
spec: ArraySpec | GroupSpec,
293-
store: BaseStore,
293+
store: Store,
294294
path: str,
295295
overwrite: bool = False,
296296
) -> zarr.Array | zarr.Group:
@@ -302,7 +302,7 @@ def to_zarr(
302302
----------
303303
spec : GroupSpec or ArraySpec
304304
The GroupSpec or ArraySpec that will be serialized to storage.
305-
store : instance of zarr.BaseStore
305+
store : instance of zarr.abc.store.Store
306306
The storage backend that will manifest the group or array.
307307
path : str
308308
The location of the group or array inside the store.

0 commit comments

Comments
 (0)