Skip to content

Commit

Permalink
Merge pull request #105 from spacetelescope/asdf-cut
Browse files Browse the repository at this point in the history
New Feature: ASDF Cutout Functionality
  • Loading branch information
havok2063 authored Jan 3, 2024
2 parents b79e8b1 + e182dd4 commit b9742bd
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 0 deletions.
123 changes: 123 additions & 0 deletions astrocut/asdf_cutouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst

"""This module implements cutout functionality similar to fitscut, but for the ASDF file format."""
from typing import Union

import asdf
import astropy
import gwcs

from astropy.coordinates import SkyCoord


def get_center_pixel(gwcs: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:
""" Get the center pixel from a roman 2d science image
For an input RA, Dec sky coordinate, get the closest pixel location
on the input Roman image.
Parameters
----------
gwcs : gwcs.wcs.WCS
the Roman GWCS object
ra : float
the input Right Ascension
dec : float
the input Declination
Returns
-------
tuple
the pixel position, FITS wcs object
"""

# Convert the gwcs object to an astropy FITS WCS header
header = gwcs.to_fits_sip()

# Update WCS header with some keywords that it's missing.
# Otherwise, it won't work with astropy.wcs tools (TODO: Figure out why. What are these keywords for?)
for k in ['cpdis1', 'cpdis2', 'det2im1', 'det2im2', 'sip']:
if k not in header:
header[k] = 'na'

# New WCS object with updated header
wcs_updated = astropy.wcs.WCS(header)

# Turn input RA, Dec into a SkyCoord object
coordinates = SkyCoord(ra, dec, unit='deg')

# Map the coordinates to a pixel's location on the Roman 2d array (row, col)
row, col = astropy.wcs.utils.skycoord_to_pixel(coords=coordinates, wcs=wcs_updated)

return (row, col), wcs_updated


def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, SkyCoord],
wcs: astropy.wcs.wcs.WCS = None, size: int = 20, outfile: str = "example_roman_cutout.fits"):
""" Get a Roman image cutout
Cut out a square section from the input image data array. The ``coords`` can either be a tuple of x, y
pixel coordinates or an astropy SkyCoord object, in which case, a wcs is required. Writes out a
new output file containing the image cutout of the specified ``size``. Default is 20 pixels.
Parameters
----------
data : asdf.tags.core.ndarray.NDArrayType
the input Roman image data array
coords : Union[tuple, SkyCoord]
the input pixel or sky coordinates
wcs : astropy.wcs.wcs.WCS, Optional
the astropy FITS wcs object
size : int, optional
the image cutout pizel size, by default 20
outfile : str, optional
the name of the output cutout file, by default "example_roman_cutout.fits"
Raises
------
ValueError:
when a wcs is not present when coords is a SkyCoord object
"""

# check for correct inputs
if isinstance(coords, SkyCoord) and not wcs:
raise ValueError('wcs must be input if coords is a SkyCoord.')

# create the cutout
cutout = astropy.nddata.Cutout2D(data, position=coords, wcs=wcs, size=(size, size))

# write the cutout to the output file
astropy.io.fits.writeto(outfile, data=cutout.data, header=cutout.wcs.to_header(), overwrite=True)


def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
output_file: str = "example_roman_cutout.fits"):
""" Preliminary proof-of-concept functionality.
Takes a single ASDF input file (``input_file``) and generates a cutout of designated size ``cutout_size``
around the given coordinates (``coordinates``).
Parameters
----------
input_file : str
the input ASDF file
ra : float
the Right Ascension of the central cutout
dec : float
the Declination of the central cutout
cutout_size : int, optional
the image cutout pixel size, by default 20
output_file : str, optional
the name of the output cutout file, by default "example_roman_cutout.fits"
"""

# get the 2d image data
with asdf.open(input_file) as f:
data = f['roman']['data']
gwcs = f['roman']['meta']['wcs']

# get the center pixel
pixel_coordinates, wcs = get_center_pixel(gwcs, ra, dec)

# create the 2d image cutout
get_cutout(data, pixel_coordinates, wcs, size=cutout_size, outfile=output_file)
125 changes: 125 additions & 0 deletions astrocut/tests/test_asdf_cut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@

import numpy as np
import pytest

import asdf
from astropy.modeling import models
from astropy import coordinates as coord
from astropy import units as u
from astropy.io import fits
from astropy.wcs import WCS
from gwcs import wcs
from gwcs import coordinate_frames as cf
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut


def make_wcs(xsize, ysize, ra=30., dec=45.):
""" create a fake gwcs object """
# todo - refine this to better reflect roman wcs

# create transformations
pixelshift = models.Shift(-xsize) & models.Shift(-ysize)
pixelscale = models.Scale(0.1 / 3600.) & models.Scale(0.1 / 3600.) # 0.1 arcsec/pixel
tangent_projection = models.Pix2Sky_TAN()
celestial_rotation = models.RotateNative2Celestial(ra, dec, 180.)

# transform pixels to sky
det2sky = pixelshift | pixelscale | tangent_projection | celestial_rotation

# define the wcs object
detector_frame = cf.Frame2D(name="detector", axes_names=("x", "y"), unit=(u.pix, u.pix))
sky_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), name='icrs', unit=(u.deg, u.deg))
return wcs.WCS([(detector_frame, det2sky), (sky_frame, None)])


@pytest.fixture()
def fakedata():
""" fixture to create fake data and wcs """
# set up initial parameters
nx = 100
ny = 100
size = nx * ny
ra = 30.
dec = 45.

# create the wcs
wcsobj = make_wcs(nx/2, ny/2, ra=ra, dec=dec)
wcsobj.bounding_box = ((0, nx), (0, ny))

# create the data
data = np.arange(size).reshape(nx, ny)

yield data, wcsobj


@pytest.fixture()
def make_file(tmp_path, fakedata):
""" fixture to create a fake dataset """
# get the fake data
data, wcsobj = fakedata

# create meta
meta = {'wcs': wcsobj}

# create and write the asdf file
tree = {'roman': {'data': data, 'meta': meta}}
af = asdf.AsdfFile(tree)

path = tmp_path / "roman"
path.mkdir(exist_ok=True)
filename = path / "test_roman.asdf"
af.write_to(filename)

yield filename


def test_get_center_pixel(fakedata):
""" test we can get the correct center pixel """
# get the fake data
__, gwcs = fakedata

pixel_coordinates, wcs = get_center_pixel(gwcs, 30., 45.)
assert np.allclose(pixel_coordinates, (np.array(50.), np.array(50.)))
assert np.allclose(wcs.celestial.wcs.crval, np.array([30., 45.]))


@pytest.fixture()
def output_file(tmp_path):
""" fixture to create the output path """
# create output fits path
out = tmp_path / "roman"
out.mkdir(exist_ok=True, parents=True)
output_file = out / "test_output_cutout.fits"
yield output_file


def test_get_cutout(output_file, fakedata):
""" test we can create a cutout """

# get the input wcs
data, gwcs = fakedata
skycoord = gwcs(25, 25, with_units=True)
wcs = WCS(gwcs.to_fits_sip())

# create cutout
get_cutout(data, skycoord, wcs, size=10, outfile=output_file)

# test output
with fits.open(output_file) as hdulist:
data = hdulist[0].data
assert data.shape == (10, 10)
assert data[5, 5] == 2525


def test_asdf_cutout(make_file, output_file):
""" test we can make a cutout """
# make cutout
ra, dec = (29.99901792, 44.99930555)
asdf_cut(make_file, ra, dec, cutout_size=10, output_file=output_file)

# test output
with fits.open(output_file) as hdulist:
data = hdulist[0].data
assert data.shape == (10, 10)
assert data[5, 5] == 2526

2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ packages = find:
python_requires = >=3.9
setup_requires = setuptools_scm
install_requires =
asdf>=2.15.0 # for ASDF file format
astropy>=5.2 # astropy with s3fs support
fsspec[http]>=2022.8.2 # for remote cutouts
s3fs>=2022.8.2 # for remote cutouts
roman_datamodels>=0.17.0 # for roman file support
scipy
Pillow

Expand Down

0 comments on commit b9742bd

Please sign in to comment.