Skip to content

Commit beb4698

Browse files
committed
Refactor gwcs_from_array to provide ND GWCS in ND flux case (astropy#1211)
* Working on constructing N dimensional GWCS for ND flux with only spectral axis provided * Get forward transform working for ND case * Convert other operator to Spectrum if needed in specutils where we have more control, rather than in astropy ndarithmetic * Don't pass extra arg to ndarithmetic * Use dimensionless instead of pix, remove unit conversion code that I think was unneeded and was causing errors * Fix two tests for multidimensional GWCS, remove debugging print * Working on debugging SpectrumCollection * Check for Spectrum specifically before arithmetic * Remove incorrect index from spectrum collection slice, cast operand as Spectrum even in 1D case * Don't pass multi-d WCS to collapsed spectrum * Fixing more test failures * Revamp arithmetic * Add changelog * remove debugging print * Handle creating dummy WCS and spectral axis when neither are provided, fix order of WCS * Codestyle * Remove debugging prints * Fixes for compatibility with up to date gwcs * Bump min version pins Bump min asdf Bump min astropy Bump min scipy Bump min asdf-astropy * Need gwcs 0.24 * Add note about arithmetic to changelog * Fix Spectrum1D references in doc page Fix skips in doc file
1 parent c36cf91 commit beb4698

13 files changed

+194
-83
lines changed

CHANGES.rst

+7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ New Features
1717
Other Changes and Additions
1818
^^^^^^^^^^^^^^^^^^^^^^^^^^^
1919

20+
- Initializing a ``Spectrum`` with only a spectral axis (not full WCS) will now
21+
result in a GWCS matching the dimensionality of the flux array, rather than a
22+
1D spectral GWCS in all cases. [#1211]
23+
24+
- Spectrum arithmetic now checks whether the spectral axes of the two operand ``Spectrum``
25+
objects are equal, and fails if they are not. [#1211]
26+
2027
1.20.0 (unreleased)
2128
-------------------
2229

docs/spectral_regions.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ Reference/API
349349
:no-heading:
350350
:no-inheritance-diagram:
351351

352-
:skip: test
352+
:skip: QTable
353353
:skip: Spectrum
354354
:skip: SpectrumCollection
355355
:skip: SpectralAxis

docs/spectrum.rst

+5-8
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ installed, which is an optional dependency for ``specutils``.
9494

9595
Call the help function for a specific loader to access further documentation
9696
on that format and optional parameters accepted by the ``read`` function,
97-
e.g. as ``Spectrum1D.read.help('tabular-fits')``. Additional optional parameters
97+
e.g. as ``Spectrum.read.help('tabular-fits')``. Additional optional parameters
9898
are generally passed through to the backend functions performing the actual
9999
reading operation, which depend on the loader. For loaders for FITS files for example,
100100
this will often be :func:`astropy.io.fits.open`.
@@ -114,7 +114,7 @@ by using the :meth:`specutils.Spectrum.write` method.
114114
>>> spec1d.write("/path/to/output.fits") # doctest: +SKIP
115115
116116
Note that the above example, calling ``write()`` without specifying
117-
any format, will default to the ``wcs1d-fits`` loader if the `~specutils.Spectrum1D`
117+
any format, will default to the ``wcs1d-fits`` loader if the `~specutils.Spectrum`
118118
has a compatible WCS, and to ``tabular-fits`` otherwise, or if writing
119119
to another than the primary HDU (``hdu=0``) has been selected.
120120
For better control of the file type, the ``format`` parameter should be explicitly passed.
@@ -126,14 +126,14 @@ which for the FITS writers is :meth:`astropy.io.fits.HDUList.writeto`.
126126
Metadata
127127
--------
128128

129-
The :attr:`specutils.Spectrum1D.meta` attribute provides a dictionary to store
129+
The :attr:`specutils.Spectrum.meta` attribute provides a dictionary to store
130130
additional information on the data, like origin, date and other circumstances.
131131
For spectra read from files containing header-like attributes like a FITS
132132
:class:`~astropy.io.fits.Header` or :attr:`astropy.table.Table.meta`,
133-
loaders are conventionally storing this in ``Spectrum1D.meta['header']``.
133+
loaders are conventionally storing this in ``Spectrum.meta['header']``.
134134

135135
The two provided FITS writers (``tabular-fits`` and ``wcs1d-fits``) save the contents of
136-
``Spectrum1D.meta['header']`` (which should be an :class:`astropy.io.fits.Header`
136+
``Spectrum.meta['header']`` (which should be an :class:`astropy.io.fits.Header`
137137
or any object, like a `dict`, that can instantiate one) as the header of the
138138
:class:`~astropy.io.fits.hdu.PrimaryHDU`.
139139

@@ -406,9 +406,6 @@ spectral axis, or 'spatial', which will collapse along all non-spectral axes.
406406
>>> spec.mean(axis='spatial') # doctest: +FLOAT_CMP
407407
<Spectrum(flux=<Quantity [0.37273938, 0.53843905, 0.61351648, 0.57311623, 0.44339915,
408408
0.66084728, 0.45881921, 0.38715911, 0.39967185, 0.53257671] Jy> (shape=(10,), mean=0.49803 Jy); spectral_axis=<SpectralAxis
409-
(observer to target:
410-
radial_velocity=0.0 km / s
411-
redshift=0.0)
412409
[5000. 5001. 5002. ... 5007. 5008. 5009.] Angstrom> (length=10))>
413410
414411
Note that in this case, the result of the collapse operation is a

docs/spectrum_collection.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ solution.
2323
2424
>>> flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
2525
>>> spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)), unit='AA')
26-
>>> wcs = np.array([gwcs_from_array(x) for x in spectral_axis])
26+
>>> wcs = np.array([gwcs_from_array(x, x.shape) for x in spectral_axis])
2727
>>> uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
2828
>>> mask = np.ones((5, 10)).astype(bool)
2929
>>> meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]

