Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ exclude: ".*(.csv|.fits|.fts|.fit|.header|.txt|tca.*|.json|.asdf)$|^CITATION.rst
repos:
# This should be before any formatting hooks like isort
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.11.2"
rev: "v0.12.1"
hooks:
- id: ruff
args: ["--fix"]
Expand Down
56 changes: 24 additions & 32 deletions arccnet/data_generation/timeseries/drms_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,40 @@
crop_map,
drms_pipeline,
hmi_l2,
l4_file_pack,
map_reproject,
match_files,
read_data,
table_match,
vid_match,
)

__all__ = []

# Logging settings here.
drms_log = logging.getLogger("drms")
drms_log.setLevel("WARNING")
reproj_log = logging.getLogger("reproject.common")
reproj_log.setLevel("WARNING")
# May need to find a more robust solution with filters/exceptions for this.
astropy_log.setLevel("ERROR")

packed_maps = namedtuple("packed_maps", ["hmi_origin", "l2_map"])

if __name__ == "__main__":
__all__ = []

# Logging settings here.
drms_log = logging.getLogger("drms")
drms_log.setLevel("WARNING")
reproj_log = logging.getLogger("reproject.common")
reproj_log.setLevel("WARNING")
# May need to find a more robust solution with filters/exceptions for this.
astropy_log.setLevel("ERROR")
data_path = config["paths"]["data_folder"]
packed_maps = namedtuple("packed_maps", ["hmi_origin", "l2_map"])
starts = read_data(
hek_path="/Users/danielgass/Desktop/ARCCnet/ARCCnet/hek_swpc_1996-01-01T00:00:00-2023-01-01T00:00:00_dev.parq",
srs_path="/Users/danielgass/Desktop/ARCCnet/ARCCnet/arccnet/data_generation/timeseries/srs_processed_catalog.parq",
hek_path=Path(f"{data_path}/flare_files/hek_swpc_1996-01-01T00:00:00-2023-01-01T00:00:00_dev.parq"),
srs_path=Path(f"{data_path}/flare_files/srs_processed_catalog.parq"),
size=10,
duration=24,
duration=6,
)
cores = int(config["drms"]["cores"])
with ProcessPoolExecutor(cores) as executor:
for record in [starts[0]]:
for record in [starts[-1]]:
noaa_ar, fl_class, start, end, date, center = record
pointing_table = calibrate.util.get_pointing_table(source="jsoc", time_range=[start - 6 * u.hour, end])

start_split = start.value.split("T")[0]
file_name = f"{fl_class}_{noaa_ar}_{start_split}"
file_name = f"{start_split}_{fl_class}_{noaa_ar}"
patch_height = int(config["drms"]["patch_height"]) * u.pix
patch_width = int(config["drms"]["patch_width"]) * u.pix
try:
Expand Down Expand Up @@ -89,8 +90,8 @@
print(hmi_patch_paths)

