20
20
21
21
import ecml_tools
22
22
23
+ from .indexing import (
24
+ apply_index_to_slices_changes ,
25
+ index_to_slices ,
26
+ length_to_slices ,
27
+ update_tuple ,
28
+ )
29
+
23
30
LOG = logging .getLogger (__name__ )
24
31
25
32
__all__ = ["open_dataset" , "open_zarr" , "debug_zarr_loading" ]
26
33
27
34
DEBUG_ZARR_LOADING = int (os .environ .get ("DEBUG_ZARR_LOADING" , "0" ))
28
35
36
+ DEPTH = 0
37
+
38
+
39
+ def _debug_indexing (method ):
40
+ def wrapper (self , index ):
41
+ global DEPTH
42
+ if isinstance (index , tuple ):
43
+ print (" " * DEPTH , "->" , self , method .__name__ , index )
44
+ DEPTH += 1
45
+ result = method (self , index )
46
+ DEPTH -= 1
47
+ if isinstance (index , tuple ):
48
+ print (" " * DEPTH , "<-" , self , method .__name__ , result .shape )
49
+ return result
50
+
51
+ return wrapper
52
+
53
+
54
+ if True :
55
+
56
+ def debug_indexing (x ):
57
+ return x
58
+
59
+ else :
60
+ debug_indexing = _debug_indexing
61
+
29
62
30
63
def debug_zarr_loading (on_off ):
31
64
global DEBUG_ZARR_LOADING
@@ -190,11 +223,18 @@ def metadata_specific(self, **kwargs):
190
223
def __repr__ (self ):
191
224
return self .__class__ .__name__ + "()"
192
225
226
+ @debug_indexing
193
227
def _get_tuple (self , n ):
194
- raise NotImplementedError (f"Tuple not supported: { n } (class { self .__class__ .__name__ } )" )
228
+ raise NotImplementedError (
229
+ f"Tuple not supported: { n } (class { self .__class__ .__name__ } )"
230
+ )
195
231
196
232
197
233
class Source :
234
+ """
235
+ Class used to follow the provenance of a data point.
236
+ """
237
+
198
238
def __init__ (self , dataset , index , source = None , info = None ):
199
239
self .dataset = dataset
200
240
self .index = index
@@ -340,6 +380,7 @@ def __init__(self, path):
340
380
def __len__ (self ):
341
381
return self .data .shape [0 ]
342
382
383
+ @debug_indexing
343
384
def __getitem__ (self , n ):
344
385
if isinstance (n , tuple ) and any (not isinstance (i , (int , slice )) for i in n ):
345
386
return self ._getitem_extended (n )
@@ -352,8 +393,7 @@ def _getitem_extended(self, index):
352
393
Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves.
353
394
"""
354
395
355
- if not isinstance (index , tuple ):
356
- return self [index ]
396
+ assert False , index
357
397
358
398
shape = self .data .shape
359
399
@@ -377,7 +417,7 @@ def _unwind(self, index, rest, shape, axis, axes):
377
417
if isinstance (index , (list , tuple )):
378
418
axes .append (axis ) # Dimension of the concatenation
379
419
for i in index :
380
- yield from self ._unwind (i , rest , shape , axis , axes )
420
+ yield from self ._unwind (( slice ( i , i + 1 ),) , rest , shape , axis , axes )
381
421
return
382
422
383
423
if len (rest ) == 0 :
@@ -635,6 +675,31 @@ class Concat(Combined):
635
675
def __len__ (self ):
636
676
return sum (len (i ) for i in self .datasets )
637
677
678
+ @debug_indexing
679
+ def _get_tuple (self , index ):
680
+ index , changes = index_to_slices (index , self .shape )
681
+ result = []
682
+
683
+ first , rest = index [0 ], index [1 :]
684
+ start , stop , step = first .start , first .stop , first .step
685
+
686
+ for d in self .datasets :
687
+ length = d ._len
688
+
689
+ result .append (d [(slice (start , stop , step ),) + rest ])
690
+
691
+ start -= length
692
+ while start < 0 :
693
+ start += step
694
+
695
+ stop -= length
696
+
697
+ if start > stop :
698
+ break
699
+
700
+ return apply_index_to_slices_changes (np .concatenate (result , axis = 0 ), changes )
701
+
702
+ @debug_indexing
638
703
def __getitem__ (self , n ):
639
704
if isinstance (n , tuple ):
640
705
return self ._get_tuple (n )
@@ -649,6 +714,7 @@ def __getitem__(self, n):
649
714
k += 1
650
715
return self .datasets [k ][n ]
651
716
717
+ @debug_indexing
652
718
def _get_slice (self , s ):
653
719
result = []
654
720
@@ -716,9 +782,23 @@ def shape(self):
716
782
assert False not in result , result
717
783
return result
718
784
785
+ @debug_indexing
786
+ def _get_tuple (self , index ):
787
+ index , changes = index_to_slices (index , self .shape )
788
+ lengths = [d .shape [self .axis ] for d in self .datasets ]
789
+ slices = length_to_slices (index [self .axis ], lengths )
790
+ before = index [: self .axis ]
791
+ result = [
792
+ d [before + (i ,)] for (d , i ) in zip (self .datasets , slices ) if i is not None
793
+ ]
794
+ result = np .concatenate (result , axis = self .axis )
795
+ return apply_index_to_slices_changes (result , changes )
796
+
797
+ @debug_indexing
719
798
def _get_slice (self , s ):
720
799
return np .stack ([self [i ] for i in range (* s .indices (self ._len ))])
721
800
801
+ @debug_indexing
722
802
def __getitem__ (self , n ):
723
803
if isinstance (n , tuple ):
724
804
return self ._get_tuple (n )
@@ -769,9 +849,22 @@ def check_same_variables(self, d1, d2):
769
849
def __len__ (self ):
770
850
return len (self .datasets [0 ])
771
851
852
+ @debug_indexing
853
+ def _get_tuple (self , index ):
854
+ index , changes = index_to_slices (index , self .shape )
855
+ index , previous = update_tuple (index , 1 , slice (None ))
856
+
857
+ # TODO: optimize if index does not access all datasets, so we don't load chunks we don't need
858
+ result = [d [index ] for d in self .datasets ]
859
+
860
+ result = np .concatenate (result , axis = 1 )
861
+ return apply_index_to_slices_changes (result [:, previous ], changes )
862
+
863
+ @debug_indexing
772
864
def _get_slice (self , s ):
773
865
return np .stack ([self [i ] for i in range (* s .indices (self ._len ))])
774
866
867
+ @debug_indexing
775
868
def __getitem__ (self , n ):
776
869
if isinstance (n , tuple ):
777
870
return self ._get_tuple (n )
@@ -857,10 +950,14 @@ def __init__(self, dataset, indices):
857
950
858
951
self .dataset = dataset
859
952
self .indices = list (indices )
953
+ self .slice = _make_slice_or_index_from_list_or_tuple (self .indices )
954
+ assert isinstance (self .slice , slice )
955
+ print ("SUBSET" , self .slice )
860
956
861
957
# Forward other properties to the super dataset
862
958
super ().__init__ (dataset )
863
959
960
+ @debug_indexing
864
961
def __getitem__ (self , n ):
865
962
if isinstance (n , tuple ):
866
963
return self ._get_tuple (n )
@@ -871,25 +968,22 @@ def __getitem__(self, n):
871
968
n = self .indices [n ]
872
969
return self .dataset [n ]
873
970
971
+ @debug_indexing
874
972
def _get_slice (self , s ):
875
973
# TODO: check if the indices can be simplified to a slice
876
974
# the time checking maybe be longer than the time saved
877
975
# using a slice
878
976
indices = [self .indices [i ] for i in range (* s .indices (self ._len ))]
879
977
return np .stack ([self .dataset [i ] for i in indices ])
880
978
979
+ @debug_indexing
881
980
def _get_tuple (self , n ):
882
- first , rest = n [0 ], n [1 :]
883
-
884
- if isinstance (first , int ):
885
- return self .dataset [(self .indices [first ],) + rest ]
886
-
887
- if isinstance (first , slice ):
888
- indices = tuple (self .indices [i ] for i in range (* first .indices (self ._len )))
889
- indices = _make_slice_or_index_from_list_or_tuple (indices )
890
- return self .dataset [(indices ,) + rest ]
891
-
892
- raise NotImplementedError (f"Only int and slice supported not { type (first )} " )
981
+ index , changes = index_to_slices (n , self .shape )
982
+ index , previous = update_tuple (index , 0 , self .slice )
983
+ result = self .dataset [index ]
984
+ result = result [previous ]
985
+ result = apply_index_to_slices_changes (result , changes )
986
+ return result
893
987
894
988
def __len__ (self ):
895
989
return len (self .indices )
@@ -929,12 +1023,23 @@ def __init__(self, dataset, indices):
929
1023
# Forward other properties to the main dataset
930
1024
super ().__init__ (dataset )
931
1025
1026
+ @debug_indexing
1027
+ def _get_tuple (self , index ):
1028
+ index , changes = index_to_slices (index , self .shape )
1029
+ index , previous = update_tuple (index , 1 , slice (None ))
1030
+ result = self .dataset [index ]
1031
+ result = result [:, self .indices ]
1032
+ result = result [:, previous ]
1033
+ result = apply_index_to_slices_changes (result , changes )
1034
+ return result
1035
+
1036
+ @debug_indexing
932
1037
def __getitem__ (self , n ):
933
- # if isinstance(n, tuple):
934
- # return self._get_tuple(n)
1038
+ if isinstance (n , tuple ):
1039
+ return self ._get_tuple (n )
935
1040
936
1041
row = self .dataset [n ]
937
- if isinstance (n , ( slice , tuple ) ):
1042
+ if isinstance (n , slice ):
938
1043
return row [:, self .indices ]
939
1044
940
1045
return row [self .indices ]
0 commit comments