Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
41 changes: 41 additions & 0 deletions arccnet/data_generation/utils/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@
import pandas as pd
import pytest
from numpy.testing import assert_allclose, assert_array_equal
from sunpy.data import sample
from sunpy.map import Map

import astropy.units as u
from astropy.time import Time

from arccnet.data_generation.utils.utils import (
check_column_values,
grouped_stratified_split,
reproject,
round_to_daystart,
save_df_to_html,
time_split,
)
Expand Down Expand Up @@ -141,3 +148,37 @@ def test_time_split():
assert len(train) == 6
assert len(test) == 2
assert len(valid) == 2


def test_rereproject():
aia = Map(sample.AIA_171_IMAGE)
hmi = Map(sample.HMI_LOS_IMAGE)

aia_to_hmi = reproject(aia, hmi)
assert aia_to_hmi.dimensions == hmi.dimensions
assert aia_to_hmi.wcs.wcs == hmi.wcs.wcs # this compares the important info not sure why
assert aia_to_hmi.date == hmi.date
assert aia_to_hmi.detector == aia.detector
assert aia_to_hmi.instrument == aia.instrument
assert aia_to_hmi.observatory == aia.observatory
assert aia_to_hmi.wavelength == aia.wavelength


def test_round_to_daystart():
expected = Time(
[
"2021-01-01",
"2021-01-01",
"2021-01-02",
"2021-01-02",
"2021-01-02",
"2021-01-02",
"2021-01-03",
"2021-01-03",
"2021-01-03",
]
)
t = Time("2021-01-01")
times = t + [0, 6, 12, 18, 24, 30, 36, 42, 48] * u.hour
out = round_to_daystart(times)
assert_array_equal(expected, out)
171 changes: 170 additions & 1 deletion arccnet/data_generation/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from pathlib import Path
from datetime import datetime, timedelta
from concurrent.futures import ProcessPoolExecutor

import numpy as np
import pandas as pd
import sunpy.map
from pandas.core.interchange.dataframe_protocol import DataFrame
from pandas import DataFrame
from sklearn.model_selection import StratifiedGroupKFold
from sunpy.map import Map

import astropy.units as u
from astropy.io.fits import CompImageHDU
from astropy.table import Table
from astropy.time import Time

from arccnet.utils.logging import logger

