1
1
from __future__ import annotations
2
2
3
- from collections .abc import Mapping , Sequence
4
3
from typing import (
4
+ TYPE_CHECKING ,
5
5
Any ,
6
6
Generic ,
7
7
Literal ,
8
+ Self ,
8
9
TypeVar ,
9
10
Union ,
10
11
cast ,
11
12
overload ,
12
13
)
13
14
15
+ import numpy as np
14
16
import numpy .typing as npt
15
- import zarr
16
- from zarr .abc .store import Store
17
17
18
18
from pydantic_zarr .core import StrictBase
19
- from pydantic_zarr .v2 import DtypeStr
19
+
20
+ if TYPE_CHECKING :
21
+ from collections .abc import Mapping , Sequence
22
+
23
+ import zarr
24
+ from zarr .abc .store import Store
25
+
26
+ from pydantic_zarr .v2 import DtypeStr
20
27
21
28
TAttr = TypeVar ("TAttr" , bound = dict [str , Any ])
22
29
TItem = TypeVar ("TItem" , bound = Union ["GroupSpec" , "ArraySpec" ])
23
30
24
31
NodeType = Literal ["group" , "array" ]
25
32
33
+ BoolFillValue = bool
34
+ IntFillValue = int
26
35
# todo: introduce a type that represents hexadecimal representations of floats
27
- FillValue = Union [
28
- Literal ["Infinity" , "-Infinity" , "NaN" ],
29
- bool ,
30
- int ,
31
- float ,
32
- str ,
33
- tuple [float , float ],
34
- tuple [int , ...],
35
- ]
36
+ FloatFillValue = Literal ["Infinity" , "-Infinity" , "NaN" ] | float
37
+ ComplexFillValue = tuple [FloatFillValue , FloatFillValue ]
38
+ RawFillValue = tuple [int , ...]
39
+
40
+ FillValue = BoolFillValue | IntFillValue | FloatFillValue | ComplexFillValue | RawFillValue
36
41
37
42
38
43
class NamedConfig (StrictBase ):
@@ -50,12 +55,12 @@ class RegularChunking(NamedConfig):
50
55
51
56
52
57
class DefaultChunkKeyEncodingConfig (StrictBase ):
53
- separator : Literal ["." , "/" ]
58
+ separator : Literal ["." , "/" ] = "/"
54
59
55
60
56
61
class DefaultChunkKeyEncoding (NamedConfig ):
57
- name : Literal ["default" ]
58
- configuration : DefaultChunkKeyEncodingConfig | None
62
+ name : Literal ["default" ] = "default"
63
+ configuration : DefaultChunkKeyEncodingConfig | None = DefaultChunkKeyEncodingConfig ()
59
64
60
65
61
66
class NodeSpec (StrictBase ):
@@ -110,11 +115,22 @@ class ArraySpec(NodeSpec, Generic[TAttr]):
110
115
chunk_key_encoding : NamedConfig # todo: validate this against shape
111
116
fill_value : FillValue # todo: validate this against the data type
112
117
codecs : Sequence [NamedConfig ]
113
- storage_transformers : Sequence [NamedConfig ] | None = None
114
- dimension_names : Sequence [str ] | None # todo: validate this against shape
118
+ storage_transformers : Sequence [NamedConfig ]
119
+ dimension_names : Sequence [str | None ] # todo: validate this against shape
115
120
116
121
@classmethod
117
- def from_array (cls , array : npt .NDArray [Any ], ** kwargs ):
122
+ def from_array (
123
+ cls ,
124
+ array : npt .NDArray [Any ],
125
+ * ,
126
+ attributes : Literal ["auto" ] | TAttr = "auto" ,
127
+ chunk_grid : Literal ["auto" ] | NamedConfig = "auto" ,
128
+ chunk_key_encoding : Literal ["auto" ] | NamedConfig = "auto" ,
129
+ fill_value : Literal ["auto" ] | FillValue = "auto" ,
130
+ codecs : Literal ["auto" ] | Sequence [NamedConfig ] = "auto" ,
131
+ storage_transformers : Literal ["auto" ] | Sequence [NamedConfig ] = "auto" ,
132
+ dimension_names : Literal ["auto" ] | Sequence [str | None ] = "auto" ,
133
+ ) -> Self :
118
134
"""
119
135
Create an ArraySpec from a numpy array-like object.
120
136
@@ -131,15 +147,51 @@ def from_array(cls, array: npt.NDArray[Any], **kwargs):
131
147
An instance of ArraySpec with properties derived from the provided array.
132
148
133
149
"""
134
- default_chunks = RegularChunking (
135
- configuration = RegularChunkingConfig (chunk_shape = list (array .shape ))
136
- )
150
+ if attributes == "auto" :
151
+ attributes_actual = cast (TAttr , auto_attributes (array ))
152
+ else :
153
+ attributes_actual = attributes
154
+
155
+ if chunk_grid == "auto" :
156
+ chunk_grid_actual = auto_chunk_grid (array )
157
+ else :
158
+ chunk_grid_actual = chunk_grid
159
+
160
+ if chunk_key_encoding == "auto" :
161
+ chunk_key_actual = DefaultChunkKeyEncoding ()
162
+ else :
163
+ chunk_key_actual = chunk_key_encoding
164
+
165
+ if fill_value == "auto" :
166
+ fill_value_actual = auto_fill_value (array )
167
+ else :
168
+ fill_value_actual = fill_value
169
+
170
+ if codecs == "auto" :
171
+ codecs_actual = auto_codecs (array )
172
+ else :
173
+ codecs_actual = codecs
174
+
175
+ if storage_transformers == "auto" :
176
+ storage_transformers_actual = auto_storage_transformers (array )
177
+ else :
178
+ storage_transformers_actual = storage_transformers
179
+
180
+ if dimension_names == "auto" :
181
+ dimension_names_actual = auto_dimension_names (array )
182
+ else :
183
+ dimension_names_actual = dimension_names
184
+
137
185
return cls (
138
186
shape = array .shape ,
139
187
data_type = str (array .dtype ),
140
- chunk_grid = kwargs .pop ("chunks" , default_chunks ),
141
- attributes = kwargs .pop ("attributes" , {}),
142
- ** kwargs ,
188
+ chunk_grid = chunk_grid_actual ,
189
+ attributes = attributes_actual ,
190
+ chunk_key_encoding = chunk_key_actual ,
191
+ fill_value = fill_value_actual ,
192
+ codecs = codecs_actual ,
193
+ storage_transformers = storage_transformers_actual ,
194
+ dimension_names = dimension_names_actual ,
143
195
)
144
196
145
197
@classmethod
@@ -325,3 +377,53 @@ def to_zarr(
325
377
raise ValueError (msg )
326
378
327
379
return result
380
+
381
+
382
+ def auto_attributes (array : Any ) -> TAttr :
383
+ if hasattr (array , "attributes" ):
384
+ return array .attributes
385
+ return cast (TAttr , {})
386
+
387
+
388
+ def auto_chunk_grid (array : Any ) -> NamedConfig :
389
+ if hasattr (array , "chunk_shape" ):
390
+ return array .chunk_shape
391
+ elif hasattr (array , "shape" ):
392
+ return RegularChunking (configuration = RegularChunkingConfig (chunk_shape = list (array .shape )))
393
+ raise ValueError ("Cannot get chunk grid from object without .shape attribute" )
394
+
395
+
396
+ def auto_fill_value (array : Any ) -> FillValue :
397
+ if hasattr (array , "fill_value" ):
398
+ return array .fill_value
399
+ elif hasattr (array , "dtype" ):
400
+ kind = np .dtype (array .dtype ).kind
401
+ if kind == "?" :
402
+ return False
403
+ elif kind in ["i" , "u" ]:
404
+ return 0
405
+ elif kind in ["f" ]:
406
+ return "NaN"
407
+ elif kind in ["c" ]:
408
+ return ("NaN" , "NaN" )
409
+ else :
410
+ raise ValueError (f"Cannot determine default fill value for data type { kind } " )
411
+ raise ValueError ("Cannot determine default data type for object without shape attribute." )
412
+
413
+
414
+ def auto_codecs (array : Any ) -> Sequence [NamedConfig ]:
415
+ if hasattr (array , "codecs" ):
416
+ return array .codecs
417
+ return []
418
+
419
+
420
+ def auto_storage_transformers (array : Any ) -> list :
421
+ if hasattr (array , "storage_transformers" ):
422
+ return array .storage_transformers
423
+ return []
424
+
425
+
426
+ def auto_dimension_names (array : Any ) -> list [str | None ]:
427
+ if hasattr (array , "dimension_names" ):
428
+ return array .dimension_names
429
+ return [None ] * np .asanyarray (array , copy = False ).ndim
0 commit comments