Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor gwcs_from_array to provide ND GWCS in ND flux case #1211

Merged
merged 22 commits into from
Mar 18, 2025
Merged
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
334f2dc
Working on constructing N dimensional GWCS for ND flux with only spec…
rosteen Jan 31, 2025
0547b4a
Get forward transform working for ND case
rosteen Jan 31, 2025
babc123
Convert other operator to Spectrum if needed in specutils where we ha…
rosteen Feb 3, 2025
429d87a
Don't pass extra arg to ndarithmetic
rosteen Feb 3, 2025
dbe74b6
Use dimensionless instead of pix, remove unit conversion code that I …
rosteen Feb 3, 2025
bb68350
Fix two tests for multidimensional GWCS, remove debugging print
rosteen Feb 3, 2025
02f7c3e
Working on debugging SpectrumCollection
rosteen Feb 3, 2025
2e332a7
Check for Spectrum specifically before arithmetic
rosteen Feb 3, 2025
35e991a
Remove incorrect index from spectrum collection slice, cast operand a…
rosteen Feb 4, 2025
c3424bb
Don't pass multi-d WCS to collapsed spectrum
rosteen Feb 4, 2025
52137fb
Fixing more test failures
rosteen Feb 4, 2025
fd52aef
Revamp arithmetic
rosteen Feb 4, 2025
4460d26
Add changelog
rosteen Feb 4, 2025
e74b2a1
remove debugging print
rosteen Feb 4, 2025
368d0fe
Handle creating dummy WCS and spectral axis when neither are provided…
rosteen Feb 6, 2025
f856f3a
Codestyle
rosteen Feb 6, 2025
4708bf5
Remove debugging prints
rosteen Feb 6, 2025
64d1180
Fixes for compatibility with up to date gwcs
rosteen Feb 7, 2025
55cbb3e
Bump min version pins
rosteen Feb 7, 2025
35ea81a
Need gwcs 0.24
rosteen Feb 9, 2025
afb7cb4
Add note about arithmetic to changelog
rosteen Mar 4, 2025
75aa0d7
Fix Spectrum1D references in doc page
rosteen Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -17,6 +17,10 @@ New Features
Other Changes and Additions
^^^^^^^^^^^^^^^^^^^^^^^^^^^

- Initializing a ``Spectrum`` with only a spectral axis (not full WCS) will now
result in a GWCS matching the dimensionality of the flux array, rather than a
1D spectral GWCS in all cases. [#1211]

1.20.0 (unreleased)
-------------------

3 changes: 0 additions & 3 deletions docs/spectrum.rst
Original file line number Diff line number Diff line change
@@ -406,9 +406,6 @@ spectral axis, or 'spatial', which will collapse along all non-spectral axes.
>>> spec.mean(axis='spatial') # doctest: +FLOAT_CMP
<Spectrum(flux=<Quantity [0.37273938, 0.53843905, 0.61351648, 0.57311623, 0.44339915,
0.66084728, 0.45881921, 0.38715911, 0.39967185, 0.53257671] Jy> (shape=(10,), mean=0.49803 Jy); spectral_axis=<SpectralAxis
(observer to target:
radial_velocity=0.0 km / s
redshift=0.0)
Comment on lines -409 to -411
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed? Is it an incidental thing you noticed at the time (no prob if so) or a result of this change somehow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall, I'll double check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was due to creating the new collapsed Spectrum with self.spectral_axis rather than self.wcs now, and is actually probably more correct than the previously printed out repr, which was replacing None values for these quantities with 0 in the process of collapsing (implying information we actually don't have).

[5000. 5001. 5002. ... 5007. 5008. 5009.] Angstrom> (length=10))>

Note that in this case, the result of the collapse operation is a
2 changes: 1 addition & 1 deletion docs/spectrum_collection.rst
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ solution.

>>> flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
>>> spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)), unit='AA')
>>> wcs = np.array([gwcs_from_array(x) for x in spectral_axis])
>>> wcs = np.array([gwcs_from_array(x, x.shape) for x in spectral_axis])
>>> uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
>>> mask = np.ones((5, 10)).astype(bool)
>>> meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]
10 changes: 5 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -17,11 +17,11 @@ packages = find:
python_requires = >=3.10
install_requires =
numpy>=1.24
scipy>=1.3
astropy>=5.1
gwcs>=0.18
asdf-astropy>=0.3
asdf>=2.14.4
scipy>=1.14
astropy>=6.0
gwcs>=0.22
asdf-astropy>=0.5
asdf>=3.3.0
ndcube>=2.0

[options.extras_require]
121 changes: 86 additions & 35 deletions specutils/spectra/spectrum.py
Original file line number Diff line number Diff line change
@@ -89,6 +89,8 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
# If the flux (data) argument is already a Spectrum (as it would
# be for internal arithmetic operations), avoid setup entirely.
if isinstance(flux, Spectrum):
self._spectral_axis_index = flux.spectral_axis_index
self._spectral_axis = flux.spectral_axis
super().__init__(flux)
return

@@ -157,9 +159,7 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
# In the case where the arithmetic operation is being performed with
# a single float, int, or array object, just go ahead and ignore wcs
# requirements
if (not isinstance(flux, u.Quantity) or isinstance(flux, float)
or isinstance(flux, int)) and np.ndim(flux) == 0:

if np.ndim(flux) == 0 and spectral_axis is None and wcs is None:
super(Spectrum, self).__init__(data=flux, wcs=wcs, **kwargs)
return

@@ -332,7 +332,10 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
self._spectral_axis = spectral_axis

if wcs is None:
wcs = gwcs_from_array(self._spectral_axis)
wcs = gwcs_from_array(self._spectral_axis,
flux.shape,
spectral_axis_index=self.spectral_axis_index
)

elif wcs is None:
# If no spectral axis or wcs information is provided, initialize
@@ -344,7 +347,10 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
raise ValueError("Must specify spectral_axis_index if no WCS or spectral"
" axis is input.")
size = flux.shape[self.spectral_axis_index] if not flux.isscalar else 1
wcs = gwcs_from_array(np.arange(size) * u.Unit(""))
wcs = gwcs_from_array(np.arange(size) * u.Unit(""),
flux.shape,
spectral_axis_index=self.spectral_axis_index
)

super().__init__(
data=flux.value if isinstance(flux, u.Quantity) else flux,
@@ -379,6 +385,10 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
for coords in temp_coords:
if isinstance(coords, SpectralCoord):
spec_axis = coords
break
else:
# WCS axis ordering is reverse of numpy
spec_axis = temp_coords[len(temp_coords) - self.spectral_axis_index - 1]
else:
spec_axis = temp_coords

@@ -646,7 +656,9 @@ def collapse(self, method, axis=None):
elif isinstance(axis, tuple) and self.spectral_axis_index in axis:
return collapsed_flux
else:
return Spectrum(collapsed_flux, wcs=self.wcs)
# Pass the spectral axis rather than WCS in this case, so we don't have to
# figure out which part of a multidimensional WCS is the spectral part.
return Spectrum(collapsed_flux, spectral_axis=self.spectral_axis)

def mean(self, **kwargs):
return self.collapse("mean", **kwargs)
@@ -821,39 +833,74 @@ def _return_with_redshift(self, result):
result.shift_spectrum_to(redshift=self.redshift)
return result

def __add__(self, other):
if not isinstance(other, (NDCube, u.Quantity)):
try:
other = u.Quantity(other, unit=self.unit)
except TypeError:
return NotImplemented
def _other_as_correct_class(self, other, force_quantity=False):
# NDArithmetic mixin will try to turn other into a Spectrum, which will fail
# sometimes because of not specifiying the spectral axis index
if isinstance(other, Spectrum):
# Take this opportunity to check if the spectral axes match
if not np.all(other.spectral_axis == self.spectral_axis):
raise ValueError("Spectral axis of both operands must match")
else:
if not isinstance(other, u.Quantity) and force_quantity:
other = other * self.unit

return self._return_with_redshift(self.add(other))
if isinstance(other, u.Quantity) and other.shape == self.shape:
return Spectrum(flux=other, spectral_axis=self.spectral_axis,
spectral_axis_index=self.spectral_axis_index)

def __sub__(self, other):
if not isinstance(other, NDCube):
try:
other = u.Quantity(other, unit=self.unit)
except TypeError:
return NotImplemented
return other

return self._return_with_redshift(self.subtract(other))
def __add__(self, other):
other = self._other_as_correct_class(other, force_quantity=True)
if isinstance(other, (Spectrum)):
return self._return_with_redshift(self.add(other))
else:
new_flux = self.flux + other
return self._return_with_redshift(Spectrum(new_flux, wcs=self.wcs, meta=self.meta,
uncertainty=self.uncertainty))

def __mul__(self, other):
if not isinstance(other, NDCube):
other = u.Quantity(other)
def __sub__(self, other):
other = self._other_as_correct_class(other, force_quantity=True)
if isinstance(other, (Spectrum)):
return self._return_with_redshift(self.subtract(other))
else:
new_flux = self.flux - other
return self._return_with_redshift(Spectrum(new_flux, wcs=self.wcs, meta=self.meta,
uncertainty=self.uncertainty))

return self._return_with_redshift(self.multiply(other))
def __mul__(self, other):
other = self._other_as_correct_class(other)
if isinstance(other, (Spectrum)):
return self._return_with_redshift(self.multiply(other))
else:
new_flux = self.flux * other
if self.uncertainty is None:
new_uncertainty = None
else:
new_uncertainty = deepcopy(self.uncertainty)
new_uncertainty.array *= other
return self._return_with_redshift(Spectrum(new_flux, wcs=self.wcs,
meta=self.meta,
uncertainty=new_uncertainty))

def __div__(self, other):
if not isinstance(other, NDCube):
other = u.Quantity(other)

return self._return_with_redshift(self.divide(other))
other = self._other_as_correct_class(other)
if isinstance(other, (Spectrum)):
return self._return_with_redshift(self.divide(other))
else:
new_flux = self.flux / other
if self.uncertainty is None:
new_uncertainty = None
else:
new_uncertainty = deepcopy(self.uncertainty)
new_uncertainty.array /= other
return self._return_with_redshift(Spectrum(new_flux, wcs=self.wcs,
meta=self.meta,
uncertainty=self.uncertainty/other))

def __truediv__(self, other):
if not isinstance(other, NDCube):
other = u.Quantity(other)
if not isinstance(other, Spectrum):
other = self._other_as_correct_class(other)

return self._return_with_redshift(self.divide(other))

@@ -901,11 +948,15 @@ def __repr__(self):
flux_str += f" {self.flux.unit}"

flux_str += f" (shape={self.flux.shape}, mean={np.nanmean(self.flux):.5f}); "
spectral_axis_str = (repr(self.spectral_axis).split("[")[0] +
np.array2string(self.spectral_axis, threshold=8) +
f" {self.spectral_axis.unit}>")
spectral_axis_str = f"spectral_axis={spectral_axis_str} (length={len(self.spectral_axis)})"
inner_str = (flux_str + spectral_axis_str)
# Sometimes this errors if an error occurs during initialization
if hasattr(self, "_spectral_axis"):
spectral_axis_str = (repr(self.spectral_axis).split("[")[0] +
np.array2string(self.spectral_axis, threshold=8) +
f" {self.spectral_axis.unit}>")
spectral_axis_str = f"spectral_axis={spectral_axis_str} (length={len(self.spectral_axis)})"
inner_str = (flux_str + spectral_axis_str)
else:
inner_str = flux_str

if self.uncertainty is not None:
inner_str += f"; uncertainty={self.uncertainty.__class__.__name__}"
12 changes: 10 additions & 2 deletions specutils/spectra/spectrum_collection.py
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ class SpectrumCollection(NDIOMixin):
each spectrum in the collection.
"""
def __init__(self, flux, spectral_axis=None, wcs=None, uncertainty=None,
mask=None, meta=None):
mask=None, meta=None, spectral_axis_index=None):
# Check for quantity
if not isinstance(flux, u.Quantity):
raise u.UnitsError("Flux must be a `Quantity`.")
@@ -89,6 +89,7 @@ def __init__(self, flux, spectral_axis=None, wcs=None, uncertainty=None,

self._flux = flux
self._spectral_axis = spectral_axis
self._spectral_axis_index = spectral_axis_index
self._wcs = wcs
self._uncertainty = uncertainty
self._mask = mask
@@ -153,6 +154,8 @@ def from_spectra(cls, spectra):
observer=sa[0].observer,
target=sa[0].target)

spectral_axis_index = spectra[0].spectral_axis_index

# Check that either all spectra have associated uncertainties, or that
# none of them do. If only some do, log an error and ignore the
# uncertainties.
@@ -183,7 +186,8 @@ def from_spectra(cls, spectra):
meta = [spec.meta for spec in spectra]

return cls(flux=flux, spectral_axis=spectral_axis,
uncertainty=uncertainty, wcs=wcs, mask=mask, meta=meta)
uncertainty=uncertainty, wcs=wcs, mask=mask, meta=meta,
spectral_axis_index=spectral_axis_index)

@property
def flux(self):
@@ -195,6 +199,10 @@ def spectral_axis(self):
"""The spectral axes as a `~astropy.units.Quantity`."""
return self._spectral_axis

@property
def spectral_axis_index(self):
return self._spectral_axis_index

@property
def frequency(self):
"""
6 changes: 4 additions & 2 deletions specutils/tests/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import astropy.units as u
from astropy.tests.helper import assert_quantity_allclose
import numpy as np
import pytest

from ..spectra.spectrum import Spectrum

@@ -89,8 +90,9 @@ def test_multiplication_basic_spectra(simulated_spectra):

def test_add_diff_spectral_axis(simulated_spectra):

# Calculate using the spectrum/nddata code
spec3 = simulated_spectra.s1_um_mJy_e1 + simulated_spectra.s1_AA_mJy_e3 # noqa
# We now raise an error if the spectra aren't on the same spectral axis
with pytest.raises(ValueError, match="Spectral axis of both operands must match"):
spec3 = simulated_spectra.s1_um_mJy_e1 + simulated_spectra.s1_AA_mJy_e3 # noqa
Comment on lines +93 to +95
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by this - I would have thought this failed previously - i.e. the shapes did not change. So is this also an incidental improvement that comes from these changes that's not strictly about the gwcs?

If so, that's good, but should be in the changelog as an API change I think (since users might encounter this as an error where they didn't before)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that this should be in the changelog too. Previously spectrum arithmetic didn't care if the spectral axis values were different, only that the shape of the arrays were the same. It now checks to see that the values actually match as well.