Expand Down Expand Up @@ -266,3 +273,165 @@ def time_split(df: DataFrame, *, time_column: str, train: int, test: int, valida
val_indices = np.hstack(val_indices)

return train_indices, test_indices, val_indices


def _get_info(path):
map = Map(path)
info = {
"Observatory": map.observatory,
"Instrument": map.instrument,
"Detector": map.detector,
"Measurement": map.measurement,
"Wavelength": map.wavelength,
"Date": map.date,
"Path": path,
}
del map
return info


def build_local_files(path):
r"""
Build local fits file table based on recursive search from given path.


Parameters
----------
path : `str` or `pathlib.Path`
Directory to start search from

Returns
-------
fits_file_info : `astropy.table.Table`
Table of fits file info
"""
root = Path(path)
files = root.rglob("*.fits")

with ProcessPoolExecutor() as executor:
file_info = executor.map(_get_info, files, chunksize=1000)

fits_file_info = Table(list(file_info))
return fits_file_info


def reproject(input_map: Map, traget_map: Map, **reproject_kwargs) -> Map:
r"""
Reproject the input map to the target map WCS.

Copy back over some import meta information removed when using reproject.

Parameters
----------
input_map
Input map to reproject
traget_map
Target map

Returns
-------

"""
output_map = input_map.reproject_to(traget_map.wcs, **reproject_kwargs)
# Created by manually looking at the headers and removing anything related to WCS, dates, data values and observer
# fmt: off
KEY_MAP = {('SDO', 'AIA'): {'bld_vers', 'lvl_num', 'trecstep', 'trecepoc', 'trecroun', 'origin', 'telescop',
'instrume', 'camera', 'img_type', 'exptime', 'expsdev', 'int_time', 'wavelnth',
'waveunit', 'wave_str', 'fsn', 'fid', 'quallev0', 'quality', 'flat_rec', 'nspikes',
'mpo_rec', 'inst_rot', 'imscl_mp', 'x0_mp', 'y0_mp', 'asd_rec', 'sat_y0', 'sat_z0',
'sat_rot', 'acs_mode', 'acs_eclp', 'acs_sunp', 'acs_safe', 'acs_cgt', 'orb_rec',
'roi_sum', 'roi_nax1', 'roi_nay1', 'roi_llx1', 'roi_lly1', 'roi_nax2', 'roi_nay2',
'roi_llx2', 'roi_lly2', 'pixlunit', 'dn_gain', 'eff_area', 'eff_ar_v', 'tempccd',
'tempgt', 'tempsmir', 'tempfpad', 'ispsname', 'isppktim', 'isppktvn', 'aivnmst',
'aimgots', 'asqhdr', 'asqtnum', 'asqfsn', 'aiahfsn', 'aecdelay', 'aiaecti', 'aiasen',
'aifdbid', 'aimgotss', 'aifcps', 'aiftswth', 'aifrmlid', 'aiftsid', 'aihismxb',
'aihis192', 'aihis348', 'aihis604', 'aihis860', 'aifwen', 'aimgshce', 'aectype',
'aecmode', 'aistate', 'aiaecenf', 'aifiltyp', 'aimshobc', 'aimshobe', 'aimshotc',
'aimshote', 'aimshcbc', 'aimshcbe', 'aimshctc', 'aimshcte', 'aicfgdl1', 'aicfgdl2',
'aicfgdl3', 'aicfgdl4', 'aifoenfl', 'aimgfsn', 'aimgtyp', 'aiawvlen', 'aiagp1',
'aiagp2', 'aiagp3', 'aiagp4', 'aiagp5', 'aiagp6', 'aiagp7', 'aiagp8', 'aiagp9',
'aiagp10', 'agt1svy', 'agt1svz', 'agt2svy', 'agt2svz', 'agt3svy', 'agt3svz',
'agt4svy', 'agt4svz', 'aimgshen', 'keywddoc', 'recnum', 'blank', 'drms_id',
'primaryk', 'comment', 'history', 'keycomments'},
('SDO', 'HMI'): {'telescop', 'instrume', 'wavelnth', 'camera', 'bunit', 'origin', 'content', 'quality',
'quallev1', 'bld_vers', 'hcamid', 'source', 'trecepoc', 'trecstep', 'trecunit',
'cadence', 'datasign', 'hflid', 'hcftid', 'qlook', 'cal_fsn', 'lutquery', 'tsel',
'tfront', 'tintnum', 'sintnum', 'distcoef', 'rotcoef', 'odicoeff', 'orocoeff',
'polcalm', 'codever0', 'codever1', 'codever2', 'codever3', 'calver64', 'recnum',
'blank', 'checksum', 'datasum', 'waveunit', 'detector', 'history', 'comment',
'keycomments'},
}
# fmt: on
keys = KEY_MAP.get((input_map.observatory, input_map.detector), [])
meta_to_update = {key: input_map.meta[key] for key in keys}
output_map.meta.update(meta_to_update)
return output_map


def round_to_daystart(time: Time) -> Time:
r"""
Round time to given interval start of day


Parameters
----------
time :
Times to round
interval :
Interval

Examples
--------
>>> import astropy.units as u
>>> from astropy.time import Time
>>> time = Time('2000-01-01')
>>> times = time + [0, 6, 12, 18, 24, 30, 36, 48] *u.h
<Time object: scale='utc' format='isot' value=['2000-01-01T01:00:00.000' '2000-01-01T02:00:00.000'
'2000-01-01T03:00:00.000' '2000-01-01T04:00:00.000'
'2000-01-01T05:00:00.000' '2000-01-01T06:00:00.000'
'2000-01-01T07:00:00.000' '2000-01-01T08:00:00.000'
'2000-01-01T09:00:00.000' '2000-01-01T10:00:00.000'
Returns
-------

"""
day_start = Time(time.strftime("%Y-%m-%d"))
next_day = day_start + 1 * u.day
diff = time - day_start
diff.to_value(u.hour)
rounded = np.where(diff.to_value(u.hour) < 12, day_start, next_day)
return Time(rounded)


### Code run to patch up current data needs to be integrated into pipeline


def _aia_to_hmi(aia_path, hmi_map, hmi_path):
aia_path = Path(aia_path)
aia_map = Map(aia_path)
print("!!!!!!!!!!!!!!!!!!!! reprojecting !!!!!!!!!!!!!!!!")
aia_repro = reproject(aia_map, hmi_map)
print("!!!!!!!!!!!!!!!!!!!! done reprojecting !!!!!!!!!!!!!!!!")
aia_repro.meta["repotar"] = str(hmi_path.name)
outpath = hmi_path.parent.parent.parent.parent / "euv" / "fits" / "aia"
outpath.mkdir(parents=True, exist_ok=True)
outfile = outpath / f"{aia_path.stem}_reprojected.fits"
print(f"!!!!!!!!!!!!!!!!!!!! Output file !!!!!!!!!!!!!!!!: {outfile}")
aia_repro.save(outfile, hdu_type=CompImageHDU)


def reproject_aia_to_hmi(hmi_paths, fits_info, executor):
for path in list(hmi_paths):
print(f"HMI path {path}")
hmi_path = Path("/mnt/ARCAFF/v0.2.2/ARCCnet/" + path)
hmi_map = Map(hmi_path)
day = round_to_daystart(hmi_map.date)
print(day)
fits = fits_info[fits_info["Day"] == day]
for row in fits[fits["Detector"] == "AIA"]:
print(f"AIA and HMI paths: {row['Path_str']}, {hmi_path}")
executor.submit(_aia_to_hmi, row["Path_str"], hmi_map, hmi_path)


# with ProcessPoolExecutor(max_workers=32) as executor:
# reproject_aia_to_hmi(hmi_paths, sdo_only, executor)
84 changes: 81 additions & 3 deletions arccnet/pipeline/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime, timedelta

import numpy as np
import pandas as pd

import astropy.units as u
from astropy.table import Column, MaskedColumn, QTable, Table, join, vstack
Expand Down Expand Up @@ -1126,11 +1127,88 @@ def process_ar_catalogs(config):
return processed_catalog


def create_pit_flare_dataset(config):
logger.info("Generating `Point in Time Flare Forecasting` dataset")
data_root = Path(config["paths"]["data_root"])

flare_file = (
data_root
/ "02_intermediate"
/ "metadata"
/ "flares"
/ "hek_swpc_1996-01-01T00:00:00-2023-01-01T00:00:00_dev.parq"
)
classification_file = data_root / "04_final" / "data" / "region_cutouts" / "region_classification.parq"

# Load flares and extract flare stats
flares = Table.read(flare_file)
flare_stats = extract_flare_statistics(flares)

# rename to match ARs and parse time
flare_stats.rename(columns={"date": "target_time", "noaa_number": "number"}, inplace=True)
flare_stats.target_time = pd.to_datetime(flare_stats.target_time)

# Merge cutouts and flare stats
cutouts = Table.read(classification_file)
ars = cutouts[cutouts["region_type"] == "AR"]
good_cols = [name for name in cutouts.colnames if len(cutouts[name].shape) <= 1]
ars_df = ars[good_cols].to_pandas()

# outer join to keep flaring and no-flaring ARs for training
flares_and_ars = ars_df.merge(flare_stats, on=["target_time", "number"], how="outer")

version = __version__ if "dev" not in __version__ else "dev" # unless it's a release use dev
start = config["general"]["start_date"]
end = config["general"]["end_date"]
start = start if isinstance(start, datetime) else datetime.fromisoformat(start)
end = end if isinstance(end, datetime) else datetime.fromisoformat(end)
file_name = f"mag-pit-flare-dataset_{start.isoformat()}" f"-{end.isoformat()}_{version}.parq"

data_dir_processed = Path(config["paths"]["data_dir_processed"])
flare_processed_catalog_file = data_dir_processed / file_name
flare_processed_catalog_file.parent.mkdir(exist_ok=True, parents=True)
flares_and_ars.to_parquet(flare_processed_catalog_file)

return flare_processed_catalog_file


def extract_flare_statistics(flares):
r"""
Extract daily (24h) flare statistics for each NOAA AR

Uses flare peak time and extract the number of A, B, C, M and X class flare per AR per 24 hours.

Parameters
----------
flares : `astropy.table.Table`
Flare events

"""
noaa_num_mask = flares["noaa_number"] != 0
flares = flares[noaa_num_mask]

flares_df = flares.to_pandas()

# Group by day (date) and NOAA number and calculate number of flares per class
flare_counts = pd.DataFrame()
for (date, noaa_num), group in flares_df.groupby([flares_df["peak_time"].dt.date, "noaa_number"]):
flare_count_by_class = group["goes_class"].str[0].value_counts()
new_row = flare_count_by_class.to_dict()
new_row["date"] = date
new_row["noaa_number"] = noaa_num
flare_counts = pd.concat([flare_counts, pd.DataFrame([new_row])])

# will have nans for date, ar combo with no flares so fill with 0s
flare_counts.fillna(0, inplace=True)

return flare_counts


def main():
logger.debug("Starting main")
process_flares(config)
catalog = process_ar_catalogs(config)
process_ars(config, catalog)
create_pit_flare_dataset(config)
# catalog = process_ar_catalogs(config)
# process_ars(config, catalog)


if __name__ == "__main__":
Expand Down
Loading
Loading