Skip to content

Refactor gwcs_from_array to provide ND GWCS in ND flux case #1211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
334f2dc
Working on constructing N dimensional GWCS for ND flux with only spec…
rosteen Jan 31, 2025
0547b4a
Get forward transform working for ND case
rosteen Jan 31, 2025
babc123
Convert other operator to Spectrum if needed in specutils where we ha…
rosteen Feb 3, 2025
429d87a
Don't pass extra arg to ndarithmetic
rosteen Feb 3, 2025
dbe74b6
Use dimensionless instead of pix, remove unit conversion code that I …
rosteen Feb 3, 2025
bb68350
Fix two tests for multidimensional GWCS, remove debugging print
rosteen Feb 3, 2025
02f7c3e
Working on debugging SpectrumCollection
rosteen Feb 3, 2025
2e332a7
Check for Spectrum specifically before arithmetic
rosteen Feb 3, 2025
35e991a
Remove incorrect index from spectrum collection slice, cast operand a…
rosteen Feb 4, 2025
c3424bb
Don't pass multi-d WCS to collapsed spectrum
rosteen Feb 4, 2025
52137fb
Fixing more test failures
rosteen Feb 4, 2025
fd52aef
Revamp arithmetic
rosteen Feb 4, 2025
4460d26
Add changelog
rosteen Feb 4, 2025
e74b2a1
remove debugging print
rosteen Feb 4, 2025
368d0fe
Handle creating dummy WCS and spectral axis when neither are provided…
rosteen Feb 6, 2025
f856f3a
Codestyle
rosteen Feb 6, 2025
4708bf5
Remove debugging prints
rosteen Feb 6, 2025
64d1180
Fixes for compatibility with up to date gwcs
rosteen Feb 7, 2025
55cbb3e
Bump min version pins
rosteen Feb 7, 2025
35ea81a
Need gwcs 0.24
rosteen Feb 9, 2025
afb7cb4
Add note about arithmetic to changelog
rosteen Mar 4, 2025
75aa0d7
Fix Spectrum1D references in doc page
rosteen Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ New Features
Other Changes and Additions
^^^^^^^^^^^^^^^^^^^^^^^^^^^

