Skip to content

Commit fe99d45

Browse files
committed
Implement from_array for v3
1 parent 77eb01a commit fe99d45

File tree

2 files changed

+160
-26
lines changed

2 files changed

+160
-26
lines changed

src/pydantic_zarr/v3.py

Lines changed: 127 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,43 @@
11
from __future__ import annotations
22

3-
from collections.abc import Mapping, Sequence
43
from typing import (
4+
TYPE_CHECKING,
55
Any,
66
Generic,
77
Literal,
8+
Self,
89
TypeVar,
910
Union,
1011
cast,
1112
overload,
1213
)
1314

15+
import numpy as np
1416
import numpy.typing as npt
15-
import zarr
16-
from zarr.abc.store import Store
1717

1818
from pydantic_zarr.core import StrictBase
19-
from pydantic_zarr.v2 import DtypeStr
19+
20+
if TYPE_CHECKING:
21+
from collections.abc import Mapping, Sequence
22+
23+
import zarr
24+
from zarr.abc.store import Store
25+
26+
from pydantic_zarr.v2 import DtypeStr
2027

2128
TAttr = TypeVar("TAttr", bound=dict[str, Any])
2229
TItem = TypeVar("TItem", bound=Union["GroupSpec", "ArraySpec"])
2330

2431
NodeType = Literal["group", "array"]
2532

33+
BoolFillValue = bool
34+
IntFillValue = int
2635
# todo: introduce a type that represents hexadecimal representations of floats
27-
FillValue = Union[
28-
Literal["Infinity", "-Infinity", "NaN"],
29-
bool,
30-
int,
31-
float,
32-
str,
33-
tuple[float, float],
34-
tuple[int, ...],
35-
]
36+
FloatFillValue = Literal["Infinity", "-Infinity", "NaN"] | float
37+
ComplexFillValue = tuple[FloatFillValue, FloatFillValue]
38+
RawFillValue = tuple[int, ...]
39+
40+
FillValue = BoolFillValue | IntFillValue | FloatFillValue | ComplexFillValue | RawFillValue
3641

3742

3843
class NamedConfig(StrictBase):
@@ -50,12 +55,12 @@ class RegularChunking(NamedConfig):
5055

5156

5257
class DefaultChunkKeyEncodingConfig(StrictBase):
53-
separator: Literal[".", "/"]
58+
separator: Literal[".", "/"] = "/"
5459

5560

5661
class DefaultChunkKeyEncoding(NamedConfig):
57-
name: Literal["default"]
58-
configuration: DefaultChunkKeyEncodingConfig | None
62+
name: Literal["default"] = "default"
63+
configuration: DefaultChunkKeyEncodingConfig | None = DefaultChunkKeyEncodingConfig()
5964

6065

6166
class NodeSpec(StrictBase):
@@ -110,11 +115,22 @@ class ArraySpec(NodeSpec, Generic[TAttr]):
110115
chunk_key_encoding: NamedConfig # todo: validate this against shape
111116
fill_value: FillValue # todo: validate this against the data type
112117
codecs: Sequence[NamedConfig]
113-
storage_transformers: Sequence[NamedConfig] | None = None
114-
dimension_names: Sequence[str] | None # todo: validate this against shape
118+
storage_transformers: Sequence[NamedConfig]
119+
dimension_names: Sequence[str | None] # todo: validate this against shape
115120

116121
@classmethod
117-
def from_array(cls, array: npt.NDArray[Any], **kwargs):
122+
def from_array(
123+
cls,
124+
array: npt.NDArray[Any],
125+
*,
126+
attributes: Literal["auto"] | TAttr = "auto",
127+
chunk_grid: Literal["auto"] | NamedConfig = "auto",
128+
chunk_key_encoding: Literal["auto"] | NamedConfig = "auto",
129+
fill_value: Literal["auto"] | FillValue = "auto",
130+
codecs: Literal["auto"] | Sequence[NamedConfig] = "auto",
131+
storage_transformers: Literal["auto"] | Sequence[NamedConfig] = "auto",
132+
dimension_names: Literal["auto"] | Sequence[str | None] = "auto",
133+
) -> Self:
118134
"""
119135
Create an ArraySpec from a numpy array-like object.
120136
@@ -131,15 +147,51 @@ def from_array(cls, array: npt.NDArray[Any], **kwargs):
131147
An instance of ArraySpec with properties derived from the provided array.
132148
133149
"""
134-
default_chunks = RegularChunking(
135-
configuration=RegularChunkingConfig(chunk_shape=list(array.shape))
136-
)
150+
if attributes == "auto":
151+
attributes_actual = cast(TAttr, auto_attributes(array))
152+
else:
153+
attributes_actual = attributes
154+
155+
if chunk_grid == "auto":
156+
chunk_grid_actual = auto_chunk_grid(array)
157+
else:
158+
chunk_grid_actual = chunk_grid
159+
160+
if chunk_key_encoding == "auto":
161+
chunk_key_actual = DefaultChunkKeyEncoding()
162+
else:
163+
chunk_key_actual = chunk_key_encoding
164+
165+
if fill_value == "auto":
166+
fill_value_actual = auto_fill_value(array)
167+
else:
168+
fill_value_actual = fill_value
169+
170+
if codecs == "auto":
171+
codecs_actual = auto_codecs(array)
172+
else:
173+
codecs_actual = codecs
174+
175+
if storage_transformers == "auto":
176+
storage_transformers_actual = auto_storage_transformers(array)
177+
else:
178+
storage_transformers_actual = storage_transformers
179+
180+
if dimension_names == "auto":
181+
dimension_names_actual = auto_dimension_names(array)
182+
else:
183+
dimension_names_actual = dimension_names
184+
137185
return cls(
138186
shape=array.shape,
139187
data_type=str(array.dtype),
140-
chunk_grid=kwargs.pop("chunks", default_chunks),
141-
attributes=kwargs.pop("attributes", {}),
142-
**kwargs,
188+
chunk_grid=chunk_grid_actual,
189+
attributes=attributes_actual,
190+
chunk_key_encoding=chunk_key_actual,
191+
fill_value=fill_value_actual,
192+
codecs=codecs_actual,
193+
storage_transformers=storage_transformers_actual,
194+
dimension_names=dimension_names_actual,
143195
)
144196

