diff --git a/arccnet/data_generation/euv/__init__.py b/arccnet/data_generation/euv/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/arccnet/data_generation/utils/tests/test_utils.py b/arccnet/data_generation/utils/tests/test_utils.py index 9ce08d0e..2bfe2ced 100644 --- a/arccnet/data_generation/utils/tests/test_utils.py +++ b/arccnet/data_generation/utils/tests/test_utils.py @@ -6,10 +6,17 @@ import pandas as pd import pytest from numpy.testing import assert_allclose, assert_array_equal +from sunpy.data import sample +from sunpy.map import Map + +import astropy.units as u +from astropy.time import Time from arccnet.data_generation.utils.utils import ( check_column_values, grouped_stratified_split, + reproject, + round_to_daystart, save_df_to_html, time_split, ) @@ -141,3 +148,37 @@ def test_time_split(): assert len(train) == 6 assert len(test) == 2 assert len(valid) == 2 + + +def test_rereproject(): + aia = Map(sample.AIA_171_IMAGE) + hmi = Map(sample.HMI_LOS_IMAGE) + + aia_to_hmi = reproject(aia, hmi) + assert aia_to_hmi.dimensions == hmi.dimensions + assert aia_to_hmi.wcs.wcs == hmi.wcs.wcs # this compares the important info not sure why + assert aia_to_hmi.date == hmi.date + assert aia_to_hmi.detector == aia.detector + assert aia_to_hmi.instrument == aia.instrument + assert aia_to_hmi.observatory == aia.observatory + assert aia_to_hmi.wavelength == aia.wavelength + + +def test_round_to_daystart(): + expected = Time( + [ + "2021-01-01", + "2021-01-01", + "2021-01-02", + "2021-01-02", + "2021-01-02", + "2021-01-02", + "2021-01-03", + "2021-01-03", + "2021-01-03", + ] + ) + t = Time("2021-01-01") + times = t + [0, 6, 12, 18, 24, 30, 36, 42, 48] * u.hour + out = round_to_daystart(times) + assert_array_equal(expected, out) diff --git a/arccnet/data_generation/utils/utils.py b/arccnet/data_generation/utils/utils.py index a92cfec6..c39f0dcf 100644 --- a/arccnet/data_generation/utils/utils.py +++ b/arccnet/data_generation/utils/utils.py @@ -1,11 +1,18 @@ from pathlib import Path from datetime import datetime, timedelta +from concurrent.futures import ProcessPoolExecutor import numpy as np import pandas as pd import sunpy.map -from pandas.core.interchange.dataframe_protocol import DataFrame +from pandas import DataFrame from sklearn.model_selection import StratifiedGroupKFold +from sunpy.map import Map + +import astropy.units as u +from astropy.io.fits import CompImageHDU +from astropy.table import Table +from astropy.time import Time from arccnet.utils.logging import logger @@ -266,3 +273,165 @@ def time_split(df: DataFrame, *, time_column: str, train: int, test: int, valida val_indices = np.hstack(val_indices) return train_indices, test_indices, val_indices + + +def _get_info(path): + map = Map(path) + info = { + "Observatory": map.observatory, + "Instrument": map.instrument, + "Detector": map.detector, + "Measurement": map.measurement, + "Wavelength": map.wavelength, + "Date": map.date, + "Path": path, + } + del map + return info + + +def build_local_files(path): + r""" + Build local fits file table based on recursive search from given path. + + + Parameters + ---------- + path : `str` or `pathlib.Path` + Directory to start search from + + Returns + ------- + fits_file_info : `astropy.table.Table` + Table of fits file info + """ + root = Path(path) + files = root.rglob("*.fits") + + with ProcessPoolExecutor() as executor: + file_info = executor.map(_get_info, files, chunksize=1000) + + fits_file_info = Table(list(file_info)) + return fits_file_info + + +def reproject(input_map: Map, traget_map: Map, **reproject_kwargs) -> Map: + r""" + Reproject the input map to the target map WCS. + + Copy back over some import meta information removed when using reproject. + + Parameters + ---------- + input_map + Input map to reproject + traget_map + Target map + + Returns + ------- + + """ + output_map = input_map.reproject_to(traget_map.wcs, **reproject_kwargs) + # Created by manually looking at the headers and removing anything related to WCS, dates, data values and observer + # fmt: off + KEY_MAP = {('SDO', 'AIA'): {'bld_vers', 'lvl_num', 'trecstep', 'trecepoc', 'trecroun', 'origin', 'telescop', + 'instrume', 'camera', 'img_type', 'exptime', 'expsdev', 'int_time', 'wavelnth', + 'waveunit', 'wave_str', 'fsn', 'fid', 'quallev0', 'quality', 'flat_rec', 'nspikes', + 'mpo_rec', 'inst_rot', 'imscl_mp', 'x0_mp', 'y0_mp', 'asd_rec', 'sat_y0', 'sat_z0', + 'sat_rot', 'acs_mode', 'acs_eclp', 'acs_sunp', 'acs_safe', 'acs_cgt', 'orb_rec', + 'roi_sum', 'roi_nax1', 'roi_nay1', 'roi_llx1', 'roi_lly1', 'roi_nax2', 'roi_nay2', + 'roi_llx2', 'roi_lly2', 'pixlunit', 'dn_gain', 'eff_area', 'eff_ar_v', 'tempccd', + 'tempgt', 'tempsmir', 'tempfpad', 'ispsname', 'isppktim', 'isppktvn', 'aivnmst', + 'aimgots', 'asqhdr', 'asqtnum', 'asqfsn', 'aiahfsn', 'aecdelay', 'aiaecti', 'aiasen', + 'aifdbid', 'aimgotss', 'aifcps', 'aiftswth', 'aifrmlid', 'aiftsid', 'aihismxb', + 'aihis192', 'aihis348', 'aihis604', 'aihis860', 'aifwen', 'aimgshce', 'aectype', + 'aecmode', 'aistate', 'aiaecenf', 'aifiltyp', 'aimshobc', 'aimshobe', 'aimshotc', + 'aimshote', 'aimshcbc', 'aimshcbe', 'aimshctc', 'aimshcte', 'aicfgdl1', 'aicfgdl2', + 'aicfgdl3', 'aicfgdl4', 'aifoenfl', 'aimgfsn', 'aimgtyp', 'aiawvlen', 'aiagp1', + 'aiagp2', 'aiagp3', 'aiagp4', 'aiagp5', 'aiagp6', 'aiagp7', 'aiagp8', 'aiagp9', + 'aiagp10', 'agt1svy', 'agt1svz', 'agt2svy', 'agt2svz', 'agt3svy', 'agt3svz', + 'agt4svy', 'agt4svz', 'aimgshen', 'keywddoc', 'recnum', 'blank', 'drms_id', + 'primaryk', 'comment', 'history', 'keycomments'}, + ('SDO', 'HMI'): {'telescop', 'instrume', 'wavelnth', 'camera', 'bunit', 'origin', 'content', 'quality', + 'quallev1', 'bld_vers', 'hcamid', 'source', 'trecepoc', 'trecstep', 'trecunit', + 'cadence', 'datasign', 'hflid', 'hcftid', 'qlook', 'cal_fsn', 'lutquery', 'tsel', + 'tfront', 'tintnum', 'sintnum', 'distcoef', 'rotcoef', 'odicoeff', 'orocoeff', + 'polcalm', 'codever0', 'codever1', 'codever2', 'codever3', 'calver64', 'recnum', + 'blank', 'checksum', 'datasum', 'waveunit', 'detector', 'history', 'comment', + 'keycomments'}, + } + # fmt: on + keys = KEY_MAP.get((input_map.observatory, input_map.detector), []) + meta_to_update = {key: input_map.meta[key] for key in keys} + output_map.meta.update(meta_to_update) + return output_map + + +def round_to_daystart(time: Time) -> Time: + r""" + Round time to given interval start of day + + + Parameters + ---------- + time : + Times to round + interval : + Interval + + Examples + -------- + >>> import astropy.units as u + >>> from astropy.time import Time + >>> time = Time('2000-01-01') + >>> times = time + [0, 6, 12, 18, 24, 30, 36, 48] *u.h +