Skip to content

Commit 78fcfce

Browse files
committed
adding temp gwcs slicing for asdf writes
1 parent 9731e33 commit 78fcfce

File tree

2 files changed

+115
-13
lines changed

2 files changed

+115
-13
lines changed

astrocut/asdf_cutouts.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Licensed under a 3-clause BSD style license - see LICENSE.rst
22

33
"""This module implements cutout functionality similar to fitscut, but for the ASDF file format."""
4+
import copy
45
import pathlib
5-
from typing import Union
6+
from typing import Union, Tuple
67

78
import asdf
89
import astropy
910
import gwcs
1011
import numpy as np
1112

1213
from astropy.coordinates import SkyCoord
14+
from astropy.modeling import models
1315

1416

1517
def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:
@@ -56,7 +58,8 @@ def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:
5658

5759
def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, SkyCoord],
5860
wcs: astropy.wcs.wcs.WCS = None, size: int = 20, outfile: str = "example_roman_cutout.fits",
59-
write_file: bool = True, fill_value: Union[int, float] = np.nan) -> astropy.nddata.Cutout2D:
61+
write_file: bool = True, fill_value: Union[int, float] = np.nan,
62+
gwcsobj: gwcs.wcs.WCS = None) -> astropy.nddata.Cutout2D:
6063
""" Get a Roman image cutout
6164
6265
Cut out a square section from the input image data array. The ``coords`` can either be a tuple of x, y
@@ -79,6 +82,8 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
7982
Flag to write the cutout to a file or not
8083
fill_value: int | float, by default np.nan
8184
The fill value for pixels outside the original image.
85+
gwcsobj : gwcs.wcs.WCS, Optional
86+
the original gwcs object for the full image, needed only when writing cutout as asdf file
8287
8388
Returns
8489
-------
@@ -91,6 +96,8 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
9196
when a wcs is not present when coords is a SkyCoord object
9297
RuntimeError:
9398
when the requested cutout does not overlap with the original image
99+
ValueError:
100+
when no gwcs object is provided when writing to an asdf file
94101
"""
95102

96103
# check for correct inputs
@@ -122,12 +129,23 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
122129
if write_as == '.fits':
123130
_write_fits(cutout, outfile)
124131
elif write_as == '.asdf':
125-
_write_asdf(cutout, outfile)
132+
if not gwcsobj:
133+
raise ValueError('The original gwcs object is needed when writing to asdf file.')
134+
_write_asdf(cutout, gwcsobj, outfile)
126135

127136
return cutout
128137

129138

130-
def _write_fits(cutout, outfile="example_roman_cutout.fits"):
139+
def _write_fits(cutout: astropy.nddata.Cutout2D, outfile: str = "example_roman_cutout.fits"):
140+
""" Write cutout as FITS file
141+
142+
Parameters
143+
----------
144+
cutout : astropy.nddata.Cutout2D
145+
the 2d cutout
146+
outfile : str, optional
147+
the name of the output cutout file, by default "example_roman_cutout.fits"
148+
"""
131149
# check if the data is a quantity and get the array data
132150
if isinstance(cutout.data, astropy.units.Quantity):
133151
data = cutout.data.value
@@ -137,8 +155,61 @@ def _write_fits(cutout, outfile="example_roman_cutout.fits"):
137155
astropy.io.fits.writeto(outfile, data=data, header=cutout.wcs.to_header(relax=True), overwrite=True)
138156

139157

140-
def _write_asdf(cutout, outfile="example_roman_cutout.asdf"):
141-
tree = {'roman': {'meta': {'wcs': dict(cutout.wcs.to_header(relax=True))}, 'data': cutout.data}}
158+
def _slice_gwcs(gwcsobj: gwcs.wcs.WCS, slices: Tuple[slice, slice]) -> gwcs.wcs.WCS:
159+
""" Slice the original gwcs object
160+
161+
"Slices" the original gwcs object down to the cutout shape. This is a hack
162+
until proper gwcs slicing is in place a la fits WCS slicing. The ``slices``
163+
keyword input is a tuple with the x, y cutout boundaries in the original image
164+
array, e.g. ``cutout.slices_original``. Astropy Cutout2D slices are in the form
165+
((ymin, ymax, None), (xmin, xmax, None))
166+
167+
Parameters
168+
----------
169+
gwcsobj : gwcs.wcs.WCS
170+
the original gwcs from the input image
171+
slices : Tuple[slice, slice]
172+
the cutout x, y slices as ((ymin, ymax), (xmin, xmax))
173+
174+
Returns
175+
-------
176+
gwcs.wcs.WCS
177+
The sliced gwcs object
178+
"""
179+
tmp = copy.deepcopy(gwcsobj)
180+
181+
# get the cutout array bounds and create a new shift transform to the cutout
182+
# add the new transform to the gwcs
183+
xmin, xmax = slices[1].start, slices[1].stop
184+
ymin, ymax = slices[0].start, slices[0].stop
185+
shape = (ymax - ymin, xmax - xmin)
186+
offsets = models.Shift(xmin, name='cutout_offset1') & models.Shift(ymin, name='cutout_offset2')
187+
tmp.insert_transform('detector', offsets, after=True)
188+
189+
# modify the gwcs bounding box to the cutout shape
190+
tmp.bounding_box = ((0, shape[0] - 1), (0, shape[1] - 1))
191+
tmp.pixel_shape = shape[::-1]
192+
tmp.array_shape = shape
193+
return tmp
194+
195+
196+
def _write_asdf(cutout: astropy.nddata.Cutout2D, gwcsobj: gwcs.wcs.WCS, outfile: str = "example_roman_cutout.asdf"):
197+
""" Write cutout as ASDF file
198+
199+
Parameters
200+
----------
201+
cutout : astropy.nddata.Cutout2D
202+
the 2d cutout
203+
gwcsobj : gwcs.wcs.WCS
204+
the original gwcs object for the full image
205+
outfile : str, optional
206+
the name of the output cutout file, by default "example_roman_cutout.asdf"
207+
"""
208+
# slice the origial gwcs to the cutout
209+
sliced_gwcs = _slice_gwcs(gwcsobj, cutout.slices_original)
210+
211+
# create the asdf tree
212+
tree = {'roman': {'meta': {'wcs': sliced_gwcs, 'data': cutout.data}}}
142213
af = asdf.AsdfFile(tree)
143214