def test_masks(simulated_spectra):
2 changes: 1 addition & 1 deletion specutils/tests/test_manipulation.py
Original file line number Diff line number Diff line change
@@ -129,7 +129,7 @@ def test_snr_threshold():
np.random.seed(42)
flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)), unit='AA')
wcs = np.array([gwcs_from_array(x) for x in spectral_axis])
wcs = np.array([gwcs_from_array(x, [10,]) for x in spectral_axis])
uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
mask = np.ones((5, 10)).astype(bool)
meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]
5 changes: 3 additions & 2 deletions specutils/tests/test_slicing.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,8 @@ def test_spectral_axes():
sliced_spec2 = spec2[0]

assert isinstance(sliced_spec2, Spectrum)
assert_allclose(sliced_spec2.wcs.pixel_to_world(np.arange(10)), spec2.wcs.pixel_to_world(np.arange(10)))
assert_allclose(sliced_spec2.wcs.pixel_to_world(np.arange(10)),
spec2.wcs.pixel_to_world(np.arange(10), [0,]*10)[0])
assert sliced_spec2.flux.shape[0] == 49


@@ -107,4 +108,4 @@ def test_slicing_multidim():
assert spec1.mask.shape == (10,)

assert quantity_allclose(spec3.spectral_axis, spec.spectral_axis[4:7])
assert quantity_allclose(spec3.wcs.pixel_to_world([0,1,2]), spec3.spectral_axis[0:3])
assert quantity_allclose(spec3.wcs.pixel_to_world([0, 1, 2], [0, 0, 0])[0], spec3.spectral_axis[0:3])
7 changes: 4 additions & 3 deletions specutils/tests/test_spectrum_collection.py
Original file line number Diff line number Diff line change
@@ -15,14 +15,15 @@
def spectrum_collection():
flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)) + 1, unit='AA')
wcs = np.array([gwcs_from_array(x) for x in spectral_axis])
wcs = np.array([gwcs_from_array(x, flux.shape, spectral_axis_index=1) for x in spectral_axis])
uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
mask = np.ones((5, 10)).astype(bool)
meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]

spec_coll = SpectrumCollection(
flux=flux, spectral_axis=spectral_axis, wcs=wcs,
uncertainty=uncertainty, mask=mask, meta=meta)
uncertainty=uncertainty, mask=mask, meta=meta,
spectral_axis_index=1)

return spec_coll

@@ -59,7 +60,7 @@ def test_collection_without_optional_arguments():
flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)) + 1, unit='AA')
uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
wcs = np.array([gwcs_from_array(x) for x in spectral_axis])
wcs = np.array([gwcs_from_array(x, flux.shape, spectral_axis_index=1) for x in spectral_axis])
mask = np.ones((5, 10)).astype(bool)
meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]

Loading