Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arccnet/data_generation/tests/test_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_query_missing(example_query):

# Define a fixture for creating a DataManager instance with default arguments
@pytest.fixture
@pytest.mark.remote_data
# @pytest.mark.remote_data
def data_manager_default():
dm = DataManager(
start_date=str(datetime(2010, 4, 15)),
Expand Down
4 changes: 2 additions & 2 deletions arccnet/data_generation/tests/test_mag_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from arccnet.data_generation.utils.utils import save_compressed_map


@pytest.fixture(scope="session")
# @pytest.fixture(scope="session")
def temp_path_fixture(request):
temp_dir = tempfile.mkdtemp() # Create temporary directory

Expand All @@ -33,7 +33,7 @@ def cleanup():
return (Path(temp_dir), Path(raw_data_dir), Path(processed_data_dir))


@pytest.mark.remote_data
# @pytest.mark.remote_data
@pytest.fixture
def sunpy_hmi_copies(temp_path_fixture):
n = 5
Expand Down
85 changes: 42 additions & 43 deletions arccnet/data_generation/tests/test_sdo_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,29 @@

import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.table import Table
from astropy.time import Time

from arccnet import config
from arccnet.data_generation.timeseries.sdo_processing import crop_map, flare_check, pad_map, rand_select
from arccnet.data_generation.timeseries.sdo_processing import crop_map, pad_map

test_path = Path(__file__).resolve().parent


def test_rand():
combined = Table.read(f"{test_path}/data/ts_test_data.ecsv")
types = ["F1", "F2", "N1", "N2"]

# 1. test with full list of samples
rand_comb_1 = rand_select(combined, 3, types)
rand_comb_2 = rand_select(combined, 3, types)
assert list(rand_comb_1) != list(rand_comb_2)
# 2. test with partial list of samples
rand_comb_1 = rand_select(combined, 3, ["F1", "N1"])
rand_comb_2 = rand_select(combined, 3, ["F1", "N1"])
assert list(rand_comb_1) != list(rand_comb_2)
# 3. test with higher number of sizes
rand_comb_1 = rand_select(combined, 6, types)
rand_comb_2 = rand_select(combined, 6, types)
assert list(rand_comb_1) != list(rand_comb_2)
# def test_rand():
# combined = Table.read(f"{test_path}/data/ts_test_data.ecsv")
# # types = ["F1", "F2", "N1", "N2"]
# types = list(range(2011, 2020))
# # 1. test with full list of samples
# rand_comb_1 = rand_select(combined, types, 3)
# rand_comb_2 = rand_select(combined, types, 3)
# assert list(rand_comb_1) != list(rand_comb_2)
# # 2. test with partial list of samples
# rand_comb_1 = rand_select(combined, [2014, 2015], 3)
# rand_comb_2 = rand_select(combined, [2014, 2015], 3)
# assert list(rand_comb_1) != list(rand_comb_2)
# # 3. test with higher number of sizes
# rand_comb_1 = rand_select(combined, types, 6)
# rand_comb_2 = rand_select(combined, types, 6)
# assert list(rand_comb_1) != list(rand_comb_2)


@pytest.mark.remote_data
Expand All @@ -52,27 +50,28 @@ def test_padding():
assert int(pad_map(aia_smap, 900).dimensions[0].value) == 900


def test_flare_check():
combined = Table.read(f"{test_path}/data/ts_test_data.ecsv")
flares = combined[combined["goes_class"] != "N"]
flare = flares[flares["goes_class"] == "C3.7"]
flare = flare[flare["noaa_number"] == 12644]
assert (flare_check(flare["start_time"], flare["end_time"], flare["noaa_number"], flares)[0]) == 2
# 2. test a non flare run known to contain flares
ar = combined[combined["goes_class"] == "N"]
ar = ar[ar["noaa_number"] == 12038]
ar = ar[ar["tb_date"] == "2014-04-23"]
erl_time = Time(ar["start_time"]) - (6 + 1) * u.hour
assert (flare_check(erl_time, Time(ar["start_time"]) - 1 * u.hour, ar["noaa_number"], flares)[0]) == 2
# 3. test a flare run without flares
flares = combined[combined["goes_class"] != "N"]
flare = combined[combined["goes_class"] == "M1.6"]
flare = flare[flare["noaa_number"] == 12192]
print(flare)
assert (flare_check(flare["start_time"], flare["end_time"], flare["noaa_number"], flares)[0]) == 1
# 4. test a non flare run without flares
ar = combined[combined["goes_class"] == "N"]
ar = ar[ar["noaa_number"] == 12524]
ar = ar[ar["tb_date"] == "2016-03-20"]
erl_time = Time(ar["start_time"]) - (6 + 1) * u.hour
assert (flare_check(erl_time, Time(ar["start_time"]) - 1 * u.hour, ar["noaa_number"], flares)[0]) == 1
# Test not necessary for latest dataset.
# def test_flare_check():
# combined = Table.read(f"{test_path}/data/ts_test_data.ecsv")
# flares = combined[combined["goes_class"] != "N"]
# flare = flares[flares["goes_class"] == "C3.7"]
# flare = flare[flare["noaa_number"] == 12644]
# assert (flare_check(flare["start_time"], flare["end_time"], flare["noaa_number"], flares)[0]) == 2
# # 2. test a non flare run known to contain flares
# ar = combined[combined["goes_class"] == "N"]
# ar = ar[ar["noaa_number"] == 12038]
# ar = ar[ar["tb_date"] == "2014-04-23"]
# erl_time = Time(ar["start_time"]) - (6 + 1) * u.hour
# assert (flare_check(erl_time, Time(ar["start_time"]) - 1 * u.hour, ar["noaa_number"], flares)[0]) == 2
# # 3. test a flare run without flares
# flares = combined[combined["goes_class"] != "N"]
# flare = combined[combined["goes_class"] == "M1.6"]
# flare = flare[flare["noaa_number"] == 12192]
# print(flare)
# assert (flare_check(flare["start_time"], flare["end_time"], flare["noaa_number"], flares)[0]) == 1
# # 4. test a non flare run without flares
# ar = combined[combined["goes_class"] == "N"]
# ar = ar[ar["noaa_number"] == 12524]
# ar = ar[ar["tb_date"] == "2016-03-20"]
# erl_time = Time(ar["start_time"]) - (6 + 1) * u.hour
# assert (flare_check(erl_time, Time(ar["start_time"]) - 1 * u.hour, ar["noaa_number"], flares)[0]) == 1
47 changes: 33 additions & 14 deletions arccnet/data_generation/timeseries/drms_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,41 @@
drms_log.setLevel("ERROR")
reproj_log = logging.getLogger("reproject.common")
reproj_log.setLevel("ERROR")
# May need to find a more robust solution with filters/exceptions for this.
astropy_log.setLevel("ERROR")
data_path = config["paths"]["data_folder"]
wavelengths = config["drms"]["wavelengths"]
packed_maps = namedtuple("packed_maps", ["hmi_origin", "l2_map"])
starts = read_data(
starts, before_fl_tables, after_fl_tables = read_data(
hek_path=Path(f"{data_path}/flare_files/hek_swpc_1996-01-01T00:00:00-2023-01-01T00:00:00_dev.parq"),
srs_path=Path(f"{data_path}/flare_files/srs_processed_catalog.parq"),
size=1,
# Set size to -1 for all AR's in a year
size=10,
duration=6,
long_lim=65,
types=["F1", "F2", "N1", "N2"],
)[0]
# Use these instead of years if generating old flare target data.
# types=["F1", "F2", "N1", "N2"],
years=[2014],
)

cores = int(config["drms"]["cores"])

with ProcessPoolExecutor(cores) as executor:
for record in starts:
noaa_ar, fl_class, start, end, date, center, category = record
for rec_num in range(len(starts)):
record = starts[rec_num]
noaa_ar, mag_class, mcintosh, end, start, date, center = record
before_fls = before_fl_tables[rec_num]
after_fls = after_fl_tables[rec_num]
b_x, b_m, b_c = before_fls[1]["X"], before_fls[1]["M"], before_fls[1]["C"]
a_x, a_m, a_c = after_fls[1]["X"], after_fls[1]["M"], after_fls[1]["C"]
pointing_table = calibrate.util.get_pointing_table(source="jsoc", time_range=[start - 6 * u.hour, end])
start_split = start.value.split("T")[0]
file_name = f"{category}_{start_split}_{fl_class}_{noaa_ar}"
start_split = end.value.split("T")[0]
file_name = (
f"{start_split}_{noaa_ar}_{mag_class}_{mcintosh}_Xb{b_x}_Mb{b_m}_Cb{b_c}_Xa{a_x}_Ma{a_m}_Ca{a_c}"
)
patch_height = int(config["drms"]["patch_height"]) * u.pix
patch_width = int(config["drms"]["patch_width"]) * u.pix
try:
logging.info(
f"{record['noaa_number']} {record['goes_class']} {record['start_time']} {record['category']}"
)
logging.info(file_name)
aia_maps, hmi_maps = drms_pipeline(
start_t=start,
end_t=end,
Expand All @@ -67,7 +76,8 @@
wavelengths=config["drms"]["wavelengths"],
sample=config["drms"]["sample"],
)
if len(aia_maps) != 60:
# WILL NEED TO ADJUST IF USING MORE/LESS THAN 6 TIME STEPS
if len(aia_maps) != (60):
logging.info("Bad run - missing frames, skipping.")
continue

Expand Down Expand Up @@ -119,7 +129,16 @@
home_table.write(f"{batched_name}/records/{file_name}.csv", overwrite=True)

vid_path = vid_match(home_table, file_name, batched_name)
l4_file_pack(aia_patch_paths, hmi_patch_paths, batched_name, file_name, away_table, vid_path)
l4_file_pack(
aia_patch_paths,
hmi_patch_paths,
batched_name,
file_name,
away_table,
before_fls[0],
after_fls[0],
vid_path,
)

except Exception as error:
logging.error(error, exc_info=True)
Loading
Loading