setup.cfg

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ packages = find:
1717
python_requires = >=3.10
1818
install_requires =
1919
numpy>=1.24
20-
scipy>=1.3
21-
astropy>=5.1
22-
gwcs>=0.18
23-
asdf-astropy>=0.3
24-
asdf>=2.14.4
20+
scipy>=1.14
21+
astropy>=6.0
22+
gwcs>=0.22
23+
asdf-astropy>=0.5
24+
asdf>=3.3.0
2525
ndcube>=2.0
2626

2727
[options.extras_require]

specutils/spectra/spectrum.py

+86-35
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
8989
# If the flux (data) argument is already a Spectrum (as it would
9090
# be for internal arithmetic operations), avoid setup entirely.
9191
if isinstance(flux, Spectrum):
92+
self._spectral_axis_index = flux.spectral_axis_index
93+
self._spectral_axis = flux.spectral_axis
9294
super().__init__(flux)
9395
return
9496

@@ -157,9 +159,7 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
157159
# In the case where the arithmetic operation is being performed with
158160
# a single float, int, or array object, just go ahead and ignore wcs
159161
# requirements
160-
if (not isinstance(flux, u.Quantity) or isinstance(flux, float)
161-
or isinstance(flux, int)) and np.ndim(flux) == 0:
162-
162+
if np.ndim(flux) == 0 and spectral_axis is None and wcs is None:
163163
super(Spectrum, self).__init__(data=flux, wcs=wcs, **kwargs)
164164
return
165165

@@ -332,7 +332,10 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
332332
self._spectral_axis = spectral_axis
333333

334334
if wcs is None:
335-
wcs = gwcs_from_array(self._spectral_axis)
335+
wcs = gwcs_from_array(self._spectral_axis,
336+
flux.shape,
337+
spectral_axis_index=self.spectral_axis_index
338+
)
336339

337340
elif wcs is None:
338341
# If no spectral axis or wcs information is provided, initialize
@@ -344,7 +347,10 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
344347
raise ValueError("Must specify spectral_axis_index if no WCS or spectral"
345348
" axis is input.")
346349
size = flux.shape[self.spectral_axis_index] if not flux.isscalar else 1
347-
wcs = gwcs_from_array(np.arange(size) * u.Unit(""))
350+
wcs = gwcs_from_array(np.arange(size) * u.Unit(""),
351+
flux.shape,
352+
spectral_axis_index=self.spectral_axis_index
353+
)
348354

349355
super().__init__(
350356
data=flux.value if isinstance(flux, u.Quantity) else flux,
@@ -379,6 +385,10 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
379385
for coords in temp_coords:
380386
if isinstance(coords, SpectralCoord):
381387
spec_axis = coords
388+
break
389+
else:
390+
# WCS axis ordering is reverse of numpy
391+
spec_axis = temp_coords[len(temp_coords) - self.spectral_axis_index - 1]
382392
else:
383393
spec_axis = temp_coords
384394

@@ -646,7 +656,9 @@ def collapse(self, method, axis=None):
646656
elif isinstance(axis, tuple) and self.spectral_axis_index in axis:
647657
return collapsed_flux
648658
else:
649-
return Spectrum(collapsed_flux, wcs=self.wcs)
659+
# Pass the spectral axis rather than WCS in this case, so we don't have to
660+
# figure out which part of a multidimensional WCS is the spectral part.
661+
return Spectrum(collapsed_flux, spectral_axis=self.spectral_axis)
650662

651663
def mean(self, **kwargs):
652664
return self.collapse("mean", **kwargs)
@@ -821,39 +833,74 @@ def _return_with_redshift(self, result):
821833
result.shift_spectrum_to(redshift=self.redshift)
822834
return result
823835

824-
def __add__(self, other):
825-
if not isinstance(other, (NDCube, u.Quantity)):
826-
try:
827-
other = u.Quantity(other, unit=self.unit)
828-
except TypeError:
829-
return NotImplemented
836+
def _other_as_correct_class(self, other, force_quantity=False):
837+
# NDArithmetic mixin will try to turn other into a Spectrum, which will fail
838+
# sometimes because of not specifiying the spectral axis index
839+
if isinstance(other, Spectrum):
840+
# Take this opportunity to check if the spectral axes match
841+
if not np.all(other.spectral_axis == self.spectral_axis):
842+
raise ValueError("Spectral axis of both operands must match")
843+
else:
844+
if not isinstance(other, u.Quantity) and force_quantity:
845+
other = other * self.unit
830846

831-
return self._return_with_redshift(self.add(other))
847+
if isinstance(other, u.Quantity) and other.shape == self.shape:
848+
return Spectrum(flux=other, spectral_axis=self.spectral_axis,
849+
spectral_axis_index=self.spectral_axis_index)
832850

833-
def __sub__(self, other):
834-
if not isinstance(other, NDCube):
835-
try:
836-
other = u.Quantity(other, unit=self.unit)
837-
except TypeError:
838-
return NotImplemented
851+
return other
839852

840-
return self._return_with_redshift(self.subtract(other))
853+
def __add__(self, other):
854+
other = self._other_as_correct_class(other, force_quantity=True)
855+
if isinstance(other, (Spectrum)):
856+
return self._return_with_redshift(self.add(other))
857+
else:
858+
new_flux = self.flux + other
859+
return self._return_with_redshift(Spectrum(new_flux, wcs=self.wcs, meta=self.meta,
860+
uncertainty=self.uncertainty))
841861

842-
def __mul__(self, other):
843-
if not isinstance(other, NDCube):
844-
other = u.Quantity(other)
862+
def __sub__(self, other):
863+
other = self._other_as_correct_class(other, force_quantity=True)
864+
if isinstance(other, (Spectrum)):
865+
return self._return_with_redshift(self.subtract(other))
866+
else:
867+
new_flux = self.flux - other
868+
return self._return_with_redshift(Spectrum(new_flux, wcs=self.wcs, meta=self.meta,
869+
uncertainty=self.uncertainty))
845870

846-
return self._return_with_redshift(self.multiply(other))
871+
def __mul__(self, other):
872+
other = self._other_as_correct_class(other)
873+
if isinstance(other, (Spectrum)):
874+
return self._return_with_redshift(self.multiply(other))
875+
else:
876+
new_flux = self.flux * other
877+
if self.uncertainty is None:
878+
new_uncertainty = None
879+
else:
880+
new_uncertainty = deepcopy(self.uncertainty)
881+
new_uncertainty.array *= other
882+
return self._return_with_redshift(Spectrum(new_flux, wcs=self.wcs,
883+
meta=self.meta,
884+
uncertainty=new_uncertainty))
847885

848886
def __div__(self, other):
849-
if not isinstance(other, NDCube):
850-
other = u.Quantity(other)
851-
852-
return self._return_with_redshift(self.divide(other))
887+
other = self._other_as_correct_class(other)
888+
if isinstance(other, (Spectrum)):
889+
return self._return_with_redshift(self.divide(other))
890+
else:
891+
new_flux = self.flux / other
892+
if self.uncertainty is None:
893+
new_uncertainty = None
894+
else:
895+
new_uncertainty = deepcopy(self.uncertainty)
896+
new_uncertainty.array /= other
897+
return self._return_with_redshift(Spectrum(new_flux, wcs=self.wcs,
898+
meta=self.meta,
899+
uncertainty=self.uncertainty/other))
853900

854901
def __truediv__(self, other):
855-
if not isinstance(other, NDCube):
856-
other = u.Quantity(other)
902+
if not isinstance(other, Spectrum):
903+
other = self._other_as_correct_class(other)
857904

858905
return self._return_with_redshift(self.divide(other))
859906

@@ -901,11 +948,15 @@ def __repr__(self):
901948
flux_str += f" {self.flux.unit}"
902949

903950
flux_str += f" (shape={self.flux.shape}, mean={np.nanmean(self.flux):.5f}); "
904-
spectral_axis_str = (repr(self.spectral_axis).split("[")[0] +
905-
np.array2string(self.spectral_axis, threshold=8) +
906-
f" {self.spectral_axis.unit}>")
907-
spectral_axis_str = f"spectral_axis={spectral_axis_str} (length={len(self.spectral_axis)})"
908-
inner_str = (flux_str + spectral_axis_str)
951+
# Sometimes this errors if an error occurs during initialization
952+
if hasattr(self, "_spectral_axis"):
953+
spectral_axis_str = (repr(self.spectral_axis).split("[")[0] +
954+
np.array2string(self.spectral_axis, threshold=8) +
955+
f" {self.spectral_axis.unit}>")
956+
spectral_axis_str = f"spectral_axis={spectral_axis_str} (length={len(self.spectral_axis)})"
957+
inner_str = (flux_str + spectral_axis_str)
958+
else:
959+
inner_str = flux_str
909960

910961
if self.uncertainty is not None:
911962
inner_str += f"; uncertainty={self.uncertainty.__class__.__name__}"

specutils/spectra/spectrum_collection.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class SpectrumCollection(NDIOMixin):
5353
each spectrum in the collection.
5454
"""
5555
def __init__(self, flux, spectral_axis=None, wcs=None, uncertainty=None,
56-
mask=None, meta=None):
56+
mask=None, meta=None, spectral_axis_index=None):
5757
# Check for quantity
5858
if not isinstance(flux, u.Quantity):
5959
raise u.UnitsError("Flux must be a `Quantity`.")
@@ -89,6 +89,7 @@ def __init__(self, flux, spectral_axis=None, wcs=None, uncertainty=None,
8989

9090
self._flux = flux
9191
self._spectral_axis = spectral_axis
92+
self._spectral_axis_index = spectral_axis_index
9293
self._wcs = wcs
9394
self._uncertainty = uncertainty
9495
self._mask = mask
@@ -153,6 +154,8 @@ def from_spectra(cls, spectra):
153154
observer=sa[0].observer,
154155
target=sa[0].target)
155156

157+
spectral_axis_index = spectra[0].spectral_axis_index
158+
156159
# Check that either all spectra have associated uncertainties, or that
157160
# none of them do. If only some do, log an error and ignore the
158161
# uncertainties.
@@ -183,7 +186,8 @@ def from_spectra(cls, spectra):
183186
meta = [spec.meta for spec in spectra]
184187

185188
return cls(flux=flux, spectral_axis=spectral_axis,
186-
uncertainty=uncertainty, wcs=wcs, mask=mask, meta=meta)
189+
uncertainty=uncertainty, wcs=wcs, mask=mask, meta=meta,
190+
spectral_axis_index=spectral_axis_index)
187191

188192
@property
189193
def flux(self):
@@ -195,6 +199,10 @@ def spectral_axis(self):
195199
"""The spectral axes as a `~astropy.units.Quantity`."""
196200
return self._spectral_axis
197201

202+
@property
203+
def spectral_axis_index(self):
204+
return self._spectral_axis_index
205+
198206
@property
199207
def frequency(self):
200208
"""

specutils/tests/test_arithmetic.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import astropy.units as u
22
from astropy.tests.helper import assert_quantity_allclose
33
import numpy as np
4+
import pytest
45

56
from ..spectra.spectrum import Spectrum
67

@@ -89,8 +90,9 @@ def test_multiplication_basic_spectra(simulated_spectra):
8990

9091
def test_add_diff_spectral_axis(simulated_spectra):
9192

92-
# Calculate using the spectrum/nddata code
93-
spec3 = simulated_spectra.s1_um_mJy_e1 + simulated_spectra.s1_AA_mJy_e3 # noqa
93+
# We now raise an error if the spectra aren't on the same spectral axis
94+
with pytest.raises(ValueError, match="Spectral axis of both operands must match"):
95+
spec3 = simulated_spectra.s1_um_mJy_e1 + simulated_spectra.s1_AA_mJy_e3 # noqa
9496

9597

9698
def test_masks(simulated_spectra):

specutils/tests/test_manipulation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_snr_threshold():
129129
np.random.seed(42)
130130
flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
131131
spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)), unit='AA')
132-
wcs = np.array([gwcs_from_array(x) for x in spectral_axis])
132+
wcs = np.array([gwcs_from_array(x, [10,]) for x in spectral_axis])
133133
uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
134134
mask = np.ones((5, 10)).astype(bool)
135135
meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]