# For some reason, aia_proc becomes an empty list after this function call.
home_table, aia_patch_paths, aia_quality, hmi_patch_paths, hmi_quality = table_match(
list(aia_patch_paths), list(hmi_patch_paths)
home_table, aia_patch_paths, aia_quality, aia_time, hmi_patch_paths, hmi_quality, hmi_time = (
table_match(list(aia_patch_paths), list(hmi_patch_paths))
)

# This can probably be streamlined/functionalized to make the pipeline look better.
Expand All @@ -99,8 +100,10 @@
Path(f"{batched_name}/tars").mkdir(parents=True, exist_ok=True)
hmi_away = ["HMI/" + Path(file).name for file in hmi_patch_paths]
aia_away = ["AIA/" + Path(file).name for file in aia_patch_paths]
aia_wvl = home_table["Wavelength"]
away_table = Table(
{
"AIA wavelength": aia_wvl,
"AIA files": aia_away,
"AIA quality": aia_quality,
"HMI files": hmi_away,
Expand All @@ -110,19 +113,8 @@

home_table.write(f"{batched_name}/records/{file_name}.csv", overwrite=True)

## Commented out until we're ready to package.
# away_table.write(f"{batched_name}/records/out_{file_name}.csv", overwrite=True)
# with tarfile.open(f"{batched_name}/tars/{file_name}.tar", "w") as tar:
# for file in aia_maps:
# name = Path(file).name
# tar.add(file, arcname=f"AIA/{name}")
# for file in np.unique(hmi_maps):
# name = Path(file).name
# tar.add(file, arcname=f"HMI/{name}")
# tar.add(f"{batched_name}/records/out_{file_name}.csv", arcname=f"{file_name}.csv")
vid_path = vid_match(home_table, file_name, batched_name)
l4_file_pack(aia_patch_paths, hmi_patch_paths, batched_name, file_name, away_table, vid_path)

except Exception as error:
logging.error(error, exc_info=True)
# 70 X class flares.
# Random sample (100 ish for M and below flares)
# Make a .gif of patch for each run
108 changes: 98 additions & 10 deletions arccnet/data_generation/timeseries/sdo_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import os
import sys
import glob
import shutil
import logging
import warnings
import itertools
from random import sample
from pathlib import Path
Expand All @@ -25,10 +28,13 @@
from arccnet import config
from arccnet.data_generation.mag_processing import pixel_to_bboxcoords
from arccnet.data_generation.utils.utils import save_compressed_map
from arccnet.visualisation.data import mosaic_animate, mosaic_plot

warnings.simplefilter("ignore", RuntimeWarning)
reproj_log = logging.getLogger("reproject.common")
reproj_log.setLevel("WARNING")
os.environ["JSOC_EMAIL"] = "[email protected]"


__all__ = [
"read_data",
"hmi_l2",
Expand All @@ -41,6 +47,7 @@
"table_match",
"crop_map",
"map_reproject",
"l4_file_pack",
]


Expand All @@ -55,7 +62,7 @@ def read_data(hek_path: str, srs_path: str, size: int, duration: int):
srs_path : `str`
The path to the parquet file containing parsed noaa srs active region information.
size : `int`
The size of the sample to be generated. (Generates 10% X, 30% M, 60% C)
The size of the sample to be generated. (Generates 10% X, 30% M, 60% C by default)
duration : `int`
The duration of the data sample in hours.

Expand Down Expand Up @@ -84,17 +91,21 @@ def read_data(hek_path: str, srs_path: str, size: int, duration: int):
flares["tb_date"] = [date.split(" ")[0] for date in flares["tb_date"]]
flares = join(flares, srs, keys_left="noaa_number", keys_right="number")

flares = flares[abs(flares["longitude"].value) <= 70]
flares = flares[abs(flares["longitude"].value) <= 65]
flares = flares[flares["tb_date"] == flares["srs_date"]]
x_flares = flares[[flare.startswith("X") for flare in flares["goes_class"]]]
x_flares = x_flares[x_flares["noaa_number"] == 12192]
x_flares = x_flares[x_flares["goes_class"] == "X1.0"]
# x_flares = x_flares[x_flares['tb_date'] == '2014-10-27']
# x_flares = x_flares[sample(range(len(x_flares)), k=int(0.1 * size))]
# x_flares = x_flares[x_flares["noaa_number"] == 11158]
# x_flares = x_flares[x_flares["goes_class"] == "X2.2"]
# x_flares = x_flares[x_flares["tb_date"] == "2014-10-27"]
x_flares = x_flares[sample(range(len(x_flares)), k=int(0.1 * size))]
m_flares = flares[[flare.startswith("M") for flare in flares["goes_class"]]]
m_flares = m_flares[sample(range(len(m_flares)), k=int(0.3 * size))]

c_flares = flares[[flare.startswith("C") for flare in flares["goes_class"]]]
c_flares = c_flares[sample(range(len(c_flares)), k=int(0.6 * size))]
# c_flares = c_flares[sample(range(len(c_flares)), k=int(0.6 * size))]
c_flares = c_flares[c_flares["noaa_number"] == 11818]
c_flares = c_flares[c_flares["goes_class"] == "C1.6"]
c_flares = c_flares[c_flares["tb_date"] == "2013-08-20"]

combined = vstack([x_flares, m_flares, c_flares])
combined["c_coord"] = [
Expand Down Expand Up @@ -621,26 +632,33 @@ def table_match(aia_maps, hmi_maps):
aia_quality = []
hmi_paths = []
hmi_times = [Time(fits.open(hmi_map)[1].header["date-obs"]) for hmi_map in hmi_maps]
paired_times = []
aia_times = []
hmi_quality = []

for aia_map in aia_maps:
t_d = [abs((Time(fits.open(aia_map)[1].header["date-obs"]) - hmi_time).value) for hmi_time in hmi_times]
date = fits.open(aia_map)[1].header["date-obs"]
t_d = [abs((Time(date) - hmi_time).value) for hmi_time in hmi_times]
hmi_match = hmi_maps[t_d.index(min(t_d))]
aia_paths.append(aia_map)
hmi_paths.append(hmi_match)
paired_times.append(fits.open(hmi_match)[1].header["date-obs"])
hmi_quality.append(fits.open(hmi_match)[1].header["quality"])
aia_quality.append(fits.open(aia_map)[1].header["quality"])
aia_wavelnth.append(fits.open(aia_map)[1].header["wavelnth"])
aia_times.append(date)
paired_table = Table(
{
"Wavelength": aia_wavelnth,
"AIA files": aia_paths,
"AIA quality": aia_quality,
"AIA time": aia_times,
"HMI files": hmi_paths,
"HMI quality": hmi_quality,
"HMI time": paired_times,
}
)
return paired_table, aia_paths, aia_quality, hmi_paths, hmi_quality
return paired_table, aia_paths, aia_quality, aia_times, hmi_paths, hmi_quality, paired_times


def crop_map(sdo_map, center, height, width, noaa_time):
Expand Down Expand Up @@ -707,3 +725,73 @@ def map_reproject(sdo_packed):
save_compressed_map(sdo_rpr, fits_path, hdu_type=CompImageHDU, overwrite=True)

return fits_path


def vid_match(table, name, path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would probably be better as 3 functions:

  1. a function to plot a "batch" of data or X maps in a certain layout (dealing) with missing maps
  2. A function to make a bunch of plots based on the input table
  3. A function to turn the plots into a video

Also did you look at using matplotlib for the video part or why use openCV - I guess we already have a dependency on it so no big deal either way.

r"""
Creates an animated mosaic of images from the data run - including all AIA wavelengths and HMI 720s magnetogram.

Parameters
----------
table : `AstropyTable`
An astropy table containing the filenames of all AIA wavelengths and their paired HMI files.
name : `str`
A string containing the name of the file to be appended, matching all versions of a file above level 1.
path: `str`
The base path of the processed data.

Returns
----------
output_file : `str`
A string containing the path of the completed mosaic animation.
"""
# Check this carefully
hmi_files = table["HMI files"].value
wvls = np.unique([table["Wavelength"].value])
hmi_files = np.unique(hmi_files)
nrows, ncols = 4, 3 # define your subplot grid
for file in range(len(hmi_files)):
hmi = hmi_files[file]
mosaic_plot(hmi, name, file, nrows, ncols, wvls, table, path)

return mosaic_animate(path, name)


def l4_file_pack(aia_paths, hmi_paths, dir_path, rec, out_table, anim_path):
r"""
Packs files into folders along with folder specific records identifying .fits files

Parameters
----------
aia_paths : `list`
The paths to the aia files to be packed.
hmi_paths : `list`
The paths to the hmi files to be packed.
dir_path : `str`
The path to the directory of the l4 data.
rec : `str`
The record value unique to the current run (date, ar number, flare class).
out_table : `AstropyTable`
The table containing the records specific to the folder containing information for the current run.
anim_path: `str`
The path to the mosaic animation of the current run.
"""
folder_hmi = f"{dir_path}/data/{rec}/HMI/"
Path(folder_hmi).mkdir(parents=True, exist_ok=True)
folder_aia = f"{dir_path}/data/{rec}/AIA/"
Path(folder_aia).mkdir(parents=True, exist_ok=True)
for file in aia_paths:
name = Path(file).name
if os.path.exists(f"{folder_aia}/{name}"):
os.remove(f"{folder_aia}/{name}")
shutil.copy(file, f"{folder_aia}/{name}")

for file in np.unique(hmi_paths):
name = Path(file).name
name = Path(f"{folder_hmi}/{name}").name
if os.path.exists(f"{folder_hmi}/{name}"):
os.remove(f"{folder_hmi}/{name}")
shutil.copy(file, f"{folder_hmi}/{name}")

shutil.copy(anim_path, f"{dir_path}/data/{rec}")
out_table.write(f"{dir_path}/data/{rec}/{rec}.csv", overwrite=True)
5 changes: 3 additions & 2 deletions arccnet/models/cutouts/mcintosh/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,14 @@ def apply_mask_at_evaluation(output_logits, z_pred, p_pred=None, valid_dict=None
return masked_logits


def test(
# Changed function name from test to validate to avoid PyTest trying to run it as a test. Will need to update func calls.
def validate(
model: nn.Module,
device: torch.device,
loader: DataLoader,
valid_p_for_z: dict,
valid_c_for_zp: dict,
teacher_forcing_ratio=None,
teacher_forcing_ratio=None, # noqa
) -> tuple:
"""
Tests the model and computes accuracy and F1 scores for each component with optional Teacher Forcing.
Expand Down
2 changes: 1 addition & 1 deletion arccnet/utils/arccnetrc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ aia_keys = T_REC, T_OBS, QUALITY, FSN, WAVELNTH, INSTRUME
wavelengths = 171, 193, 304, 211, 335, 94, 131
patch_height = 400
patch_width = 800
cores = 6
cores = 4


;;;;;;;;;;;;;;;;
Expand Down
Loading