@@ -89,6 +89,8 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
89
89
# If the flux (data) argument is already a Spectrum (as it would
90
90
# be for internal arithmetic operations), avoid setup entirely.
91
91
if isinstance (flux , Spectrum ):
92
+ self ._spectral_axis_index = flux .spectral_axis_index
93
+ self ._spectral_axis = flux .spectral_axis
92
94
super ().__init__ (flux )
93
95
return
94
96
@@ -157,9 +159,7 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
157
159
# In the case where the arithmetic operation is being performed with
158
160
# a single float, int, or array object, just go ahead and ignore wcs
159
161
# requirements
160
- if (not isinstance (flux , u .Quantity ) or isinstance (flux , float )
161
- or isinstance (flux , int )) and np .ndim (flux ) == 0 :
162
-
162
+ if np .ndim (flux ) == 0 and spectral_axis is None and wcs is None :
163
163
super (Spectrum , self ).__init__ (data = flux , wcs = wcs , ** kwargs )
164
164
return
165
165
@@ -332,7 +332,10 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
332
332
self ._spectral_axis = spectral_axis
333
333
334
334
if wcs is None :
335
- wcs = gwcs_from_array (self ._spectral_axis )
335
+ wcs = gwcs_from_array (self ._spectral_axis ,
336
+ flux .shape ,
337
+ spectral_axis_index = self .spectral_axis_index
338
+ )
336
339
337
340
elif wcs is None :
338
341
# If no spectral axis or wcs information is provided, initialize
@@ -344,7 +347,10 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
344
347
raise ValueError ("Must specify spectral_axis_index if no WCS or spectral"
345
348
" axis is input." )
346
349
size = flux .shape [self .spectral_axis_index ] if not flux .isscalar else 1
347
- wcs = gwcs_from_array (np .arange (size ) * u .Unit ("" ))
350
+ wcs = gwcs_from_array (np .arange (size ) * u .Unit ("" ),
351
+ flux .shape ,
352
+ spectral_axis_index = self .spectral_axis_index
353
+ )
348
354
349
355
super ().__init__ (
350
356
data = flux .value if isinstance (flux , u .Quantity ) else flux ,
@@ -379,6 +385,10 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
379
385
for coords in temp_coords :
380
386
if isinstance (coords , SpectralCoord ):
381
387
spec_axis = coords
388
+ break
389
+ else :
390
+ # WCS axis ordering is reverse of numpy
391
+ spec_axis = temp_coords [len (temp_coords ) - self .spectral_axis_index - 1 ]
382
392
else :
383
393
spec_axis = temp_coords
384
394
@@ -646,7 +656,9 @@ def collapse(self, method, axis=None):
646
656
elif isinstance (axis , tuple ) and self .spectral_axis_index in axis :
647
657
return collapsed_flux
648
658
else :
649
- return Spectrum (collapsed_flux , wcs = self .wcs )
659
+ # Pass the spectral axis rather than WCS in this case, so we don't have to
660
+ # figure out which part of a multidimensional WCS is the spectral part.
661
+ return Spectrum (collapsed_flux , spectral_axis = self .spectral_axis )
650
662
651
663
def mean (self , ** kwargs ):
652
664
return self .collapse ("mean" , ** kwargs )
@@ -821,39 +833,74 @@ def _return_with_redshift(self, result):
821
833
result .shift_spectrum_to (redshift = self .redshift )
822
834
return result
823
835
824
- def __add__ (self , other ):
825
- if not isinstance (other , (NDCube , u .Quantity )):
826
- try :
827
- other = u .Quantity (other , unit = self .unit )
828
- except TypeError :
829
- return NotImplemented
836
+ def _other_as_correct_class (self , other , force_quantity = False ):
837
+ # NDArithmetic mixin will try to turn other into a Spectrum, which will fail
838
+ # sometimes because of not specifiying the spectral axis index
839
+ if isinstance (other , Spectrum ):
840
+ # Take this opportunity to check if the spectral axes match
841
+ if not np .all (other .spectral_axis == self .spectral_axis ):
842
+ raise ValueError ("Spectral axis of both operands must match" )
843
+ else :
844
+ if not isinstance (other , u .Quantity ) and force_quantity :
845
+ other = other * self .unit
830
846
831
- return self ._return_with_redshift (self .add (other ))
847
+ if isinstance (other , u .Quantity ) and other .shape == self .shape :
848
+ return Spectrum (flux = other , spectral_axis = self .spectral_axis ,
849
+ spectral_axis_index = self .spectral_axis_index )
832
850
833
- def __sub__ (self , other ):
834
- if not isinstance (other , NDCube ):
835
- try :
836
- other = u .Quantity (other , unit = self .unit )
837
- except TypeError :
838
- return NotImplemented
851
+ return other
839
852
840
- return self ._return_with_redshift (self .subtract (other ))
853
+ def __add__ (self , other ):
854
+ other = self ._other_as_correct_class (other , force_quantity = True )
855
+ if isinstance (other , (Spectrum )):
856
+ return self ._return_with_redshift (self .add (other ))
857
+ else :
858
+ new_flux = self .flux + other
859
+ return self ._return_with_redshift (Spectrum (new_flux , wcs = self .wcs , meta = self .meta ,
860
+ uncertainty = self .uncertainty ))
841
861
842
- def __mul__ (self , other ):
843
- if not isinstance (other , NDCube ):
844
- other = u .Quantity (other )
862
+ def __sub__ (self , other ):
863
+ other = self ._other_as_correct_class (other , force_quantity = True )
864
+ if isinstance (other , (Spectrum )):
865
+ return self ._return_with_redshift (self .subtract (other ))
866
+ else :
867
+ new_flux = self .flux - other
868
+ return self ._return_with_redshift (Spectrum (new_flux , wcs = self .wcs , meta = self .meta ,
869
+ uncertainty = self .uncertainty ))
845
870
846
- return self ._return_with_redshift (self .multiply (other ))
871
+ def __mul__ (self , other ):
872
+ other = self ._other_as_correct_class (other )
873
+ if isinstance (other , (Spectrum )):
874
+ return self ._return_with_redshift (self .multiply (other ))
875
+ else :
876
+ new_flux = self .flux * other
877
+ if self .uncertainty is None :
878
+ new_uncertainty = None
879
+ else :
880
+ new_uncertainty = deepcopy (self .uncertainty )
881
+ new_uncertainty .array *= other
882
+ return self ._return_with_redshift (Spectrum (new_flux , wcs = self .wcs ,
883
+ meta = self .meta ,
884
+ uncertainty = new_uncertainty ))
847
885
848
886
def __div__ (self , other ):
849
- if not isinstance (other , NDCube ):
850
- other = u .Quantity (other )
851
-
852
- return self ._return_with_redshift (self .divide (other ))
887
+ other = self ._other_as_correct_class (other )
888
+ if isinstance (other , (Spectrum )):
889
+ return self ._return_with_redshift (self .divide (other ))
890
+ else :
891
+ new_flux = self .flux / other
892
+ if self .uncertainty is None :
893
+ new_uncertainty = None
894
+ else :
895
+ new_uncertainty = deepcopy (self .uncertainty )
896
+ new_uncertainty .array /= other
897
+ return self ._return_with_redshift (Spectrum (new_flux , wcs = self .wcs ,
898
+ meta = self .meta ,
899
+ uncertainty = self .uncertainty / other ))
853
900
854
901
def __truediv__ (self , other ):
855
- if not isinstance (other , NDCube ):
856
- other = u . Quantity (other )
902
+ if not isinstance (other , Spectrum ):
903
+ other = self . _other_as_correct_class (other )
857
904
858
905
return self ._return_with_redshift (self .divide (other ))
859
906
@@ -901,11 +948,15 @@ def __repr__(self):
901
948
flux_str += f" { self .flux .unit } "
902
949
903
950
flux_str += f" (shape={ self .flux .shape } , mean={ np .nanmean (self .flux ):.5f} ); "
904
- spectral_axis_str = (repr (self .spectral_axis ).split ("[" )[0 ] +
905
- np .array2string (self .spectral_axis , threshold = 8 ) +
906
- f" { self .spectral_axis .unit } >" )
907
- spectral_axis_str = f"spectral_axis={ spectral_axis_str } (length={ len (self .spectral_axis )} )"
908
- inner_str = (flux_str + spectral_axis_str )
951
+ # Sometimes this errors if an error occurs during initialization
952
+ if hasattr (self , "_spectral_axis" ):
953
+ spectral_axis_str = (repr (self .spectral_axis ).split ("[" )[0 ] +
954
+ np .array2string (self .spectral_axis , threshold = 8 ) +
955
+ f" { self .spectral_axis .unit } >" )
956
+ spectral_axis_str = f"spectral_axis={ spectral_axis_str } (length={ len (self .spectral_axis )} )"
957
+ inner_str = (flux_str + spectral_axis_str )
958
+ else :
959
+ inner_str = flux_str
909
960
910
961
if self .uncertainty is not None :
911
962
inner_str += f"; uncertainty={ self .uncertainty .__class__ .__name__ } "
0 commit comments