Skip to content

Commit

Permalink
adding option and tests for fill_value
Browse files Browse the repository at this point in the history
  • Loading branch information
havok2063 committed Jan 23, 2024
1 parent 27d1398 commit 1097cfb
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 5 deletions.
14 changes: 10 additions & 4 deletions astrocut/asdf_cutouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asdf
import astropy
import gwcs
import numpy as np

from astropy.coordinates import SkyCoord

Expand Down Expand Up @@ -54,7 +55,7 @@ def get_center_pixel(gwcs: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:

def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, SkyCoord],
wcs: astropy.wcs.wcs.WCS = None, size: int = 20, outfile: str = "example_roman_cutout.fits",
write_file: bool = True) -> astropy.nddata.Cutout2D:
write_file: bool = True, fill_value: int = np.nan) -> astropy.nddata.Cutout2D:
""" Get a Roman image cutout
Cut out a square section from the input image data array. The ``coords`` can either be a tuple of x, y
Expand All @@ -75,6 +76,8 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
the name of the output cutout file, by default "example_roman_cutout.fits"
write_file : bool, by default True
Flag to write the cutout to a file or not
fill_value: int, by default np.nan
The fill value for pixels outside the original image.
Returns
-------
Expand All @@ -95,7 +98,8 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk

# create the cutout
try:
cutout = astropy.nddata.Cutout2D(data, position=coords, wcs=wcs, size=(size, size), mode='partial')
cutout = astropy.nddata.Cutout2D(data, position=coords, wcs=wcs, size=(size, size), mode='partial',
fill_value=fill_value)
except astropy.nddata.utils.NoOverlapError as e:
raise RuntimeError('Could not create 2d cutout. The requested cutout does not overlap with the original image.') from e

Expand All @@ -114,7 +118,7 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk

def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
output_file: str = "example_roman_cutout.fits",
write_file: bool = True) -> astropy.nddata.Cutout2D:
write_file: bool = True, fill_value: int = np.nan) -> astropy.nddata.Cutout2D:
""" Preliminary proof-of-concept functionality.
Takes a single ASDF input file (``input_file``) and generates a cutout of designated size ``cutout_size``
Expand All @@ -134,6 +138,8 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
the name of the output cutout file, by default "example_roman_cutout.fits"
write_file : bool, by default True
Flag to write the cutout to a file or not
fill_value: int, by default np.nan
The fill value for pixels outside the original image.
Returns
-------
Expand All @@ -151,4 +157,4 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,

# create the 2d image cutout
return get_cutout(data, pixel_coordinates, wcs, size=cutout_size, outfile=output_file,
write_file=write_file)
write_file=write_file, fill_value=fill_value)
34 changes: 33 additions & 1 deletion astrocut/tests/test_asdf_cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def make_wcs(xsize, ysize, ra=30., dec=45.):
def makefake():
""" fixture factory to make a fake gwcs and dataset """

def _make_fake(nx, ny, ra, dec, zero=False):
def _make_fake(nx, ny, ra, dec, zero=False, asint=False):
# create the wcs
wcsobj = make_wcs(nx/2, ny/2, ra=ra, dec=dec)
wcsobj.bounding_box = ((0, nx), (0, ny))
Expand All @@ -53,8 +53,14 @@ def _make_fake(nx, ny, ra, dec, zero=False):
else:
size = nx * ny
data = np.arange(size).reshape(nx, ny)

# make a quantity
data *= (u.electron / u.second)

# make integer array
if asint:
data = data.astype(int)

return data, wcsobj

yield _make_fake
Expand Down Expand Up @@ -205,6 +211,32 @@ def assert_same_coord(x, y, cutout, wcs):
assert cutout_coord == orig_coord


@pytest.mark.parametrize('asint, fill', [(False, None), (True, -9999)], ids=['fillfloat', 'fillint'])
def test_partial_cutout(makefake, asint, fill):
""" test we get a partial cutout with nans or fill value """
ra, dec = 30.0, 45.0
data, gwcs = makefake(100, 100, ra, dec, asint=asint)

wcs = WCS(gwcs.to_fits_sip())
cc = coord.SkyCoord(29.999, 44.998, unit=u.degree)
cutout = get_cutout(data, cc, wcs, size=50, write_file=False, fill_value=fill)
assert cutout.shape == (50, 50)
if asint:
assert -9999 in cutout.data
else:
assert np.isnan(cutout.data).any()


def test_bad_fill(makefake):
""" test error is raised on bad fill value """
ra, dec = 30.0, 45.0
data, gwcs = makefake(100, 100, ra, dec, asint=True)
wcs = WCS(gwcs.to_fits_sip())
cc = coord.SkyCoord(29.999, 44.998, unit=u.degree)
with pytest.raises(ValueError, match='fill_value is inconsistent with the data type of the input array'):
get_cutout(data, cc, wcs, size=50, write_file=False)





0 comments on commit 1097cfb

Please sign in to comment.