20
20
21
21
import ecml_tools
22
22
23
- from .indexing import apply_index_to_slices_changes , index_to_slices , length_to_slices
23
+ from .indexing import (
24
+ apply_index_to_slices_changes ,
25
+ index_to_slices ,
26
+ length_to_slices ,
27
+ update_tuple ,
28
+ )
24
29
25
30
LOG = logging .getLogger (__name__ )
26
31
27
32
__all__ = ["open_dataset" , "open_zarr" , "debug_zarr_loading" ]
28
33
29
34
DEBUG_ZARR_LOADING = int (os .environ .get ("DEBUG_ZARR_LOADING" , "0" ))
30
35
36
+ DEPTH = 0
37
+
38
+
39
+ def debug_indexing (method ):
40
+ def wrapper (self , index ):
41
+ global DEPTH
42
+ if isinstance (index , tuple ):
43
+ print (" " * DEPTH , "->" , self , method .__name__ , index )
44
+ DEPTH += 1
45
+ result = method (self , index )
46
+ DEPTH -= 1
47
+ if isinstance (index , tuple ):
48
+ print (" " * DEPTH , "<-" , self , method .__name__ , result .shape )
49
+ return result
50
+
51
+ return wrapper
52
+
31
53
32
54
def debug_zarr_loading (on_off ):
33
55
global DEBUG_ZARR_LOADING
@@ -192,6 +214,7 @@ def metadata_specific(self, **kwargs):
192
214
def __repr__ (self ):
193
215
return self .__class__ .__name__ + "()"
194
216
217
+ @debug_indexing
195
218
def _get_tuple (self , n ):
196
219
raise NotImplementedError (
197
220
f"Tuple not supported: { n } (class { self .__class__ .__name__ } )"
@@ -344,6 +367,7 @@ def __init__(self, path):
344
367
def __len__ (self ):
345
368
return self .data .shape [0 ]
346
369
370
+ @debug_indexing
347
371
def __getitem__ (self , n ):
348
372
if isinstance (n , tuple ) and any (not isinstance (i , (int , slice )) for i in n ):
349
373
return self ._getitem_extended (n )
@@ -638,6 +662,7 @@ class Concat(Combined):
638
662
def __len__ (self ):
639
663
return sum (len (i ) for i in self .datasets )
640
664
665
+ @debug_indexing
641
666
def _get_tuple (self , index ):
642
667
index , changes = index_to_slices (index , self .shape )
643
668
result = []
@@ -661,6 +686,7 @@ def _get_tuple(self, index):
661
686
662
687
return apply_index_to_slices_changes (np .concatenate (result , axis = 0 ), changes )
663
688
689
+ @debug_indexing
664
690
def __getitem__ (self , n ):
665
691
if isinstance (n , tuple ):
666
692
return self ._get_tuple (n )
@@ -675,6 +701,7 @@ def __getitem__(self, n):
675
701
k += 1
676
702
return self .datasets [k ][n ]
677
703
704
+ @debug_indexing
678
705
def _get_slice (self , s ):
679
706
result = []
680
707
@@ -742,24 +769,30 @@ def shape(self):
742
769
assert False not in result , result
743
770
return result
744
771
772
+ @debug_indexing
745
773
def _get_tuple (self , index ):
774
+ print (index , self .shape )
746
775
index , changes = index_to_slices (index , self .shape )
747
- selected = index [self .axis ]
748
776
lengths = [d .shape [self .axis ] for d in self .datasets ]
749
- slices = length_to_slices (selected , lengths )
750
- print ("per_dataset_index" , slices )
777
+ slices = length_to_slices (index [self .axis ], lengths )
751
778
752
- result = [d [i ] for (d , i ) in zip (self .datasets , slices ) if i is not None ]
779
+ print ("SLICES" , slices , self .axis , index , lengths )
780
+ before = index [: self .axis ]
753
781
754
- x = tuple ([slice (None )] * self .axis + [selected ])
782
+ result = [
783
+ d [before + (i ,)] for (d , i ) in zip (self .datasets , slices ) if i is not None
784
+ ]
785
+ print ([d .shape for d in result ])
786
+ result = np .concatenate (result , axis = self .axis )
787
+ print (result .shape )
755
788
756
- return apply_index_to_slices_changes (
757
- np .concatenate (result , axis = self .axis )[x ], changes
758
- )
789
+ return apply_index_to_slices_changes (result , changes )
759
790
791
+ @debug_indexing
760
792
def _get_slice (self , s ):
761
793
return np .stack ([self [i ] for i in range (* s .indices (self ._len ))])
762
794
795
+ @debug_indexing
763
796
def __getitem__ (self , n ):
764
797
if isinstance (n , tuple ):
765
798
return self ._get_tuple (n )
@@ -810,42 +843,22 @@ def check_same_variables(self, d1, d2):
810
843
def __len__ (self ):
811
844
return len (self .datasets [0 ])
812
845
846
+ @debug_indexing
813
847
def _get_tuple (self , index ):
814
- print ("Join._get_tuple" , index )
815
- assert len (index ) > 1 , index
816
-
817
848
index , changes = index_to_slices (index , self .shape )
818
-
819
- selected_variables = index [1 ]
820
-
821
- index = list (index )
822
- index [1 ] = slice (None )
823
- index = tuple (index )
824
- print ("Join._get_tuple" , index )
849
+ index , previous = update_tuple (index , 1 , slice (None ))
825
850
826
851
# TODO: optimize if index does not access all datasets, so we don't load chunks we don't need
827
852
result = [d [index ] for d in self .datasets ]
828
853
829
- print (
830
- "Join._get_tuple" ,
831
- self .shape ,
832
- [r .shape for r in result ],
833
- selected_variables ,
834
- changes ,
835
- )
836
854
result = np .concatenate (result , axis = 1 )
837
- print ("Join._get_tuple" , result .shape )
838
-
839
- # raise NotImplementedError()
840
-
841
- # result = np.concatenate(result)
842
- # result = np.stack(result)
843
-
844
- return apply_index_to_slices_changes (result [:, selected_variables ], changes )
855
+ return apply_index_to_slices_changes (result [:, previous ], changes )
845
856
857
+ @debug_indexing
846
858
def _get_slice (self , s ):
847
859
return np .stack ([self [i ] for i in range (* s .indices (self ._len ))])
848
860
861
+ @debug_indexing
849
862
def __getitem__ (self , n ):
850
863
if isinstance (n , tuple ):
851
864
return self ._get_tuple (n )
@@ -931,10 +944,14 @@ def __init__(self, dataset, indices):
931
944
932
945
self .dataset = dataset
933
946
self .indices = list (indices )
947
+ self .slice = _make_slice_or_index_from_list_or_tuple (self .indices )
948
+ assert isinstance (self .slice , slice )
949
+ print ("SUBSET" , self .slice )
934
950
935
951
# Forward other properties to the super dataset
936
952
super ().__init__ (dataset )
937
953
954
+ @debug_indexing
938
955
def __getitem__ (self , n ):
939
956
if isinstance (n , tuple ):
940
957
return self ._get_tuple (n )
@@ -945,25 +962,22 @@ def __getitem__(self, n):
945
962
n = self .indices [n ]
946
963
return self .dataset [n ]
947
964
965
+ @debug_indexing
948
966
def _get_slice (self , s ):
949
967
# TODO: check if the indices can be simplified to a slice
950
968
# the time checking maybe be longer than the time saved
951
969
# using a slice
952
970
indices = [self .indices [i ] for i in range (* s .indices (self ._len ))]
953
971
return np .stack ([self .dataset [i ] for i in indices ])
954
972
973
+ @debug_indexing
955
974
def _get_tuple (self , n ):
956
- first , rest = n [0 ], n [1 :]
957
-
958
- if isinstance (first , int ):
959
- return self .dataset [(self .indices [first ],) + rest ]
960
-
961
- if isinstance (first , slice ):
962
- indices = tuple (self .indices [i ] for i in range (* first .indices (self ._len )))
963
- indices = _make_slice_or_index_from_list_or_tuple (indices )
964
- return self .dataset [(indices ,) + rest ]
965
-
966
- raise NotImplementedError (f"Only int and slice supported not { type (first )} " )
975
+ index , changes = index_to_slices (n , self .shape )
976
+ index , previous = update_tuple (index , 0 , self .slice )
977
+ result = self .dataset [index ]
978
+ result = result [previous ]
979
+ result = apply_index_to_slices_changes (result , changes )
980
+ return result
967
981
968
982
def __len__ (self ):
969
983
return len (self .indices )
@@ -1003,6 +1017,17 @@ def __init__(self, dataset, indices):
1003
1017
# Forward other properties to the main dataset
1004
1018
super ().__init__ (dataset )
1005
1019
1020
+ @debug_indexing
1021
+ def _get_tuple (self , index ):
1022
+ index , changes = index_to_slices (index , self .shape )
1023
+ index , previous = update_tuple (index , 1 , slice (None ))
1024
+ result = self .dataset [index ]
1025
+ result = result [:, self .indices ]
1026
+ result = result [:, previous ]
1027
+ result = apply_index_to_slices_changes (result , changes )
1028
+ return result
1029
+
1030
+ @debug_indexing
1006
1031
def __getitem__ (self , n ):
1007
1032
if isinstance (n , tuple ):
1008
1033
return self ._get_tuple (n )
0 commit comments