Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle an s3 URI in asdf_cut() #117

Merged
merged 5 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion astrocut/asdf_cutouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,28 @@
import astropy
import gwcs
import numpy as np
import s3fs

from astropy.coordinates import SkyCoord
from astropy.modeling import models


def _get_cloud_http(s3_uri: str) -> str:
""" Get the HTTP URI of a cloud resource from an S3 URI

Parameters
----------
s3_uri : string
the S3 URI of the cloud resource
"""
# create file system
fs = s3fs.S3FileSystem(anon=True)

# open resource and get URL
with fs.open(s3_uri, 'rb') as f:
return f.url()


def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:
""" Get the center pixel from a roman 2d science image

Expand Down Expand Up @@ -247,8 +264,13 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
an image cutout object
"""

# if file comes from AWS cloud bucket, get HTTP URL to open with asdf
file = input_file
if isinstance(input_file, str) and input_file.startswith('s3://'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we may want to do at some point is expand this to support pathlib.Path as input, in addition to a string. I don't think pathlib.Path handles s3 urls, but it looks like there are packages that do, e.g. s3path or s3pathlib. I don't think we need to do it for this PR, but maybe we can create an issue for the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, will make a new issue for this!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue is ASB-27085

file = _get_cloud_http(input_file)

# get the 2d image data
with asdf.open(input_file) as f:
with asdf.open(file) as f:
data = f['roman']['data']
gwcsobj = f['roman']['meta']['wcs']

Expand Down
42 changes: 31 additions & 11 deletions astrocut/tests/test_asdf_cut.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

import pathlib
from unittest.mock import MagicMock, patch
import numpy as np
import pytest

Expand All @@ -12,7 +13,7 @@
from astropy.wcs.utils import pixel_to_skycoord
from gwcs import wcs
from gwcs import coordinate_frames as cf
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut, _slice_gwcs
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut, _slice_gwcs, _get_cloud_http


def make_wcs(xsize, ysize, ra=30., dec=45.):
Expand Down Expand Up @@ -99,16 +100,6 @@ def make_file(tmp_path, fakedata):
yield filename


def test_get_center_pixel(fakedata):
""" test we can get the correct center pixel """
# get the fake data
__, gwcs = fakedata

pixel_coordinates, wcs = get_center_pixel(gwcs, 30., 45.)
assert np.allclose(pixel_coordinates, (np.array(500.), np.array(500.)))
assert np.allclose(wcs.celestial.wcs.crval, np.array([30., 45.]))


@pytest.fixture()
def output(tmp_path):
""" fixture to create the output path """
Expand All @@ -121,6 +112,16 @@ def _output_file(ext='fits'):
yield _output_file


def test_get_center_pixel(fakedata):
""" test we can get the correct center pixel """
# get the fake data
__, gwcs = fakedata

pixel_coordinates, wcs = get_center_pixel(gwcs, 30., 45.)
assert np.allclose(pixel_coordinates, (np.array(500.), np.array(500.)))
assert np.allclose(wcs.celestial.wcs.crval, np.array([30., 45.]))


@pytest.mark.parametrize('quantity', [True, False], ids=['quantity', 'array'])
def test_get_cutout(output, fakedata, quantity):
""" test we can create a cutout """
Expand Down Expand Up @@ -312,3 +313,22 @@ def test_slice_gwcs(fakedata):
# gwcs footprint/bounding_box expects ((x0, x1), (y0, y1)) but cutout.bbox is in ((y0, y1), (x0, x1))
assert (gwcsobj.footprint(bounding_box=tuple(reversed(cutout.bbox_original))) == sliced.footprint()).all()


@patch('s3fs.S3FileSystem')
def test_get_cloud_http(mock_s3fs):
""" test we can get HTTP URI of cloud resource """
# mock s3 file system operations
HTTP_URI = "http_test"
mock_file = MagicMock()
mock_fs = MagicMock()
mock_file.url.return_value = HTTP_URI
mock_fs.open.return_value.__enter__.return_value = mock_file
mock_s3fs.return_value = mock_fs

s3_uri = "s3://test_bucket/test_file.asdf"
http_uri = _get_cloud_http(s3_uri)

assert http_uri == HTTP_URI
mock_s3fs.assert_called_once_with(anon=True)
mock_fs.open.assert_called_once_with(s3_uri, 'rb')
mock_file.url.assert_called_once()
22 changes: 11 additions & 11 deletions astrocut/tests/test_make_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ def test_make_cube(tmpdir):
ecube[:, :, i, 0] = -plane
ecube[:, :, i, 1] = plane
plane += img_sz*img_sz
assert np.alltrue(cube == ecube), "Cube values do not match expected values"
assert np.all(cube == ecube), "Cube values do not match expected values"

tab = Table(hdu[2].data)
assert np.alltrue(tab['TSTART'] == np.arange(num_im)), "TSTART mismatch in table"
assert np.alltrue(tab['TSTOP'] == np.arange(num_im)+1), "TSTOP mismatch in table"
assert np.all(tab['TSTART'] == np.arange(num_im)), "TSTART mismatch in table"
assert np.all(tab['TSTOP'] == np.arange(num_im)+1), "TSTOP mismatch in table"

filenames = np.array([path.split(x)[1] for x in ffi_files])
assert np.alltrue(tab['FFI_FILE'] == np.array(filenames)), "FFI_FILE mismatch in table"
assert np.all(tab['FFI_FILE'] == np.array(filenames)), "FFI_FILE mismatch in table"

hdu.close()

Expand Down Expand Up @@ -86,7 +86,7 @@ def test_make_and_update_cube(tmpdir):
# ecube[:, :, i, 1] = plane
plane += img_sz*img_sz

assert np.alltrue(cube == ecube), "Cube values do not match expected values"
assert np.all(cube == ecube), "Cube values do not match expected values"

hdu.close()

Expand All @@ -110,14 +110,14 @@ def test_make_and_update_cube(tmpdir):
# ecube[:, :, i, 1] = plane
plane += img_sz*img_sz

assert np.alltrue(cube == ecube), "Cube values do not match expected values"
assert np.all(cube == ecube), "Cube values do not match expected values"

tab = Table(hdu[2].data)
assert np.alltrue(tab['STARTTJD'] == np.arange(num_im)), "STARTTJD mismatch in table"
assert np.alltrue(tab['ENDTJD'] == np.arange(num_im)+1), "ENDTJD mismatch in table"
assert np.all(tab['STARTTJD'] == np.arange(num_im)), "STARTTJD mismatch in table"
assert np.all(tab['ENDTJD'] == np.arange(num_im)+1), "ENDTJD mismatch in table"

filenames = np.array([path.split(x)[1] for x in ffi_files])
assert np.alltrue(tab['FFI_FILE'] == np.array(filenames)), "FFI_FILE mismatch in table"
assert np.all(tab['FFI_FILE'] == np.array(filenames)), "FFI_FILE mismatch in table"

hdu.close()

Expand Down Expand Up @@ -156,7 +156,7 @@ def test_iteration(tmpdir, capsys):
cube_2 = hdu_2[1].data

assert cube_1.shape == cube_2.shape, "Mismatch between cube shape for 1 vs 2 iterations"
assert np.alltrue(cube_1 == cube_2), "Cubes made in 1 vs 2 iterations do not match"
assert np.all(cube_1 == cube_2), "Cubes made in 1 vs 2 iterations do not match"

# expected values for cube
ecube = np.zeros((img_sz, img_sz, num_im, 2))
Expand All @@ -168,7 +168,7 @@ def test_iteration(tmpdir, capsys):
ecube[:, :, i, 1] = plane
plane += img_sz*img_sz

assert np.alltrue(cube_1 == ecube), "Cube values do not match expected values"
assert np.all(cube_1 == ecube), "Cube values do not match expected values"


@pytest.mark.parametrize("ffi_type", ["TICA", "SPOC"])
Expand Down
Loading