- Initializing a ``Spectrum`` with only a spectral axis (not full WCS) will now
result in a GWCS matching the dimensionality of the flux array, rather than a
1D spectral GWCS in all cases. [#1211]

- Spectrum arithmetic now checks whether the spectral axes of the two operand ``Spectrum``
objects are equal, and fails if they are not. [#1211]

1.20.0 (unreleased)
-------------------

Expand Down
13 changes: 5 additions & 8 deletions docs/spectrum.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ installed, which is an optional dependency for ``specutils``.

Call the help function for a specific loader to access further documentation
on that format and optional parameters accepted by the ``read`` function,
e.g. as ``Spectrum1D.read.help('tabular-fits')``. Additional optional parameters
e.g. as ``Spectrum.read.help('tabular-fits')``. Additional optional parameters
are generally passed through to the backend functions performing the actual
reading operation, which depend on the loader. For loaders for FITS files for example,
this will often be :func:`astropy.io.fits.open`.
Expand All @@ -114,7 +114,7 @@ by using the :meth:`specutils.Spectrum.write` method.
>>> spec1d.write("/path/to/output.fits") # doctest: +SKIP

Note that the above example, calling ``write()`` without specifying
any format, will default to the ``wcs1d-fits`` loader if the `~specutils.Spectrum1D`
any format, will default to the ``wcs1d-fits`` loader if the `~specutils.Spectrum`
has a compatible WCS, and to ``tabular-fits`` otherwise, or if writing
to another than the primary HDU (``hdu=0``) has been selected.
For better control of the file type, the ``format`` parameter should be explicitly passed.
Expand All @@ -126,14 +126,14 @@ which for the FITS writers is :meth:`astropy.io.fits.HDUList.writeto`.
Metadata
--------

The :attr:`specutils.Spectrum1D.meta` attribute provides a dictionary to store
The :attr:`specutils.Spectrum.meta` attribute provides a dictionary to store
additional information on the data, like origin, date and other circumstances.
For spectra read from files containing header-like attributes like a FITS
:class:`~astropy.io.fits.Header` or :attr:`astropy.table.Table.meta`,
loaders are conventionally storing this in ``Spectrum1D.meta['header']``.
loaders are conventionally storing this in ``Spectrum.meta['header']``.

The two provided FITS writers (``tabular-fits`` and ``wcs1d-fits``) save the contents of
``Spectrum1D.meta['header']`` (which should be an :class:`astropy.io.fits.Header`
``Spectrum.meta['header']`` (which should be an :class:`astropy.io.fits.Header`
or any object, like a `dict`, that can instantiate one) as the header of the
:class:`~astropy.io.fits.hdu.PrimaryHDU`.

Expand Down Expand Up @@ -406,9 +406,6 @@ spectral axis, or 'spatial', which will collapse along all non-spectral axes.
>>> spec.mean(axis='spatial') # doctest: +FLOAT_CMP
<Spectrum(flux=<Quantity [0.37273938, 0.53843905, 0.61351648, 0.57311623, 0.44339915,
0.66084728, 0.45881921, 0.38715911, 0.39967185, 0.53257671] Jy> (shape=(10,), mean=0.49803 Jy); spectral_axis=<SpectralAxis
(observer to target:
radial_velocity=0.0 km / s
redshift=0.0)
Comment on lines -409 to -411
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed? Is it an incidental thing you noticed at the time (no prob if so) or a result of this change somehow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall, I'll double check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was due to creating the new collapsed Spectrum with self.spectral_axis rather than self.wcs now, and is actually probably more correct than the previously printed out repr, which was replacing None values for these quantities with 0 in the process of collapsing (implying information we actually don't have).

[5000. 5001. 5002. ... 5007. 5008. 5009.] Angstrom> (length=10))>

Note that in this case, the result of the collapse operation is a
Expand Down
2 changes: 1 addition & 1 deletion docs/spectrum_collection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ solution.

>>> flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
>>> spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)), unit='AA')
>>> wcs = np.array([gwcs_from_array(x) for x in spectral_axis])
>>> wcs = np.array([gwcs_from_array(x, x.shape) for x in spectral_axis])
>>> uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
>>> mask = np.ones((5, 10)).astype(bool)
>>> meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]
Expand Down
10 changes: 5 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ packages = find:
python_requires = >=3.10
install_requires =
numpy>=1.24
scipy>=1.3
astropy>=5.1
gwcs>=0.18
asdf-astropy>=0.3
asdf>=2.14.4
scipy>=1.14
astropy>=6.0
gwcs>=0.22
asdf-astropy>=0.5
asdf>=3.3.0
ndcube>=2.0

[options.extras_require]
Expand Down
121 changes: 86 additions & 35 deletions specutils/spectra/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
# If the flux (data) argument is already a Spectrum (as it would
# be for internal arithmetic operations), avoid setup entirely.
if isinstance(flux, Spectrum):
self._spectral_axis_index = flux.spectral_axis_index
self._spectral_axis = flux.spectral_axis
super().__init__(flux)
return

Expand Down Expand Up @@ -157,9 +159,7 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
# In the case where the arithmetic operation is being performed with
# a single float, int, or array object, just go ahead and ignore wcs
# requirements
if (not isinstance(flux, u.Quantity) or isinstance(flux, float)
or isinstance(flux, int)) and np.ndim(flux) == 0:

if np.ndim(flux) == 0 and spectral_axis is None and wcs is None:
super(Spectrum, self).__init__(data=flux, wcs=wcs, **kwargs)
return

Expand Down Expand Up @@ -332,7 +332,10 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
self._spectral_axis = spectral_axis

if wcs is None:
wcs = gwcs_from_array(self._spectral_axis)
wcs = gwcs_from_array(self._spectral_axis,
flux.shape,
spectral_axis_index=self.spectral_axis_index
)

elif wcs is None:
# If no spectral axis or wcs information is provided, initialize
Expand All @@ -344,7 +347,10 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
raise ValueError("Must specify spectral_axis_index if no WCS or spectral"
" axis is input.")
size = flux.shape[self.spectral_axis_index] if not flux.isscalar else 1
wcs = gwcs_from_array(np.arange(size) * u.Unit(""))
wcs = gwcs_from_array(np.arange(size) * u.Unit(""),
flux.shape,
spectral_axis_index=self.spectral_axis_index
)

super().__init__(
data=flux.value if isinstance(flux, u.Quantity) else flux,
Expand Down Expand Up @@ -379,6 +385,10 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
for coords in temp_coords:
if isinstance(coords, SpectralCoord):
spec_axis = coords
break
else:
# WCS axis ordering is reverse of numpy
spec_axis = temp_coords[len(temp_coords) - self.spectral_axis_index - 1]
else:
spec_axis = temp_coords

Expand Down Expand Up @@ -646,7 +656,9 @@ def collapse(self, method, axis=None):
elif isinstance(axis, tuple) and self.spectral_axis_index in axis:
return collapsed_flux
else:
return Spectrum(collapsed_flux, wcs=self.wcs)
# Pass the spectral axis rather than WCS in this case, so we don't have to
# figure out which part of a multidimensional WCS is the spectral part.
return Spectrum(collapsed_flux, spectral_axis=self.spectral_axis)

def mean(self, **kwargs):
return self.collapse("mean", **kwargs)
Expand Down Expand Up @@ -821,39 +833,74 @@ def _return_with_redshift(self, result):
result.shift_spectrum_to(redshift=self.redshift)
return result

def __add__(self, other):
if not isinstance(other, (NDCube, u.Quantity)):
try:
other = u.Quantity(other, unit=self.unit)
except TypeError:
return NotImplemented
def _other_as_correct_class(self, other, force_quantity=False):
# NDArithmetic mixin will try to turn other into a Spectrum, which will fail
# sometimes because of not specifiying the spectral axis index
if isinstance(other, Spectrum):
# Take this opportunity to check if the spectral axes match
if not np.all(other.spectral_axis == self.spectral_axis):
raise ValueError("Spectral axis of both operands must match")
else:
if not isinstance(other, u.Quantity) and force_quantity:
other = other * self.unit

return self._return_with_redshift(self.add(other))
if isinstance(other, u.Quantity) and other.shape == self.shape:
return Spectrum(flux=other, spectral_axis=self.spectral_axis,
spectral_axis_index=self.spectral_axis_index)

def __sub__(self, other):
if not isinstance(other, NDCube):
try:
other = u.Quantity(other, unit=self.unit)
except TypeError:
return NotImplemented
return other

return self._return_with_redshift(self.subtract(other))
def __add__(self, other):
other = self._other_as_correct_class(other, force_quantity=True)
if isinstance(other, (Spectrum)):
return self._return_with_redshift(self.add(other))
else:
new_flux = self.flux + other
return self._return_with_redshift(Spectrum(new_flux, wcs=self.wcs, meta=self.meta,
uncertainty=self.uncertainty))

def __mul__(self, other):
if not isinstance(other, NDCube):
other = u.Quantity(other)
def __sub__(self, other):
other = self._other_as_correct_class(other, force_quantity=True)
if isinstance(other, (Spectrum)):
return self._return_with_redshift(self.subtract(other))
else:
new_flux = self.flux - other
return self._return_with_redshift(Spectrum(new_flux, wcs=self.wcs, meta=self.meta,
uncertainty=self.uncertainty))

return self._return_with_redshift(self.multiply(other))
def __mul__(self, other):
other = self._other_as_correct_class(other)
if isinstance(other, (Spectrum)):
return self._return_with_redshift(self.multiply(other))
else:
new_flux = self.flux * other
if self.uncertainty is None:
new_uncertainty = None
else:
new_uncertainty = deepcopy(self.uncertainty)
new_uncertainty.array *= other
return self._return_with_redshift(Spectrum(new_flux, wcs=self.wcs,
meta=self.meta,
uncertainty=new_uncertainty))

def __div__(self, other):
if not isinstance(other, NDCube):
other = u.Quantity(other)

return self._return_with_redshift(self.divide(other))
other = self._other_as_correct_class(other)
if isinstance(other, (Spectrum)):
return self._return_with_redshift(self.divide(other))
else:
new_flux = self.flux / other
if self.uncertainty is None:
new_uncertainty = None
else:
new_uncertainty = deepcopy(self.uncertainty)
new_uncertainty.array /= other
return self._return_with_redshift(Spectrum(new_flux, wcs=self.wcs,
meta=self.meta,
uncertainty=self.uncertainty/other))

def __truediv__(self, other):
if not isinstance(other, NDCube):
other = u.Quantity(other)
if not isinstance(other, Spectrum):
other = self._other_as_correct_class(other)

return self._return_with_redshift(self.divide(other))

Expand Down Expand Up @@ -901,11 +948,15 @@ def __repr__(self):
flux_str += f" {self.flux.unit}"

flux_str += f" (shape={self.flux.shape}, mean={np.nanmean(self.flux):.5f}); "
spectral_axis_str = (repr(self.spectral_axis).split("[")[0] +
np.array2string(self.spectral_axis, threshold=8) +
f" {self.spectral_axis.unit}>")
spectral_axis_str = f"spectral_axis={spectral_axis_str} (length={len(self.spectral_axis)})"
inner_str = (flux_str + spectral_axis_str)
# Sometimes this errors if an error occurs during initialization
if hasattr(self, "_spectral_axis"):
spectral_axis_str = (repr(self.spectral_axis).split("[")[0] +
np.array2string(self.spectral_axis, threshold=8) +
f" {self.spectral_axis.unit}>")
spectral_axis_str = f"spectral_axis={spectral_axis_str} (length={len(self.spectral_axis)})"
inner_str = (flux_str + spectral_axis_str)
else:
inner_str = flux_str

if self.uncertainty is not None:
inner_str += f"; uncertainty={self.uncertainty.__class__.__name__}"
Expand Down
12 changes: 10 additions & 2 deletions specutils/spectra/spectrum_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class SpectrumCollection(NDIOMixin):
each spectrum in the collection.
"""
def __init__(self, flux, spectral_axis=None, wcs=None, uncertainty=None,
mask=None, meta=None):
mask=None, meta=None, spectral_axis_index=None):
# Check for quantity
if not isinstance(flux, u.Quantity):
raise u.UnitsError("Flux must be a `Quantity`.")
Expand Down Expand Up @@ -89,6 +89,7 @@ def __init__(self, flux, spectral_axis=None, wcs=None, uncertainty=None,

self._flux = flux
self._spectral_axis = spectral_axis
self._spectral_axis_index = spectral_axis_index
self._wcs = wcs
self._uncertainty = uncertainty
self._mask = mask
Expand Down Expand Up @@ -153,6 +154,8 @@ def from_spectra(cls, spectra):
observer=sa[0].observer,
target=sa[0].target)

spectral_axis_index = spectra[0].spectral_axis_index

# Check that either all spectra have associated uncertainties, or that
# none of them do. If only some do, log an error and ignore the
# uncertainties.
Expand Down Expand Up @@ -183,7 +186,8 @@ def from_spectra(cls, spectra):
meta = [spec.meta for spec in spectra]

return cls(flux=flux, spectral_axis=spectral_axis,
uncertainty=uncertainty, wcs=wcs, mask=mask, meta=meta)
uncertainty=uncertainty, wcs=wcs, mask=mask, meta=meta,
spectral_axis_index=spectral_axis_index)

@property
def flux(self):
Expand All @@ -195,6 +199,10 @@ def spectral_axis(self):
"""The spectral axes as a `~astropy.units.Quantity`."""
return self._spectral_axis

@property
def spectral_axis_index(self):
return self._spectral_axis_index

@property
def frequency(self):
"""
Expand Down
6 changes: 4 additions & 2 deletions specutils/tests/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import astropy.units as u
from astropy.tests.helper import assert_quantity_allclose
import numpy as np
import pytest

from ..spectra.spectrum import Spectrum

Expand Down Expand Up @@ -89,8 +90,9 @@ def test_multiplication_basic_spectra(simulated_spectra):

def test_add_diff_spectral_axis(simulated_spectra):

# Calculate using the spectrum/nddata code
spec3 = simulated_spectra.s1_um_mJy_e1 + simulated_spectra.s1_AA_mJy_e3 # noqa
# We now raise an error if the spectra aren't on the same spectral axis
with pytest.raises(ValueError, match="Spectral axis of both operands must match"):
spec3 = simulated_spectra.s1_um_mJy_e1 + simulated_spectra.s1_AA_mJy_e3 # noqa
Comment on lines +93 to +95
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by this - I would have thought this failed previously - i.e. the shapes did not change. So is this also an incidental improvement that comes from these changes that's not strictly about the gwcs?

If so, that's good, but should be in the changelog as an API change I think (since users might encounter this as an error where they didn't before)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that this should be in the changelog too. Previously spectrum arithmetic didn't care if the spectral axis values were different, only that the shape of the arrays were the same. It now checks to see that the values actually match as well.