specutils/tests/test_slicing.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def test_spectral_axes():
2424
sliced_spec2 = spec2[0]
2525

2626
assert isinstance(sliced_spec2, Spectrum)
27-
assert_allclose(sliced_spec2.wcs.pixel_to_world(np.arange(10)), spec2.wcs.pixel_to_world(np.arange(10)))
27+
assert_allclose(sliced_spec2.wcs.pixel_to_world(np.arange(10)),
28+
spec2.wcs.pixel_to_world(np.arange(10), [0,]*10)[0])
2829
assert sliced_spec2.flux.shape[0] == 49
2930

3031

@@ -107,4 +108,4 @@ def test_slicing_multidim():
107108
assert spec1.mask.shape == (10,)
108109

109110
assert quantity_allclose(spec3.spectral_axis, spec.spectral_axis[4:7])
110-
assert quantity_allclose(spec3.wcs.pixel_to_world([0,1,2]), spec3.spectral_axis[0:3])
111+
assert quantity_allclose(spec3.wcs.pixel_to_world([0, 1, 2], [0, 0, 0])[0], spec3.spectral_axis[0:3])

specutils/tests/test_spectrum_collection.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
def spectrum_collection():
1616
flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
1717
spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)) + 1, unit='AA')
18-
wcs = np.array([gwcs_from_array(x) for x in spectral_axis])
18+
wcs = np.array([gwcs_from_array(x, flux.shape, spectral_axis_index=1) for x in spectral_axis])
1919
uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
2020
mask = np.ones((5, 10)).astype(bool)
2121
meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]
2222

2323
spec_coll = SpectrumCollection(
2424
flux=flux, spectral_axis=spectral_axis, wcs=wcs,
25-
uncertainty=uncertainty, mask=mask, meta=meta)
25+
uncertainty=uncertainty, mask=mask, meta=meta,
26+
spectral_axis_index=1)
2627

2728
return spec_coll
2829

@@ -59,7 +60,7 @@ def test_collection_without_optional_arguments():
5960
flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
6061
spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)) + 1, unit='AA')
6162
uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
62-
wcs = np.array([gwcs_from_array(x) for x in spectral_axis])
63+
wcs = np.array([gwcs_from_array(x, flux.shape, spectral_axis_index=1) for x in spectral_axis])
6364
mask = np.ones((5, 10)).astype(bool)
6465
meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]
6566

0 commit comments

Comments
 (0)