144215
# Write the data to a new file
@@ -186,4 +257,4 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
186257

187258
# create the 2d image cutout
188259
return get_cutout(data, pixel_coordinates, wcs, size=cutout_size, outfile=output_file,
189-
write_file=write_file, fill_value=fill_value)
260+
write_file=write_file, fill_value=fill_value, gwcsobj=gwcsobj)

astrocut/tests/test_asdf_cut.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from astropy.wcs.utils import pixel_to_skycoord
1313
from gwcs import wcs
1414
from gwcs import coordinate_frames as cf
15-
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut
15+
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut, _slice_gwcs
1616

1717

1818
def make_wcs(xsize, ysize, ra=30., dec=45.):
@@ -70,8 +70,8 @@ def _make_fake(nx, ny, ra, dec, zero=False, asint=False):
7070
def fakedata(makefake):
7171
""" fixture to create fake data and wcs """
7272
# set up initial parameters
73-
nx = 100
74-
ny = 100
73+
nx = 1000
74+
ny = 1000
7575
ra = 30.
7676
dec = 45.
7777

@@ -105,7 +105,7 @@ def test_get_center_pixel(fakedata):
105105
__, gwcs = fakedata
106106

107107
pixel_coordinates, wcs = get_center_pixel(gwcs, 30., 45.)
108-
assert np.allclose(pixel_coordinates, (np.array(50.), np.array(50.)))
108+
assert np.allclose(pixel_coordinates, (np.array(500.), np.array(500.)))
109109
assert np.allclose(wcs.celestial.wcs.crval, np.array([30., 45.]))
110110

111111

@@ -144,7 +144,7 @@ def test_get_cutout(output, fakedata, quantity):
144144
with fits.open(output_file) as hdulist:
145145
data = hdulist[0].data
146146
assert data.shape == (10, 10)
147-
assert data[5, 5] == 2525
147+
assert data[5, 5] == 25025
148148

149149

150150
def test_asdf_cutout(make_file, output):
@@ -158,7 +158,7 @@ def test_asdf_cutout(make_file, output):
158158
with fits.open(output_file) as hdulist:
159159
data = hdulist[0].data
160160
assert data.shape == (10, 10)
161-
assert data[5, 5] == 2526
161+
assert data[5, 5] == 475476
162162

163163

164164
@pytest.mark.parametrize('suffix', ['fits', 'asdf', None])
@@ -177,6 +177,16 @@ def test_write_file(make_file, suffix, output):
177177
assert pathlib.Path(output_file).exists()
178178

179179

180+
def test_fail_write_asdf(fakedata, output):
181+
""" test we fail to write an asdf if no gwcs given """
182+
with pytest.raises(ValueError, match='The original gwcs object is needed when writing to asdf file.'):
183+
output_file = output('asdf')
184+
data, gwcs = fakedata
185+
skycoord = gwcs(25, 25, with_units=True)
186+
wcs = WCS(gwcs.to_fits_sip())
187+
get_cutout(data, skycoord, wcs, size=10, outfile=output_file)
188+
189+
180190
def test_cutout_nofile(make_file, output):
181191
""" test we can make a cutout with no file output """
182192
output_file = output()
@@ -281,3 +291,24 @@ def test_cutout_raedge(makefake):
281291
assert bounds[0].ra.value > 359
282292
assert bounds[1].ra.value < 0.1
283293

294+
295+
def test_slice_gwcs(fakedata):
296+
""" test we can slice a gwcs object """
297+
data, gwcsobj = fakedata
298+
skycoord = gwcsobj(250, 250)
299+
wcs = WCS(gwcsobj.to_fits_sip())
300+
301+
cutout = get_cutout(data, skycoord, wcs, size=50, write_file=False)
302+
303+
sliced = _slice_gwcs(gwcsobj, cutout.slices_original)
304+
305+
# check coords between slice and original gwcs
306+
assert cutout.center_cutout == (24.5, 24.5)
307+
assert sliced.array_shape == (50, 50)
308+
assert sliced(*cutout.input_position_cutout) == gwcsobj(*cutout.input_position_original)
309+
assert gwcsobj(*cutout.center_original) == sliced(*cutout.center_cutout)
310+
311+
# assert same sky footprint between slice and original
312+
# gwcs footprint/bounding_box expects ((x0, x1), (y0, y1)) but cutout.bbox is in ((y0, y1), (x0, x1))
313+
assert (gwcsobj.footprint(bounding_box=tuple(reversed(cutout.bbox_original))) == sliced.footprint()).all()
314+

0 commit comments

Comments
 (0)