145197
@classmethod
@@ -325,3 +377,53 @@ def to_zarr(
325377
raise ValueError(msg)
326378

327379
return result
380+
381+
382+
def auto_attributes(array: Any) -> TAttr:
383+
if hasattr(array, "attributes"):
384+
return array.attributes
385+
return cast(TAttr, {})
386+
387+
388+
def auto_chunk_grid(array: Any) -> NamedConfig:
389+
if hasattr(array, "chunk_shape"):
390+
return array.chunk_shape
391+
elif hasattr(array, "shape"):
392+
return RegularChunking(configuration=RegularChunkingConfig(chunk_shape=list(array.shape)))
393+
raise ValueError("Cannot get chunk grid from object without .shape attribute")
394+
395+
396+
def auto_fill_value(array: Any) -> FillValue:
397+
if hasattr(array, "fill_value"):
398+
return array.fill_value
399+
elif hasattr(array, "dtype"):
400+
kind = np.dtype(array.dtype).kind
401+
if kind == "?":
402+
return False
403+
elif kind in ["i", "u"]:
404+
return 0
405+
elif kind in ["f"]:
406+
return "NaN"
407+
elif kind in ["c"]:
408+
return ("NaN", "NaN")
409+
else:
410+
raise ValueError(f"Cannot determine default fill value for data type {kind}")
411+
raise ValueError("Cannot determine default data type for object without shape attribute.")
412+
413+
414+
def auto_codecs(array: Any) -> Sequence[NamedConfig]:
415+
if hasattr(array, "codecs"):
416+
return array.codecs
417+
return []
418+
419+
420+
def auto_storage_transformers(array: Any) -> list:
421+
if hasattr(array, "storage_transformers"):
422+
return array.storage_transformers
423+
return []
424+
425+
426+
def auto_dimension_names(array: Any) -> list[str | None]:
427+
if hasattr(array, "dimension_names"):
428+
return array.dimension_names
429+
return [None] * np.asanyarray(array, copy=False).ndim

tests/test_pydantic_zarr/test_v3.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
from pydantic_zarr.v3 import ArraySpec, GroupSpec, NamedConfig
1+
import numpy as np
2+
3+
from pydantic_zarr.v3 import (
4+
ArraySpec,
5+
DefaultChunkKeyEncoding,
6+
DefaultChunkKeyEncodingConfig,
7+
GroupSpec,
8+
NamedConfig,
9+
RegularChunking,
10+
RegularChunkingConfig,
11+
)
212

313

414
def test_serialize_deserialize() -> None:
@@ -15,6 +25,28 @@ def test_serialize_deserialize() -> None:
1525
chunk_key_encoding=NamedConfig(name="default", configuration={"separator": "/"}),
1626
codecs=[NamedConfig(name="GZip", configuration={"level": 1})],
1727
fill_value="NaN",
28+
storage_transformers=[],
1829
)
1930

2031
GroupSpec(attributes=group_attributes, members={"array": array_spec})
32+
33+
34+
def test_from_array() -> None:
35+
array_spec = ArraySpec.from_array(np.arange(10))
36+
assert array_spec == ArraySpec(
37+
zarr_format=3,
38+
node_type="array",
39+
attributes={},
40+
shape=(10,),
41+
data_type="<i8",
42+
chunk_grid=RegularChunking(
43+
name="regular", configuration=RegularChunkingConfig(chunk_shape=[10])
44+
),
45+
chunk_key_encoding=DefaultChunkKeyEncoding(
46+
name="default", configuration=DefaultChunkKeyEncodingConfig(separator="/")
47+
),
48+
fill_value=0,
49+
codecs=[],
50+
storage_transformers=[],
51+
dimension_names=[None],
52+
)

0 commit comments

Comments
 (0)