def test_masks(simulated_spectra):
Expand Down
2 changes: 1 addition & 1 deletion specutils/tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_snr_threshold():
np.random.seed(42)
flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)), unit='AA')
wcs = np.array([gwcs_from_array(x) for x in spectral_axis])
wcs = np.array([gwcs_from_array(x, [10,]) for x in spectral_axis])
uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
mask = np.ones((5, 10)).astype(bool)
meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]
Expand Down
5 changes: 3 additions & 2 deletions specutils/tests/test_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def test_spectral_axes():
sliced_spec2 = spec2[0]

assert isinstance(sliced_spec2, Spectrum)
assert_allclose(sliced_spec2.wcs.pixel_to_world(np.arange(10)), spec2.wcs.pixel_to_world(np.arange(10)))
assert_allclose(sliced_spec2.wcs.pixel_to_world(np.arange(10)),
spec2.wcs.pixel_to_world(np.arange(10), [0,]*10)[0])
assert sliced_spec2.flux.shape[0] == 49


Expand Down Expand Up @@ -107,4 +108,4 @@ def test_slicing_multidim():
assert spec1.mask.shape == (10,)

assert quantity_allclose(spec3.spectral_axis, spec.spectral_axis[4:7])
assert quantity_allclose(spec3.wcs.pixel_to_world([0,1,2]), spec3.spectral_axis[0:3])
assert quantity_allclose(spec3.wcs.pixel_to_world([0, 1, 2], [0, 0, 0])[0], spec3.spectral_axis[0:3])
7 changes: 4 additions & 3 deletions specutils/tests/test_spectrum_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
def spectrum_collection():
flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)) + 1, unit='AA')
wcs = np.array([gwcs_from_array(x) for x in spectral_axis])
wcs = np.array([gwcs_from_array(x, flux.shape, spectral_axis_index=1) for x in spectral_axis])
uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
mask = np.ones((5, 10)).astype(bool)
meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]

spec_coll = SpectrumCollection(
flux=flux, spectral_axis=spectral_axis, wcs=wcs,
uncertainty=uncertainty, mask=mask, meta=meta)
uncertainty=uncertainty, mask=mask, meta=meta,
spectral_axis_index=1)

return spec_coll

Expand Down Expand Up @@ -59,7 +60,7 @@ def test_collection_without_optional_arguments():
flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)) + 1, unit='AA')
uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
wcs = np.array([gwcs_from_array(x) for x in spectral_axis])
wcs = np.array([gwcs_from_array(x, flux.shape, spectral_axis_index=1) for x in spectral_axis])
mask = np.ones((5, 10)).astype(bool)
meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]

Expand Down
Loading