-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #105 from spacetelescope/asdf-cut
New Feature: ASDF Cutout Functionality
- Loading branch information
Showing
3 changed files
with
250 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Licensed under a 3-clause BSD style license - see LICENSE.rst | ||
|
||
"""This module implements cutout functionality similar to fitscut, but for the ASDF file format.""" | ||
from typing import Union | ||
|
||
import asdf | ||
import astropy | ||
import gwcs | ||
|
||
from astropy.coordinates import SkyCoord | ||
|
||
|
||
def get_center_pixel(gwcs: gwcs.wcs.WCS, ra: float, dec: float) -> tuple: | ||
""" Get the center pixel from a roman 2d science image | ||
For an input RA, Dec sky coordinate, get the closest pixel location | ||
on the input Roman image. | ||
Parameters | ||
---------- | ||
gwcs : gwcs.wcs.WCS | ||
the Roman GWCS object | ||
ra : float | ||
the input Right Ascension | ||
dec : float | ||
the input Declination | ||
Returns | ||
------- | ||
tuple | ||
the pixel position, FITS wcs object | ||
""" | ||
|
||
# Convert the gwcs object to an astropy FITS WCS header | ||
header = gwcs.to_fits_sip() | ||
|
||
# Update WCS header with some keywords that it's missing. | ||
# Otherwise, it won't work with astropy.wcs tools (TODO: Figure out why. What are these keywords for?) | ||
for k in ['cpdis1', 'cpdis2', 'det2im1', 'det2im2', 'sip']: | ||
if k not in header: | ||
header[k] = 'na' | ||
|
||
# New WCS object with updated header | ||
wcs_updated = astropy.wcs.WCS(header) | ||
|
||
# Turn input RA, Dec into a SkyCoord object | ||
coordinates = SkyCoord(ra, dec, unit='deg') | ||
|
||
# Map the coordinates to a pixel's location on the Roman 2d array (row, col) | ||
row, col = astropy.wcs.utils.skycoord_to_pixel(coords=coordinates, wcs=wcs_updated) | ||
|
||
return (row, col), wcs_updated | ||
|
||
|
||
def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, SkyCoord], | ||
wcs: astropy.wcs.wcs.WCS = None, size: int = 20, outfile: str = "example_roman_cutout.fits"): | ||
""" Get a Roman image cutout | ||
Cut out a square section from the input image data array. The ``coords`` can either be a tuple of x, y | ||
pixel coordinates or an astropy SkyCoord object, in which case, a wcs is required. Writes out a | ||
new output file containing the image cutout of the specified ``size``. Default is 20 pixels. | ||
Parameters | ||
---------- | ||
data : asdf.tags.core.ndarray.NDArrayType | ||
the input Roman image data array | ||
coords : Union[tuple, SkyCoord] | ||
the input pixel or sky coordinates | ||
wcs : astropy.wcs.wcs.WCS, Optional | ||
the astropy FITS wcs object | ||
size : int, optional | ||
the image cutout pizel size, by default 20 | ||
outfile : str, optional | ||
the name of the output cutout file, by default "example_roman_cutout.fits" | ||
Raises | ||
------ | ||
ValueError: | ||
when a wcs is not present when coords is a SkyCoord object | ||
""" | ||
|
||
# check for correct inputs | ||
if isinstance(coords, SkyCoord) and not wcs: | ||
raise ValueError('wcs must be input if coords is a SkyCoord.') | ||
|
||
# create the cutout | ||
cutout = astropy.nddata.Cutout2D(data, position=coords, wcs=wcs, size=(size, size)) | ||
|
||
# write the cutout to the output file | ||
astropy.io.fits.writeto(outfile, data=cutout.data, header=cutout.wcs.to_header(), overwrite=True) | ||
|
||
|
||
def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20, | ||
output_file: str = "example_roman_cutout.fits"): | ||
""" Preliminary proof-of-concept functionality. | ||
Takes a single ASDF input file (``input_file``) and generates a cutout of designated size ``cutout_size`` | ||
around the given coordinates (``coordinates``). | ||
Parameters | ||
---------- | ||
input_file : str | ||
the input ASDF file | ||
ra : float | ||
the Right Ascension of the central cutout | ||
dec : float | ||
the Declination of the central cutout | ||
cutout_size : int, optional | ||
the image cutout pixel size, by default 20 | ||
output_file : str, optional | ||
the name of the output cutout file, by default "example_roman_cutout.fits" | ||
""" | ||
|
||
# get the 2d image data | ||
with asdf.open(input_file) as f: | ||
data = f['roman']['data'] | ||
gwcs = f['roman']['meta']['wcs'] | ||
|
||
# get the center pixel | ||
pixel_coordinates, wcs = get_center_pixel(gwcs, ra, dec) | ||
|
||
# create the 2d image cutout | ||
get_cutout(data, pixel_coordinates, wcs, size=cutout_size, outfile=output_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import asdf | ||
from astropy.modeling import models | ||
from astropy import coordinates as coord | ||
from astropy import units as u | ||
from astropy.io import fits | ||
from astropy.wcs import WCS | ||
from gwcs import wcs | ||
from gwcs import coordinate_frames as cf | ||
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut | ||
|
||
|
||
def make_wcs(xsize, ysize, ra=30., dec=45.): | ||
""" create a fake gwcs object """ | ||
# todo - refine this to better reflect roman wcs | ||
|
||
# create transformations | ||
pixelshift = models.Shift(-xsize) & models.Shift(-ysize) | ||
pixelscale = models.Scale(0.1 / 3600.) & models.Scale(0.1 / 3600.) # 0.1 arcsec/pixel | ||
tangent_projection = models.Pix2Sky_TAN() | ||
celestial_rotation = models.RotateNative2Celestial(ra, dec, 180.) | ||
|
||
# transform pixels to sky | ||
det2sky = pixelshift | pixelscale | tangent_projection | celestial_rotation | ||
|
||
# define the wcs object | ||
detector_frame = cf.Frame2D(name="detector", axes_names=("x", "y"), unit=(u.pix, u.pix)) | ||
sky_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), name='icrs', unit=(u.deg, u.deg)) | ||
return wcs.WCS([(detector_frame, det2sky), (sky_frame, None)]) | ||
|
||
|
||
@pytest.fixture() | ||
def fakedata(): | ||
""" fixture to create fake data and wcs """ | ||
# set up initial parameters | ||
nx = 100 | ||
ny = 100 | ||
size = nx * ny | ||
ra = 30. | ||
dec = 45. | ||
|
||
# create the wcs | ||
wcsobj = make_wcs(nx/2, ny/2, ra=ra, dec=dec) | ||
wcsobj.bounding_box = ((0, nx), (0, ny)) | ||
|
||
# create the data | ||
data = np.arange(size).reshape(nx, ny) | ||
|
||
yield data, wcsobj | ||
|
||
|
||
@pytest.fixture() | ||
def make_file(tmp_path, fakedata): | ||
""" fixture to create a fake dataset """ | ||
# get the fake data | ||
data, wcsobj = fakedata | ||
|
||
# create meta | ||
meta = {'wcs': wcsobj} | ||
|
||
# create and write the asdf file | ||
tree = {'roman': {'data': data, 'meta': meta}} | ||
af = asdf.AsdfFile(tree) | ||
|
||
path = tmp_path / "roman" | ||
path.mkdir(exist_ok=True) | ||
filename = path / "test_roman.asdf" | ||
af.write_to(filename) | ||
|
||
yield filename | ||
|
||
|
||
def test_get_center_pixel(fakedata): | ||
""" test we can get the correct center pixel """ | ||
# get the fake data | ||
__, gwcs = fakedata | ||
|
||
pixel_coordinates, wcs = get_center_pixel(gwcs, 30., 45.) | ||
assert np.allclose(pixel_coordinates, (np.array(50.), np.array(50.))) | ||
assert np.allclose(wcs.celestial.wcs.crval, np.array([30., 45.])) | ||
|
||
|
||
@pytest.fixture() | ||
def output_file(tmp_path): | ||
""" fixture to create the output path """ | ||
# create output fits path | ||
out = tmp_path / "roman" | ||
out.mkdir(exist_ok=True, parents=True) | ||
output_file = out / "test_output_cutout.fits" | ||
yield output_file | ||
|
||
|
||
def test_get_cutout(output_file, fakedata): | ||
""" test we can create a cutout """ | ||
|
||
# get the input wcs | ||
data, gwcs = fakedata | ||
skycoord = gwcs(25, 25, with_units=True) | ||
wcs = WCS(gwcs.to_fits_sip()) | ||
|
||
# create cutout | ||
get_cutout(data, skycoord, wcs, size=10, outfile=output_file) | ||
|
||
# test output | ||
with fits.open(output_file) as hdulist: | ||
data = hdulist[0].data | ||
assert data.shape == (10, 10) | ||
assert data[5, 5] == 2525 | ||
|
||
|
||
def test_asdf_cutout(make_file, output_file): | ||
""" test we can make a cutout """ | ||
# make cutout | ||
ra, dec = (29.99901792, 44.99930555) | ||
asdf_cut(make_file, ra, dec, cutout_size=10, output_file=output_file) | ||
|
||
# test output | ||
with fits.open(output_file) as hdulist: | ||
data = hdulist[0].data | ||
assert data.shape == (10, 10) | ||
assert data[5, 5] == 2526 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters