From 7f54c52f7eb95e748eb5048a6ead6c2787dec6ff Mon Sep 17 00:00:00 2001 From: Edoardo Legnaro Date: Wed, 22 Oct 2025 19:32:19 +0000 Subject: [PATCH 1/6] Copied dataset_utils.py, utils.py, and EDA_utils.py from train_AR branch and made new notebooks folder --- arccnet/models/dataset_utils.py | 552 ++++++++++++++-------- arccnet/notebooks/analysis/EDA_cutouts.py | 464 ++++++++++++++++++ arccnet/visualisation/EDA_utils.py | 295 ++++++++++++ arccnet/visualisation/utils.py | 168 +++---- 4 files changed, 1210 insertions(+), 269 deletions(-) create mode 100644 arccnet/notebooks/analysis/EDA_cutouts.py create mode 100644 arccnet/visualisation/EDA_utils.py diff --git a/arccnet/models/dataset_utils.py b/arccnet/models/dataset_utils.py index f8603e24..c13731e7 100644 --- a/arccnet/models/dataset_utils.py +++ b/arccnet/models/dataset_utils.py @@ -1,4 +1,5 @@ import os +import logging import numpy as np import pandas as pd @@ -15,256 +16,427 @@ def make_dataframe(data_folder, dataset_folder, file_name): """ - Process the ARCCNet cutout dataset. + Load and process the ARCCNet cutout dataset. - Parameters - ---------- - data_folder : str, optional - The base directory where the dataset folder is located. - dataset_folder : str, optional - The folder containing the dataset. Default is 'arccnet-cutout-dataset-v20240715'. - file_name : str, optional - The name of the parquet file to read. Default is 'cutout-mcintosh-catalog-v20240715.parq'. + - Reads a parquet file and converts Julian dates to datetime. + - Removes rows with problematic quicklook magnetograms (by filename). + - Adds 'label' (from magnetic_class or region_type) and 'date_only' columns. + - Returns the full DataFrame, a filtered DataFrame with only AR/IA regions, and the removed problematic rows. Returns ------- df : pandas.DataFrame - The processed DataFrame containing all regions with additional date and label columns. + The processed DataFrame. AR_df : pandas.DataFrame - A DataFrame filtered to include only active regions (AR) and plages (IA). - - Notes - ----- - - The function reads a parquet file from the specified folder, processes Julian dates, and converts them to - datetime objects. - - It filters out problematic magnetograms from the dataset by excluding specific records based on quicklook images. - - The magnetic regions are labeled either by their magnetic class or region type, and an additional column - for date is added. - - A subset of the data containing only active regions (AR) and intermediate regions (IA) is returned. - - Examples - -------- - df, AR_df = make_dataframe( - data_folder='../../data/', - dataset_folder='arccnet-cutout-dataset-v20240715', - file_name='cutout-mcintosh-catalog-v20240715.parq') + DataFrame with only active regions (AR) and intermediate regions (IA). + filtered_df : pandas.DataFrame + DataFrame of removed problematic quicklook rows. """ - # Read the parquet file + df = _load_and_process_data(data_folder, dataset_folder, file_name) + df, filtered_df = _remove_problematic_quicklooks(df) + df, AR_df = _add_labels_and_filter_regions(df) + return df, AR_df, filtered_df + + +def _load_and_process_data(data_folder, dataset_folder, file_name): + """Load data and convert dates.""" df = pd.read_parquet(os.path.join(data_folder, dataset_folder, file_name)) + return _convert_jd_to_datetime(df) + - # Convert Julian dates to datetime objects +def _convert_jd_to_datetime(df): + """Convert Julian dates to datetime objects.""" df["time"] = df["target_time.jd1"] + df["target_time.jd2"] times = Time(df["time"], format="jd") - dates = pd.to_datetime(times.iso) # Convert to datetime objects - df["dates"] = dates + df["dates"] = pd.to_datetime(times.iso) + return df + - # Remove problematic magnetograms from the dataset - problematic_quicklooks = config.get("magnetograms", "problematic_quicklooks").split(",") +def _remove_problematic_quicklooks(df): + """Remove problematic magnetograms from the dataset.""" + problematic_quicklooks = [ql.strip() for ql in config.get("magnetograms", "problematic_quicklooks").split(",")] + mask = df["quicklook_path_mdi"].apply(lambda x: os.path.basename(x) in problematic_quicklooks) + filtered_df = df[mask] + df = df[~mask].reset_index(drop=True) + return df, filtered_df - filtered_df = [] - for ql in problematic_quicklooks: - row = df["quicklook_path_mdi"] == "quicklook/" + ql - filtered_df.append(df[row]) - filtered_df = pd.concat(filtered_df) - df = df.drop(filtered_df.index).reset_index(drop=True) - # Label the data +def _add_labels_and_filter_regions(df): + """Label the data and filter for AR and IA regions.""" df["label"] = np.where(df["magnetic_class"] == "", df["region_type"], df["magnetic_class"]) df["date_only"] = df["dates"].dt.date - - # Filter AR and IA regions AR_df = pd.concat([df[df["region_type"] == "AR"], df[df["region_type"] == "IA"]]) - return df, AR_df +def cleanup_df(df, log_level=logging.INFO): + """ + Clean dataframe by removing bad quality data and rows with missing paths. + + Parameters + ---------- + df : pandas.DataFrame + The dataframe to clean. + log_level : int, optional + Logging level for output messages. Use logging.DEBUG, logging.INFO, etc. + Set to None to disable logging. Defaults to logging.INFO. + + Returns + ------- + pandas.DataFrame + Cleaned dataframe with quality filtering and missing path removal applied. + """ + logger = logging.getLogger(__name__) + + # Define quality flags + hmi_good_flags = {"", "0x00000000", "0x00000400"} + mdi_good_flags = {"", "00000000", "00000200"} + + # Filter by quality flags + df_clean = df[df["QUALITY_hmi"].isin(hmi_good_flags) & df["QUALITY_mdi"].isin(mdi_good_flags)].copy() + + if log_level is not None: + _log_filtering_stats(df, df_clean, logger, log_level) + + # Remove rows where both paths are missing + def is_missing(series): + return series.isna() | (series == "") | (series == "None") + + both_missing = is_missing(df_clean["path_image_cutout_hmi"]) & is_missing(df_clean["path_image_cutout_mdi"]) + + if log_level is not None: + _log_path_analysis(df_clean, both_missing, logger, log_level) + + return df_clean[~both_missing].reset_index(drop=True) + + +def _log_filtering_stats(df_orig, df_clean, logger, log_level): + """Log data filtering statistics.""" + df_HMI = df_orig[df_orig["path_image_cutout_mdi"] == ""] + df_MDI = df_orig[df_orig["path_image_cutout_hmi"] == ""] + + hmi_good_flags = {"", "0x00000000", "0x00000400"} + mdi_good_flags = {"", "00000000", "00000200"} + + df_HMI_clean = df_HMI[df_HMI["QUALITY_hmi"].isin(hmi_good_flags)] + df_MDI_clean = df_MDI[df_MDI["QUALITY_mdi"].isin(mdi_good_flags)] + + logger.log(log_level, "DATA FILTERING Stats") + logger.log(log_level, "-" * 40) + + for name, orig, clean in [ + ("HMI", len(df_HMI), len(df_HMI_clean)), + ("MDI", len(df_MDI), len(df_MDI_clean)), + ("Total", len(df_orig), len(df_clean)), + ]: + pct = clean / orig * 100 if orig > 0 else 0 + logger.log(log_level, f"{name}: {clean:,}/{orig:,} ({pct:.1f}% retained)") + + logger.log(log_level, "-" * 40) + + +def _log_path_analysis(df_clean, both_missing, logger, log_level): + """Log path analysis statistics.""" + stats = { + "total": len(df_clean), + "hmi_none": (df_clean["path_image_cutout_hmi"] == "None").sum(), + "hmi_empty": (df_clean["path_image_cutout_hmi"] == "").sum(), + "mdi_none": (df_clean["path_image_cutout_mdi"] == "None").sum(), + "mdi_empty": (df_clean["path_image_cutout_mdi"] == "").sum(), + "both_missing": both_missing.sum(), + } + + logger.log(log_level, "PATH ANALYSIS:") + logger.log(log_level, "-" * 40) + logger.log(log_level, f"Total rows in df_clean: {stats['total']:,}") + logger.log(log_level, f"HMI paths - None: {stats['hmi_none']:,}, Empty: {stats['hmi_empty']:,}") + logger.log(log_level, f"MDI paths - None: {stats['mdi_none']:,}, Empty: {stats['mdi_empty']:,}") + logger.log( + log_level, + f"Both paths missing: {stats['both_missing']:,} ({stats['both_missing'] / stats['total'] * 100:.1f}%)", + ) + + def undersample_group_filter(df, label_mapping, long_limit_deg=60, undersample=True, buffer_percentage=0.1): """ - Filter data based on a specified longitude limit, assign 'front' or 'rear' locations, and group labels - according to a provided mapping. Optionally undersample the majority class. + Filter data based on longitude limit, group labels according to mapping, and optionally undersample. + + This function performs a multi-step data processing pipeline: + 1. Assigns front/rear location based on longitude limits + 2. Maps original labels to grouped labels using the provided mapping + 3. Filters out rows with unmapped labels (None values) + 4. Optionally undersamples the majority class for balanced training + 5. Keeps only front-hemisphere samples for final output Parameters ---------- df : pandas.DataFrame - The dataframe containing the data to be undersampled, grouped, and filtered. + Input dataframe containing solar region data with 'label', 'longitude_hmi', + 'longitude_mdi' columns. label_mapping : dict - A dictionary mapping original labels to grouped labels. + Dictionary mapping original labels to grouped labels. Unmapped labels + (mapped to None) will be filtered out. long_limit_deg : int, optional - The longitude limit for filtering to determine 'front' or 'rear' location. Defaults to 60 degrees. + Longitude limit in degrees for front/rear classification. Regions with + |longitude| <= limit are considered 'front' (default: 60). undersample : bool, optional - Flag to enable or disable undersampling of the majority class. Defaults to True. + Whether to undersample the majority class to balance dataset (default: True). buffer_percentage : float, optional - The percentage buffer added to the second-largest class size when undersampling the majority class. - Defaults to 0.1 (10%). + Buffer percentage added to second-largest class size when undersampling + majority class (default: 0.1 = 10%). Returns ------- - pandas.DataFrame - The modified original dataframe with 'location', 'grouped_labels', and 'encoded_labels' columns added. - pandas.DataFrame - The undersampled and grouped dataframe, with rows from the 'rear' location filtered out. - - Notes - ----- - - This function assigns 'front' or 'rear' location based on a longitude limit. - - Labels are grouped according to the `label_mapping` provided. - - If `undersample` is True, the majority class is reduced to the size of the second-largest class, - plus a specified buffer percentage. - - The function returns two dataframes: the modified original dataframe and an undersampled version where - the 'rear' locations are filtered out. - - Examples - -------- - label_mapping = {'A': 'group1', 'B': 'group1', 'C': 'group2'} - df, undersampled_df = undersample_group_filter( - df=my_dataframe, - label_mapping=label_mapping, - long_limit_deg=60, - undersample=True, - buffer_percentage=0.1 - ) + df_original : pandas.DataFrame + Original dataframe with added columns: 'location', 'grouped_labels', + 'encoded_labels'. + df_processed : pandas.DataFrame + Processed dataframe with label mapping applied, optional undersampling, + and only front-hemisphere samples retained. + + Raises + ------ + ValueError + If input DataFrame is empty or no data remains after label mapping. """ - lonV = np.deg2rad(np.where(df["processed_path_image_hmi"] != "", df["longitude_hmi"], df["longitude_mdi"])) - condition = (lonV < -np.deg2rad(long_limit_deg)) | (lonV > np.deg2rad(long_limit_deg)) - df_filtered = df[~condition] - df_rear = df[condition] - df.loc[df_filtered.index, "location"] = "front" - df.loc[df_rear.index, "location"] = "rear" - - # Apply label mapping to the dataframe - df["grouped_labels"] = df["label"].map(label_mapping) + logger = logging.getLogger(__name__) + + if df.empty: + raise ValueError("Input DataFrame is empty") + + logger.debug(f"Input: {len(df):,} rows") + + df = df.copy() + df = _assign_location(df, long_limit_deg, logger) + df = _apply_label_mapping(df, label_mapping, logger) + + if len(df) == 0: + raise ValueError("No data remaining after applying label mapping") + df["encoded_labels"] = df["grouped_labels"].map(labels.LABEL_TO_INDEX) + _log_unmapped_encoded(df, logger) - if undersample: - class_counts = df["grouped_labels"].value_counts() - majority_class = class_counts.idxmax() - second_largest_class_count = class_counts.iloc[1] - n_samples = int(second_largest_class_count * (1 + buffer_percentage)) + df_processed = _perform_undersampling(df, undersample, buffer_percentage, logger) if undersample else df.copy() + df_processed = _filter_front_hemisphere(df_processed, logger) - # Perform undersampling on the majority class - df_majority = df[df["grouped_labels"] == majority_class] - df_majority_undersampled = resample(df_majority, replace=False, n_samples=n_samples, random_state=42) + _log_final_summary(df_processed, logger) - df_list = [df[df["grouped_labels"] == label] for label in class_counts.index if label != majority_class] - df_list.append(df_majority_undersampled) + return df, df_processed - df_du = pd.concat(df_list) - else: - df_du = df.copy() - # Filter out rows with 'rear' location - df_du = df_du[df_du["location"] != "rear"] +def _assign_location(df, lon_limit_deg, logger): + """Assigns front/rear location based on longitude limits.""" + original_labels = df["label"].value_counts() + logger.debug(f"Original label distribution:\n{original_labels}") - return df, df_du + location_results = filter_by_location(df, lon_limit_deg=lon_limit_deg) + df["location"] = "rear" + df.loc[location_results["mask_front"], "location"] = "front" + + front_count = (df["location"] == "front").sum() + rear_count = len(df) - front_count + logger.debug( + f"Location assignment: {front_count:,} front ({front_count / len(df) * 100:.1f}%), {rear_count:,} rear ({rear_count / len(df) * 100:.1f}%)" + ) + + front_labels = df[df["location"] == "front"]["label"].value_counts() + logger.debug(f"Front hemisphere label distribution:\n{front_labels}") + return df + + +def _apply_label_mapping(df, label_mapping, logger): + """Apply label mapping and filter unmapped labels.""" + df["grouped_labels"] = df["label"].map(label_mapping) + initial_count = len(df) + + mapped_count = df["grouped_labels"].notna().sum() + unmapped_count = df["grouped_labels"].isna().sum() + logger.debug( + f"Label mapping results: {mapped_count:,} mapped ({mapped_count / initial_count * 100:.1f}%), {unmapped_count:,} unmapped ({unmapped_count / initial_count * 100:.1f}%)" + ) + + unmapped_labels = df[df["grouped_labels"].isna()]["label"].value_counts() + if len(unmapped_labels) > 0: + logger.debug(f"Unmapped labels being filtered out:\n{unmapped_labels}") + + df = df.dropna(subset=["grouped_labels"]).reset_index(drop=True) + grouped_labels = df["grouped_labels"].value_counts() + logger.debug(f"After label mapping - grouped label distribution:\n{grouped_labels}") + return df -def split_data(df_du, label_col, group_col, random_state=42): +def _log_unmapped_encoded(df, logger): + """Check for unmapped encoded labels.""" + unmapped_encoded = df["encoded_labels"].isna().sum() + if unmapped_encoded > 0: + logger.warning(f"Found {unmapped_encoded} rows with unmapped encoded labels") + unmapped_grouped = df[df["encoded_labels"].isna()]["grouped_labels"].value_counts() + logger.warning(f"Unmapped grouped labels:\n{unmapped_grouped}") + + +def _perform_undersampling(df, undersample, buffer_percentage, logger): + """Undersample the majority class.""" + if not undersample: + logger.debug("Undersampling disabled") + return df.copy() + + class_counts = df["grouped_labels"].value_counts() + logger.debug(f"Before undersampling - class distribution:\n{class_counts}") + + if len(class_counts) < 2: + logger.warning("Less than 2 classes available, skipping undersampling") + return df.copy() + + majority_class = class_counts.idxmax() + n_samples = min(int(class_counts.iloc[1] * (1 + buffer_percentage)), class_counts.iloc[0]) + + logger.debug("Undersampling strategy:") + logger.debug(f" Majority class: {majority_class} ({class_counts.iloc[0]:,} samples)") + logger.debug(f" Second largest class: {class_counts.index[1]} ({class_counts.iloc[1]:,} samples)") + logger.debug(f" Target size for majority: {n_samples:,} (with {buffer_percentage * 100:.1f}% buffer)") + + df_majority_resampled = resample( + df[df["grouped_labels"] == majority_class], replace=False, n_samples=n_samples, random_state=42 + ) + df_others = [df[df["grouped_labels"] == label] for label in class_counts.index if label != majority_class] + df_du = pd.concat([*df_others, df_majority_resampled], ignore_index=True) + + logger.debug("After undersampling:") + after_undersample = df_du["grouped_labels"].value_counts() + for label in after_undersample.index: + count = after_undersample[label] + pct = count / len(df_du) * 100 + logger.debug(f" {label}: {count:,} ({pct:.1f}%)") + return df_du + + +def _filter_front_hemisphere(df, logger): + """Keep only front samples.""" + before_front_filter = len(df) + df_filtered = df[df["location"] == "front"].reset_index(drop=True) + after_front_filter = len(df_filtered) + + if before_front_filter > 0: + logger.debug( + f"Front hemisphere filtering: {before_front_filter:,} -> {after_front_filter:,} ({after_front_filter / before_front_filter * 100:.1f}% retained)" + ) + return df_filtered + + +def _log_final_summary(df, logger): + """Log final summary.""" + final_counts = df["grouped_labels"].value_counts() + logger.debug(f"Final output: {len(df):,} rows") + logger.debug("Final class distribution:") + for label in final_counts.index: + count = final_counts[label] + pct = count / len(df) * 100 if len(df) > 0 else 0 + logger.debug(f" {label}: {count:,} ({pct:.1f}%)") + + +def split_data(df, label_col, group_col, n_splits=5, random_state=42): """ - Split the data into training, validation, and test sets using stratified group k-fold cross-validation. + Split the data into training, validation, and test sets for cross-validation + using Stratified Group K-Fold approach. + + This implementation ensures: + 1. Each group appears in the test set exactly once across all folds. + 2. Each group appears in the validation set exactly once across all folds. + 3. There is no overlap between train, validation, and test sets within any given fold. Parameters ---------- - df_du : pandas.DataFrame - The dataframe to be split. It must contain the columns specified by `label_col` and `group_col`. + df : pandas.DataFrame + The dataframe to be split. label_col : str - The name of the column to be used for stratification, ensuring balanced class distribution across folds. + The column to be used for stratification. group_col : str - The name of the column to be used for grouping, ensuring that all instances of a group are in the same fold. + The column to group by, ensuring groups are not split across sets. + n_splits : int, optional + The number of folds for cross-validation. Defaults to 5. random_state : int, optional - The random seed for reproducibility of the splits. Defaults to 42. + The random seed for reproducibility. Defaults to 42. Returns ------- list of tuples - A list of tuples, each containing the following for each fold: - - fold : int - The fold number (1 to n_splits). - - train_df : pandas.DataFrame - The training set for the fold. - - val_df : pandas.DataFrame - The validation set for the fold. - - test_df : pandas.DataFrame - The test set for the fold. - - Notes - ----- - - The function uses `StratifiedGroupKFold` to perform k-fold cross-validation with both stratification and - group-wise splits. - - `label_col` is used to ensure balanced class distributions across folds, while `group_col` ensures that - all instances of a group remain in the same fold. - - An inner 10-fold split is performed on the training set to create the test set. - - Examples - -------- - fold_splits = split_data( - df_du=my_dataframe, - label_col='grouped_labels', - group_col='number', - random_state=42 - ) + A list where each tuple contains (fold_number, train_df, val_df, test_df). """ - fold_df = [] - inner_fold_choice = [0, 1, 2, 3, 4] - sgkf = StratifiedGroupKFold(n_splits=5, random_state=random_state, shuffle=True) - X = df_du - - for fold, (train_idx, val_idx) in enumerate(sgkf.split(df_du, df_du[label_col], df_du[group_col]), 1): - temp_df = X.iloc[train_idx] + sgkf = StratifiedGroupKFold(n_splits=n_splits, random_state=random_state, shuffle=True) + X = df + y = df[label_col] + groups = df[group_col] + + # Generate indices for all folds at once + fold_indices = list(sgkf.split(X, y, groups)) + + fold_dataframes = [] + for i in range(n_splits): + # Assign test, validation, and training folds + test_idx = fold_indices[i][1] + val_idx = fold_indices[(i + 1) % n_splits][1] # Use the next fold for validation + + # All other folds are used for training + train_folds_indices = [j for j in range(n_splits) if j != i and j != (i + 1) % n_splits] + train_idx = np.concatenate([fold_indices[j][1] for j in train_folds_indices]) + + # Create the dataframes + train_df = X.iloc[train_idx] val_df = X.iloc[val_idx] - inner_sgkf = StratifiedGroupKFold(n_splits=10) - inner_splits = list(inner_sgkf.split(temp_df, temp_df[label_col], temp_df[group_col])) - inner_train_idx, inner_test_idx = inner_splits[inner_fold_choice[fold - 1]] - train_df = temp_df.iloc[inner_train_idx] - test_df = temp_df.iloc[inner_test_idx] + test_df = X.iloc[test_idx] - fold_df.append((fold, train_df, val_df, test_df)) + fold_dataframes.append((i + 1, train_df, val_df, test_df)) - for fold, train_df, val_df, test_df in fold_df: - X.loc[train_df.index, f"Fold {fold}"] = "train" - X.loc[val_df.index, f"Fold {fold}"] = "val" - X.loc[test_df.index, f"Fold {fold}"] = "test" + # Annotate the original DataFrame + X.loc[train_df.index, f"Fold {i + 1}"] = "train" + X.loc[val_df.index, f"Fold {i + 1}"] = "val" + X.loc[test_df.index, f"Fold {i + 1}"] = "test" - return fold_df + return fold_dataframes -def assign_fold_sets(df, fold_df): +def filter_by_location(df, lon_limit_deg=65, lat_limit_deg=None, limb_r_max=None): """ - Assign training, validation, and test sets to the dataframe based on fold information. - - Parameters - ---------- - df : pandas.DataFrame - The dataframe to be annotated with set information. - fold_df : list of tuples - A list containing tuples for each fold. Each tuple consists of: - - fold : int - The fold number. - - train_df : pandas.DataFrame - The training set for the fold. - - val_df : pandas.DataFrame - The validation set for the fold. - - test_df : pandas.DataFrame - The test set for the fold. - - Returns - ------- - pandas.DataFrame - The original dataframe with an additional 'set' column, which indicates whether a row belongs - to the training, validation, or test set for each fold. - - Notes - ----- - - The function iterates through each fold, adding a 'set' column to the dataframe that assigns rows to - either the 'train', 'val', or 'test' sets based on the information in `fold_df`. - - Examples - -------- - df = assign_fold_sets( - df=df, - fold_df=fold_splits) + Return masks and filtered DataFrames using |lon|, optional |lat| and optional inner-disc radius. + Chooses HMI coords when an HMI path exists, else MDI. """ - for fold, train_set, val_set, test_set in fold_df: - df.loc[train_set.index, f"Fold {fold}"] = "train" - df.loc[val_set.index, f"Fold {fold}"] = "val" - df.loc[test_set.index, f"Fold {fold}"] = "test" - return df + + # Build an HMI-available mask from known path columns + def nonempty_mask(s): + return (~s.isna()) & (s != "") & (s != "None") + + hmi_candidates = ["path_image_cutout_hmi", "processed_path_image_hmi", "quicklook_path_hmi"] + hmi_mask = np.zeros(len(df), dtype=bool) + for col in hmi_candidates: + if col in df.columns: + hmi_mask |= nonempty_mask(df[col]).to_numpy() + + lon_deg = np.where(hmi_mask, df["longitude_hmi"].to_numpy(), df["longitude_mdi"].to_numpy()) + lat_deg = np.where(hmi_mask, df["latitude_hmi"].to_numpy(), df["latitude_mdi"].to_numpy()) + + # Vector masks (avoid Python bools) + lon_ok = (np.abs(lon_deg) <= lon_limit_deg) if lon_limit_deg is not None else np.ones(len(df), dtype=bool) + lat_ok = (np.abs(lat_deg) <= lat_limit_deg) if lat_limit_deg is not None else np.ones(len(df), dtype=bool) + + # Optional limb exclusion via projected radius r = sqrt(y^2 + z^2) + if limb_r_max is not None: + lon = np.deg2rad(lon_deg) + lat = np.deg2rad(lat_deg) + y = np.cos(lat) * np.sin(lon) + z = np.sin(lat) + r = np.sqrt(y**2 + z**2) + limb_ok = r <= limb_r_max + else: + limb_ok = np.ones(len(df), dtype=bool) + + front_mask = lon_ok & lat_ok & limb_ok + rear_mask = ( + (np.logical_not(lon_ok) & lat_ok & limb_ok) if lon_limit_deg is not None else np.zeros(len(df), dtype=bool) + ) + + return { + "mask_front": front_mask, + "mask_rear": rear_mask, + "df_front": df[front_mask].copy(), + "df_rear": df[rear_mask].copy(), + } diff --git a/arccnet/notebooks/analysis/EDA_cutouts.py b/arccnet/notebooks/analysis/EDA_cutouts.py new file mode 100644 index 00000000..e3b4e2c4 --- /dev/null +++ b/arccnet/notebooks/analysis/EDA_cutouts.py @@ -0,0 +1,464 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: py_3.11 +# language: python +# name: python3 +# --- + +# %% +# %load_ext autoreload +# %autoreload 2 + +import os +from datetime import datetime +from collections import defaultdict + +import matplotlib.dates as mdates +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from p_tqdm import p_map + +from arccnet import load_config +from arccnet.models import dataset_utils as ut_d +from arccnet.visualisation import utils as ut_v +from arccnet.visualisation.EDA_utils import ( + analyze_quality_flags, + create_solar_grid, + load_and_analyze_fits_pair, + process_row, +) + +pd.set_option("display.max_columns", None) +pd.set_option("display.max_colwidth", None) +config = load_config() + +# %% +data_folder = os.getenv("ARCAFF_DATA_FOLDER", "/ARCAFF/data") +dataset_folder = "arccnet-v20251017/04_final" +df_file_name = "data/cutout_classification/region_classification.parq" +dataset_title = "arccnet v20251017" + +# %% +df, _, filtered_ql_df = ut_d.make_dataframe(data_folder, dataset_folder, df_file_name) +ut_v.make_classes_histogram(df["label"], figsz=(18, 6), text_fontsize=11, title=dataset_title) +plt.show() +df + +# %% +mdi_color = "royalblue" +hmi_color = "tomato" + +df_MDI = df[df["path_image_cutout_hmi"] == ""].copy() +df_HMI = df[df["path_image_cutout_mdi"] == ""].copy() + +# %% [markdown] +# # Quality Flags + +# %% +quality_mdi_df = analyze_quality_flags(df_MDI, "SOHO/MDI") +quality_mdi_df + +# %% +quality_hmi_df = analyze_quality_flags(df_HMI, "SDO/HMI") +quality_hmi_df + + +# %% [markdown] +# ## Data Filtering + +# %% +# Remove bad quality data +hmi_good_flags = ["", "0x00000000", "0x00000400"] +mdi_good_flags = ["", "00000000", "00000200"] + +df_clean = df[df["QUALITY_hmi"].isin(hmi_good_flags) & df["QUALITY_mdi"].isin(mdi_good_flags)] +df_HMI_clean = df_HMI[df_HMI["QUALITY_hmi"].isin(hmi_good_flags)] +df_MDI_clean = df_MDI[df_MDI["QUALITY_mdi"].isin(mdi_good_flags)] + +print("DATA FILTERING Stats") +print("-" * 40) +hmi_orig, hmi_clean = len(df_HMI), len(df_HMI_clean) +mdi_orig, mdi_clean = len(df_MDI), len(df_MDI_clean) +total_orig, total_clean = len(df), len(df_clean) + +print(f"HMI: {hmi_clean:,}/{hmi_orig:,} ({hmi_clean / hmi_orig * 100:.1f}% retained)") +print(f"MDI: {mdi_clean:,}/{mdi_orig:,} ({mdi_clean / mdi_orig * 100:.1f}% retained)") +print(f"Total: {total_clean:,}/{total_orig:,} ({total_clean / total_orig * 100:.1f}% retained)") +print("-" * 40) + + +# %% +# Path Analysis - Check for None, empty strings, and 'None' strings +def is_missing(series): + """ + Check if file paths are missing or invalid. + """ + return series.isna() | (series == "") | (series == "None") + + +hmi_missing = is_missing(df_clean["path_image_cutout_hmi"]) +mdi_missing = is_missing(df_clean["path_image_cutout_mdi"]) +both_missing = hmi_missing & mdi_missing + +# Count statistics +stats = { + "total": len(df_clean), + "hmi_none": (df_clean["path_image_cutout_hmi"] == "None").sum(), + "hmi_empty": (df_clean["path_image_cutout_hmi"] == "").sum(), + "mdi_none": (df_clean["path_image_cutout_mdi"] == "None").sum(), + "mdi_empty": (df_clean["path_image_cutout_mdi"] == "").sum(), + "both_missing": both_missing.sum(), + "at_least_one": (~hmi_missing | ~mdi_missing).sum(), + "both_exist": (~hmi_missing & ~mdi_missing).sum(), +} + +print("PATH ANALYSIS:") +print("-" * 40) +print(f"Total rows in df_clean: {stats['total']:,}") +print(f"HMI paths - None: {stats['hmi_none']:,}, Empty: {stats['hmi_empty']:,}") +print(f"MDI paths - None: {stats['mdi_none']:,}, Empty: {stats['mdi_empty']:,}") +print(f"Both paths missing: {stats['both_missing']:,} ({stats['both_missing'] / stats['total'] * 100:.1f}%)") +print(f"At least one path exists: {stats['at_least_one']:,}") +print(f"Both paths exist: {stats['both_exist']:,}") + +# Remove rows where both paths are missing +df_clean = df_clean[~both_missing].copy() + +df_clean = df_clean.reset_index(drop=True) + + +# %% [markdown] +# # Location of ARs on the Sun + +# %% +AR_IA_lbs = ["Alpha", "Beta", "IA", "Beta-Gamma-Delta", "Beta-Gamma", "Beta-Delta", "Gamma-Delta", "Gamma"] +AR_IA_df = df_clean[df_clean["label"].isin(AR_IA_lbs)].reset_index(drop=True) + +# %% +ut_v.make_classes_histogram(AR_IA_df["label"], figsz=(12, 7), text_fontsize=11, title=f"{dataset_title} ARs", y_off=100) +plt.show() + + +# %% +def get_coordinates(df, coord_type): + """Extract longitude or latitude coordinates""" + hmi_col = f"{coord_type}_hmi" + mdi_col = f"{coord_type}_mdi" + return np.deg2rad(np.where(df["path_image_cutout_hmi"] != "", df[hmi_col], df[mdi_col])) + + +def plot_histogram(ax, data, degree_ticks, title, color="#4C72B0"): + """Plot histogram with degree labels.""" + rad_ticks = np.deg2rad(degree_ticks) + ax.hist(data, bins=rad_ticks, color=color, edgecolor="black") + ax.set_xticks(rad_ticks) + ax.set_xticklabels([f"{deg}°" for deg in degree_ticks]) + ax.set_xlabel(f"{title} (degrees)") + ax.set_ylabel("Frequency") + + +# Get coordinates +lonV = get_coordinates(AR_IA_df, "longitude") +latV = get_coordinates(AR_IA_df, "latitude") +degree_ticks = np.arange(-90, 91, 15) + +# Plot histograms +with sns.axes_style("darkgrid"): + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) + plot_histogram(ax1, lonV, degree_ticks, "Longitude") + plot_histogram(ax2, latV, degree_ticks, "Latitude") + plt.tight_layout() + plt.show() + + +# %% +results = ut_d.filter_by_location(AR_IA_df, lon_limit_deg=65) + +# Get coordinates for visualization +lonV = get_coordinates(AR_IA_df, "longitude") +latV = get_coordinates(AR_IA_df, "latitude") + +# Calculate y, z coordinates for plotting +yV = np.cos(latV) * np.sin(lonV) +zV = np.sin(latV) + +# Create solar disc visualization +fig, ax = plt.subplots(figsize=(10, 10)) +ax.add_artist(plt.Circle((0, 0), 1, edgecolor="gray", facecolor="none")) +create_solar_grid(ax) + +# Plot data points using masks +front_mask, rear_mask = results["mask_front"], results["mask_rear"] +ax.scatter(yV[rear_mask], zV[rear_mask], s=1, alpha=0.2, color=hmi_color, label="Rear") +ax.scatter(yV[front_mask], zV[front_mask], s=1, alpha=0.2, color=mdi_color, label="Front") + +# Configure plot +ax.set(xlim=(-1.1, 1.1), ylim=(-1.1, 1.1), aspect="equal") +ax.axis("off") +ax.legend(fontsize=12) +plt.show() + +# Print statistics +front_count = int(front_mask.sum()) +rear_count = int(rear_mask.sum()) +total_count = front_count + rear_count +print(f"Rear ARs: {rear_count:,}") +print(f"Front ARs: {front_count:,}") +print(f"Percentage of rear ARs: {100 * rear_count / total_count:.2f}%") + +ut_v.make_classes_histogram(results["df_front"]["label"], title="Front ARs", y_off=10, figsz=(11, 5)) +ut_v.make_classes_histogram(results["df_rear"]["label"], title="Rear ARs", y_off=10, figsz=(11, 5)) +plt.show() + +# %% [markdown] +# # Time Distribution +# %% +mdi_df = AR_IA_df[AR_IA_df["path_image_cutout_mdi"] != ""] +hmi_df = AR_IA_df[AR_IA_df["path_image_cutout_hmi"] != ""] + +# Get time series data +mdi_dates, hmi_dates = mdi_df["dates"].values, hmi_df["dates"].values +mdi_counts, hmi_counts = mdi_df["dates"].value_counts().sort_index(), hmi_df["dates"].value_counts().sort_index() + +# Setup plot +tick_dates = [datetime(year, 1, 1) for year in range(1996, 2025, 2)] + +with plt.style.context("seaborn-v0_8-darkgrid"): + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 9), sharex=True, gridspec_kw={"height_ratios": [4, 1]}) + + # Top panel: Bar chart + ax1.bar(mdi_counts.index, mdi_counts.values, width=0.8, color=mdi_color, alpha=0.9, label="MDI") + ax1.bar(hmi_counts.index, hmi_counts.values, width=0.8, color=hmi_color, alpha=0.9, label="HMI") + ax1.set(ylabel="n° of ARs per day", ylim=[0, 20], yticks=np.arange(0, 20 + 2, 2)) + ax1.tick_params(axis="y", labelsize=14) + ax1.legend(loc="upper left", fontsize=14) + ax1.grid(True, linestyle="--", alpha=0.5) + + # Bottom panel: Timeline + for dates, color, y_range in zip([mdi_dates, hmi_dates], [mdi_color, hmi_color], [(0.2, 0.8), (1.2, 1.8)]): + ax2.vlines(dates, *y_range, color=color, alpha=0.9, linewidth=0.5) + + ax2.set(ylim=[0, 2], yticks=[]) + ax2.grid(True, linestyle="--", alpha=0.75) + + # X-axis formatting + ax2.xaxis_date() + ax2.xaxis.set_major_locator(mdates.YearLocator(2)) + ax2.xaxis.set_major_formatter(mdates.DateFormatter("%Y")) + ax2.set_xticks(tick_dates) + plt.setp(ax2.get_xticklabels(), rotation=45, fontsize=14) + ax2.set_xlabel("Time", fontsize=16) + + plt.tight_layout() + plt.show() + +# %% [markdown] +# # McIntosh Classification + +# %% +AR_df = df[df["magnetic_class"] != ""].copy() + +ut_v.make_classes_histogram( + AR_df["mcintosh_class"], + horizontal=True, + figsz=(10, 18), + y_off=20, + x_rotation=0, + ylabel="Number of Active Regions", + title="McIntosh Class Distribution", + ylim=5900, +) +plt.show() + +# %% + +# McIntosh classification components +AR_df = df[df["magnetic_class"] != ""].copy() +for comp in ["Z_component", "p_component", "c_component"]: + AR_df[comp] = AR_df["mcintosh_class"].str[{"Z_component": 0, "p_component": 1, "c_component": 2}[comp]] + +# Plot parameters and histograms +params = [ + ("Z_component", (10, 6), "Z McIntosh Component"), + ("p_component", (9, 6), "p McIntosh Component"), + ("c_component", (6, 6), "c McIntosh Component"), +] + +for component, figsz, title in params: + ut_v.make_classes_histogram(AR_df[component], y_off=50, figsz=figsz, title=title) +# %% +mappings = { + # Merge D, E, F into LG (LargeGroup) + "Z_component": {"A": "A", "B": "B", "C": "C", "D": "LG", "E": "LG", "F": "LG", "H": "H"}, + # Merge s and h into sym & a and k into asym + "p_component": {"x": "x", "r": "r", "s": "sym", "h": "sym", "a": "asym", "k": "asym"}, + # Merge i and c into frag + "c_component": {"x": "x", "o": "o", "i": "frag", "c": "frag"}, +} + +# Apply mappings and plot +for comp, mapping in mappings.items(): + AR_df[f"{comp}_grouped"] = AR_df[comp].map(mapping) + ut_v.make_classes_histogram( + AR_df[f"{comp}_grouped"], + y_off=50, + figsz={"Z_component": (8, 6), "p_component": (6, 6), "c_component": (5, 6)}[comp], + title=f"{comp.split('_')[0].upper()} McIntosh Component (Grouped)", + ) + plt.show() + + +# %% +def group_and_sort_classes(class_list): + """ + Group classes by their initial letter and display them. + """ + # Group classes by their initial letter + grouped_classes = defaultdict(list) + for cls in sorted(class_list): # Sort the entire list alphabetically first + grouped_classes[cls[0]].append(cls) + + # Format the output + for letter, classes in grouped_classes.items(): + print(f"{letter}: {', '.join(classes)}") + + +print("------ McIntosh Classes ------") +group_and_sort_classes(list(AR_df["mcintosh_class"].unique())) +print(f"\nn° of classes: {len(AR_df['mcintosh_class'].unique())}") +print("\n------ Grouped McIntosh Classes ------") +grouped_classes = list( + (AR_df["Z_component_grouped"] + AR_df["p_component_grouped"] + AR_df["c_component_grouped"]).unique() +) +group_and_sort_classes(grouped_classes) +print(f"\nn° of classes: {len(grouped_classes)}") +# %% [markdown] +# # Pixel Values + + +# %% [markdown] +# ### Single Image + + +# %% +idx = 5564 + +data = load_and_analyze_fits_pair(idx, AR_IA_df, data_folder, dataset_folder) + +# Create subplot and display +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) +fig.suptitle(f"{data['row']['label']} - {data['row']['mcintosh_class']} - {data['row']['dates']}", fontsize=12, y=0.95) + +# Display both images with colorbars +for ax, img_data, title, stats in zip( + [ax1, ax2], + [data["mag_data"], data["cont_data"]], + ["Magnetogram", "Continuum"], + [data["mag_stats"], data["cont_stats"]], +): + im = ax.imshow(img_data, cmap="gray", origin="lower") + ax.set_title(f"{title}\nMean: {stats['mean']:.2f}, Std: {stats['std']:.2f}") + ax.axis("off") + plt.colorbar(im, ax=ax, shrink=0.8) + +plt.tight_layout() +plt.show() + +print(f"Label: {data['row']['label']} - {data['row']['mcintosh_class']}") +print(f"Date: {data['row']['dates']}") +print(f"{'Statistic':<12} {'Magnetogram':<12} {'Continuum':<12}") +print("-" * 36) +print(f"{'Mean':<12} {data['mag_stats']['mean']:<12.2f} {data['cont_stats']['mean']:<12.2f}") +print(f"{'Std Dev':<12} {data['mag_stats']['std']:<12.2f} {data['cont_stats']['std']:<12.2f}") +print(f"{'Min':<12} {data['mag_stats']['min']:<12.2f} {data['cont_stats']['min']:<12.2f}") +print(f"{'Max':<12} {data['mag_stats']['max']:<12.2f} {data['cont_stats']['max']:<12.2f}") + +# %% [markdown] +# ### All images statistics + + +# %% +def process_row_wrapper(idx): + """Wrapper function that uses global variables for parallel processing.""" + return process_row(idx, AR_IA_df, data_folder, dataset_folder) + + +results = p_map(process_row_wrapper, range(len(AR_IA_df))) + +# %% +flat_stats = [] +for entry in results: + if entry is not None: + idx = entry["index"] + row = { + "index": idx, + "label": entry["label"], + "mcintosh_class": entry["mcintosh_class"], + } + row.update({f"mag_{k}": v for k, v in entry["mag_stats"].items()}) + row.update({f"cont_{k}": v for k, v in entry["cont_stats"].items()}) + flat_stats.append(row) +stats_df = pd.DataFrame(flat_stats) +stats_df.describe() + +# %% +# Find the indices of the 10 highest mag_mean values +top10_indices = stats_df["mag_max"].nlargest(10).index +# Get the corresponding rows +top10_rows = stats_df.loc[top10_indices] +top10_rows + +# %% +stats_config = [ + ("mag_mean", "Magnetogram Mean", "royalblue"), + ("mag_std", "Magnetogram Std Dev", "royalblue"), + ("mag_min", "Magnetogram Min", "royalblue"), + ("mag_max", "Magnetogram Max", "royalblue"), + ("cont_mean", "Continuum Mean", "tomato"), + ("cont_std", "Continuum Std Dev", "tomato"), + ("cont_min", "Continuum Min", "tomato"), + ("cont_max", "Continuum Max", "tomato"), +] + +# Create histograms +fig, axes = plt.subplots(2, 4, figsize=(20, 8)) +for i, (col, title, color) in enumerate(stats_config): + ax = axes.flat[i] + ax.hist(stats_df[col], bins=50, color=color, alpha=0.7, edgecolor="black", linewidth=0.5, log=True) + ax.set_title(title, fontsize=16, pad=12) + ax.set_xlabel("Value", fontsize=14) + if i % 4 == 0: # First column gets y-label + ax.set_ylabel("Frequency", fontsize=14) + ax.grid(True, alpha=0.3) + ax.tick_params(labelsize=12) + +plt.tight_layout() +plt.show() + +# %% +# Create boxplots by class +colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f"] +all_labels = stats_df["label"].unique() + +for col, title, _ in stats_config: + plt.figure(figsize=(14, 8)) + sns.boxplot(x="label", y=col, hue="label", data=stats_df, palette=colors[: len(all_labels)], legend=False) + plt.title(f"{title} by Active Region Class", fontsize=18) + plt.xticks(rotation=45, ha="right", fontsize=14) + plt.xlabel("") # Remove x-axis label + plt.ylabel("Value", fontsize=16) + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.show() +# %% diff --git a/arccnet/visualisation/EDA_utils.py b/arccnet/visualisation/EDA_utils.py new file mode 100644 index 00000000..fd3fcf21 --- /dev/null +++ b/arccnet/visualisation/EDA_utils.py @@ -0,0 +1,295 @@ +""" +Utility functions for Exploratory Data Analysis (EDA) of ARCCnet cutout data. +""" + +import os +from pathlib import Path + +import numpy as np + +from astropy.io import fits + +# Quality flag definitions +QUALITY_FLAGS = { + "SOHO/MDI": { + 0x00000001: "Missing Data", + 0x00000002: "Saturated Pixel", + 0x00000004: "Truncated (Top)", + 0x00000008: "Truncated (Bottom)", + 0x00000200: "Shutterless Mode", + 0x00010000: "Cosmic Ray", + 0x00020000: "Calibration Mode", + 0x00040000: "Image Bad", + }, + "SDO/HMI": { + 0x00000020: "Missing >50% Data", + 0x00000080: "Limb Darkening Correction Bad", + 0x00000400: "Shutterless Mode", + 0x00001000: "Partial/Missing Frame", + 0x00010000: "Cosmic Ray", + }, +} + + +def decode_flags(flag_hex, flag_dict): + """ + Decode hexadecimal quality flag to human-readable status. + """ + try: + flag_str = str(flag_hex).strip().lstrip("0x") + if not flag_str or flag_str in ["nan", "None", ""]: + return "Good Quality" + flag_int = int(flag_str, 16) + if flag_int == 0: + return "Good Quality" + meanings = [meaning for bit_val, meaning in flag_dict.items() if flag_int & bit_val] + return " | ".join(meanings) or "Unknown Flag" + except (ValueError, TypeError): + return "Invalid Format" + + +def analyze_quality_flags(df, instrument_name): + """ + Analyze and summarize quality flags for "SOHO/MDI" or "SDO/HMI". + + Returns + ------- + pandas.DataFrame or None + Flag statistics with columns ['Flag_Hex', 'Count', 'Percentage', 'Description']. + Returns None if quality column missing or DataFrame empty. + """ + quality_column = "QUALITY_mdi" if instrument_name == "SOHO/MDI" else "QUALITY_hmi" + + if quality_column not in df.columns or len(df) == 0: + return None + + series = ( + df[quality_column] + .astype(str) + .replace(["nan", "None", "", ""], "00000000") + .str.strip() + .str.replace("0x", "", regex=False) + ) + + counts = series.value_counts().reset_index() + counts.columns = ["Flag", "Count"] + total = counts["Count"].sum() + + return ( + counts.assign( + Percentage=(counts["Count"] / total * 100).round(2).apply(lambda p: f"{p:.2f}%"), + Flag_Hex=counts["Flag"].apply(lambda f: f"0x{f.upper()}"), + Description=counts["Flag"].apply(lambda f: decode_flags(f, QUALITY_FLAGS[instrument_name])), + )[["Flag_Hex", "Count", "Percentage", "Description"]] + .sort_values("Count", ascending=False) + .reset_index(drop=True) + ) + + +def create_solar_grid(ax, num_meridians=12, num_parallels=12, num_points=300): + """ + Add meridian and parallel grid lines to a solar disc plot. + + Parameters + ---------- + ax : matplotlib.axes.Axes + Axes object for drawing grid lines. + num_meridians : int, optional + Number of longitude lines. Default 12. + num_parallels : int, optional + Number of latitude lines. Default 12. + num_points : int, optional + Points per grid line for smoothness. Default 300. + """ + phis = np.linspace(0, 2 * np.pi, num_meridians, endpoint=False) + lats = np.linspace(-np.pi / 2, np.pi / 2, num_parallels) + theta = np.linspace(-np.pi / 2, np.pi / 2, num_points) + + # Meridians + for phi in phis: + y, z = np.cos(theta) * np.sin(phi), np.sin(theta) + ax.plot(y, z, "k-", linewidth=0.2) + + # Parallels + for lat in lats: + y = np.cos(lat) * np.sin(theta) + z = np.full(num_points, np.sin(lat)) + ax.plot(y, z, "k-", linewidth=0.2) + + +def analyze_nan_pattern(data, longitude): + """ + Analyze NaN patterns considering longitude position. + + Parameters + ---------- + data : numpy.ndarray + 2D array to analyze for NaN patterns. + longitude : float + Longitude position in degrees for limb detection. + + Returns + ------- + dict + Dictionary with NaN fraction statistics and position info. + """ + nan_mask = np.isnan(data) + total_nans = np.sum(nan_mask) + + # Check for instrumental artifacts + rows_all_nan = np.all(nan_mask, axis=1) + cols_all_nan = np.all(nan_mask, axis=0) + horizontal_nan_rows = np.sum(rows_all_nan) + vertical_nan_cols = np.sum(cols_all_nan) + + # Calculate bar NaNs + horizontal_bar_nans = horizontal_nan_rows * data.shape[1] + vertical_bar_nans = vertical_nan_cols * data.shape[0] + + # Estimate limb vs instrumental NaNs + abs_longitude = abs(longitude) + is_near_limb = abs_longitude > 60 + + if is_near_limb: + edge_nans = total_nans - horizontal_bar_nans - vertical_bar_nans + limb_nans = max(0, edge_nans) + instrumental_nans = horizontal_bar_nans + vertical_bar_nans + else: + limb_nans = 0 + instrumental_nans = total_nans + + return { + "total_nans": total_nans, + "horizontal_nan_rows": horizontal_nan_rows, + "vertical_nan_cols": vertical_nan_cols, + "limb_nans": limb_nans, + "instrumental_nans": instrumental_nans, + "nan_fraction": total_nans / data.size, + "longitude": longitude, + "is_near_limb": is_near_limb, + } + + +def compute_stats(data, longitude): + """ + Compute statistics with longitude-informed NaN handling. + + Parameters + ---------- + data : numpy.ndarray + 2D array to compute statistics for. + longitude : float + Longitude position in degrees. + + Returns + ------- + dict + Dictionary with statistical measures and NaN analysis. + """ + nan_analysis = analyze_nan_pattern(data, longitude) + valid_data = data[~np.isnan(data)] + + if len(valid_data) == 0: + stats = { + "mean": np.nan, + "median": np.nan, + "std": np.nan, + "min": np.nan, + "max": np.nan, + "shape": data.shape, + "valid_pixels": 0, + "total_pixels": data.size, + } + else: + with np.errstate(invalid="ignore"): + stats = { + "mean": np.nanmean(data), + "median": np.nanmedian(data), + "std": np.nanstd(data), + "min": np.nanmin(data), + "max": np.nanmax(data), + "shape": data.shape, + "valid_pixels": len(valid_data), + "total_pixels": data.size, + } + + stats.update(nan_analysis) + return stats + + +def load_and_analyze_fits_pair(idx, df_clean, data_folder, dataset_folder): + """ + Load magnetogram and continuum FITS files and compute statistics. + + Returns + ------- + dict + Dictionary with loaded data, statistics, and metadata. + """ + # Check if index is valid + if idx >= len(df_clean): + raise ValueError(f"Index {idx} is out of range. df_clean has {len(df_clean)} rows (0-{len(df_clean) - 1})") + + row = df_clean.iloc[idx] + path = row["path_image_cutout_hmi"] if row["path_image_cutout_mdi"] == "" else row["path_image_cutout_mdi"] + fits_magn_filename = os.path.basename(path) + fits_magn_path = Path(data_folder) / dataset_folder / "data/cutout_classification/fits" / fits_magn_filename + fits_cont_path = Path(str(fits_magn_path).replace("_mag_", "_cont_")) + + # Check if files exist + if not fits_magn_path.exists(): + raise FileNotFoundError(f"Magnetogram file not found: {fits_magn_path}") + if not fits_cont_path.exists(): + raise FileNotFoundError(f"Continuum file not found: {fits_cont_path}") + + # Load data + with fits.open(fits_magn_path) as hdul: + mag_data = hdul[1].data + with fits.open(fits_cont_path) as hdul: + cont_data = hdul[1].data + + # Check if data is not empty + if mag_data is None or mag_data.size == 0: + raise ValueError(f"Magnetogram data is empty: {fits_magn_filename}") + if cont_data is None or cont_data.size == 0: + raise ValueError(f"Continuum data is empty: {fits_cont_path.name}") + + # Get longitude information + longitude = row["longitude_hmi"] if row["path_image_cutout_mdi"] == "" else row["longitude_mdi"] + + return { + "row": row, + "mag_data": mag_data, + "cont_data": cont_data, + "mag_stats": compute_stats(mag_data, longitude), + "cont_stats": compute_stats(cont_data, longitude), + "mag_filename": fits_magn_filename, + "cont_filename": fits_cont_path.name, + } + + +def process_row(idx, df_clean, data_folder, dataset_folder): + """ + Process a single row to extract statistics from FITS files. + + Returns + ------- + dict or None + Statistics dictionary or None if processing fails. + """ + try: + data = load_and_analyze_fits_pair(idx, df_clean, data_folder, dataset_folder) + label = df_clean.iloc[idx]["label"] + mcintosh_class = df_clean.iloc[idx]["mcintosh_class"] + mag_stats = data["mag_stats"] + cont_stats = data["cont_stats"] + return { + "index": idx, + "label": label, + "mcintosh_class": mcintosh_class, + "mag_stats": mag_stats, + "cont_stats": cont_stats, + } + except Exception as e: + print(f"Failed at idx={idx}: {e}") + return None diff --git a/arccnet/visualisation/utils.py b/arccnet/visualisation/utils.py index 875d0434..29484783 100644 --- a/arccnet/visualisation/utils.py +++ b/arccnet/visualisation/utils.py @@ -62,8 +62,9 @@ def pad_resize_normalize(image, target_height=224, target_width=224): def make_classes_histogram( series, + horizontal=False, figsz=(13, 5), - y_off=300, + y_off=None, ylim=None, title=None, ylabel="n° of ARs", @@ -77,6 +78,7 @@ def make_classes_histogram( show_percentages=True, ax=None, save_path=None, + transparent=False, ): """ Creates and displays a bar chart (histogram) that visualizes the distribution of classes in a given pandas Series. @@ -84,131 +86,139 @@ def make_classes_histogram( Parameters: - series (pandas.Series): The input series containing the class labels. + - horizontal (bool, optional): + Whether to create a horizontal bar chart. Default is False. - figsz (tuple, optional): A tuple representing the size of the figure (width, height) in inches. - Default is (13, 5). - - y_off (int, optional): - The vertical offset for the text labels above the bars. - Default is 300. + Default is (13, 5) for vertical, but consider (10, 16) for horizontal with many classes. + - y_off (int or float, optional): + The offset for the text labels. For vertical charts, this is the vertical offset; + for horizontal charts, this is the horizontal offset. If None, defaults to 300 for vertical + and 0.5 for horizontal. Default is None. + - ylim (int or float, optional): + The maximum value for the y-axis (vertical) or x-axis (horizontal). Default is None. - title (str, optional): The title of the histogram plot. If `None`, no title will be displayed. Default is None. + - ylabel (str, optional): + The label for the count axis. For vertical charts, this is the y-axis label; + for horizontal charts, this is the x-axis label. Default is "n° of ARs". - titlesize (int, optional): The font size of the title text. Ignored if `title` is `None`. Default is 14. - x_rotation (int, optional): - The rotation angle for the x-axis labels. + The rotation angle for the x-axis labels (vertical) or y-axis labels (horizontal). Default is 0. - fontsize (int, optional): - The font size of the x and y axis labels. - Default is 11. + The font size of the axis labels. Default is 11. - bar_color (str, optional): - The color of the bars in the histogram. - Default is '#4C72B0'. + The color of the bars in the histogram. Default is '#4C72B0'. - edgecolor (str, optional): - The color of the edges of the bars. - Default is 'black'. + The color of the edges of the bars. Default is 'black'. - text_fontsize (int, optional): - The font size of the text displayed above the bars. - Default is 11. + The font size of the text displayed above the bars. Default is 11. - style (str, optional): - The matplotlib style to be used for the plot. - Default is 'seaborn-v0_8-darkgrid'. + The matplotlib style to be used for the plot. Default is 'seaborn-v0_8-darkgrid'. - show_percentages (bool, optional): - Whether to display percentages on top of the bars. - Default is True. + Whether to display percentages on top of the bars. Default is True. - ax (matplotlib.axes.Axes, optional): An existing matplotlib Axes object to plot on. If `None`, a new figure and Axes will be created. Default is None. - save_path (str, optional): Path to save the figure. If `None`, the plot will be displayed instead of saved. Default is None. + - transparent (bool, optional): + Whether to save the figure with a transparent background. Default is False. """ + # Determine the default y_off based on horizontal + if y_off is None: + y_off = 0.5 if horizontal else 300 + + # Process class names and counts + # Filter out None and sort based on horizontal flag + classes_names = sorted(filter(lambda x: x is not None, series.unique()), reverse=horizontal) + if horizontal: + counts = series.value_counts().reindex(classes_names, fill_value=0) + classes_names = counts.index.tolist() + values = counts.values + else: + classes_counts = series.value_counts().reindex(classes_names) + values = classes_counts.values - # Remove None values before sorting - classes_names = sorted(filter(lambda x: x is not None, series.unique())) - - greek_labels = labels.convert_to_greek_label(classes_names) - classes_counts = series.value_counts().reindex(classes_names) - values = classes_counts.values total = np.sum(values) + greek_labels = labels.convert_to_greek_label(classes_names) # Ensure this is defined or imported with plt.style.context(style): + # Create figure and axes if not provided if ax is None: plt.figure(figsize=figsz) - bars = plt.bar(greek_labels, values, color=bar_color, edgecolor=edgecolor) + ax = plt.gca() + + # Create bars + if horizontal: + bars = ax.barh(greek_labels, values, color=bar_color, edgecolor=edgecolor) else: bars = ax.bar(greek_labels, values, color=bar_color, edgecolor=edgecolor) - # Add text on top of the bars + # Annotate bars with values and percentages for bar in bars: - yval = bar.get_height() + if horizontal: + value = bar.get_width() + y_center = bar.get_y() + bar.get_height() / 2 + text_x = value + y_off + text_y = y_center + ha = "left" + va = "center" + else: + value = bar.get_height() + text_x = bar.get_x() + bar.get_width() / 2 + text_y = value + y_off + ha = "center" + va = "bottom" + if show_percentages: - percentage = f"{yval / total * 100:.2f}%" if total > 0 else "0.00%" - if ax is None: - plt.text( - bar.get_x() + bar.get_width() / 2, - yval + y_off, - f"{yval} ({percentage})", - ha="center", - va="bottom", - fontsize=text_fontsize, - ) - else: - ax.text( - bar.get_x() + bar.get_width() / 2, - yval + y_off, - f"{yval} ({percentage})", - ha="center", - va="bottom", - fontsize=text_fontsize, - ) + percentage = f"{value / total * 100:.2f}%" if total > 0 else "0.00%" + text = f"{value} ({percentage})" else: - if ax is None: - plt.text( - bar.get_x() + bar.get_width() / 2, - yval + y_off, - f"{yval}", - ha="center", - va="bottom", - fontsize=text_fontsize, - ) - else: - ax.text( - bar.get_x() + bar.get_width() / 2, - yval + y_off, - f"{yval}", - ha="center", - va="bottom", - fontsize=text_fontsize, - ) - - # Setting x and y ticks - if ax is None: - plt.xticks(rotation=x_rotation, ha="center", fontsize=fontsize) - plt.yticks(fontsize=fontsize) - plt.ylabel(ylabel, fontsize=fontsize) - if ylim: - plt.ylim([0, ylim]) + text = f"{value}" + + ax.text(text_x, text_y, text, ha=ha, va=va, fontsize=text_fontsize) + + # Set axis labels and ticks + if horizontal: + ax.set_yticks(np.arange(len(greek_labels))) + ax.set_yticklabels(greek_labels, rotation=x_rotation, ha="right", fontsize=fontsize) + ax.set_xlabel(ylabel, fontsize=fontsize) + ax.set_ylabel("Class", fontsize=fontsize) + if ylim is not None: + ax.set_xlim([0, ylim]) else: ax.set_xticks(np.arange(len(greek_labels))) ax.set_xticklabels(greek_labels, rotation=x_rotation, ha="center", fontsize=fontsize) - ax.tick_params(axis="y", labelsize=fontsize) + ax.set_ylabel(ylabel, fontsize=fontsize) + if ylim is not None: + ax.set_ylim([0, ylim]) + # Set title if title: - if ax is None: - plt.title(title, fontsize=titlesize) - else: - ax.set_title(title, fontsize=titlesize) + ax.set_title(title, fontsize=titlesize) + + # Set tick parameters for the other axis + if horizontal: + ax.tick_params(axis="x", labelsize=fontsize) + else: + ax.tick_params(axis="y", labelsize=fontsize) - # If a new figure was created, show the plot + # Save or show the plot if ax is None if ax is None: if save_path: - plt.savefig(save_path, bbox_inches="tight") + plt.savefig(save_path, bbox_inches="tight", transparent=transparent) plt.close() else: plt.show() + return ax + class HardTanhTransform: """ From 13f5c81bb174aaa5f6eef5b0bf9ec8e8b182c414 Mon Sep 17 00:00:00 2001 From: Edoardo Legnaro Date: Wed, 22 Oct 2025 13:40:07 +0200 Subject: [PATCH 2/6] removed seaborn --- arccnet/notebooks/analysis/EDA_cutouts.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/arccnet/notebooks/analysis/EDA_cutouts.py b/arccnet/notebooks/analysis/EDA_cutouts.py index e3b4e2c4..acb2f6a9 100644 --- a/arccnet/notebooks/analysis/EDA_cutouts.py +++ b/arccnet/notebooks/analysis/EDA_cutouts.py @@ -26,7 +26,6 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -import seaborn as sns from p_tqdm import p_map from arccnet import load_config @@ -174,7 +173,7 @@ def plot_histogram(ax, data, degree_ticks, title, color="#4C72B0"): degree_ticks = np.arange(-90, 91, 15) # Plot histograms -with sns.axes_style("darkgrid"): +with plt.style.context("seaborn-v0_8-darkgrid"): fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) plot_histogram(ax1, lonV, degree_ticks, "Longitude") plot_histogram(ax2, latV, degree_ticks, "Latitude") @@ -452,13 +451,17 @@ def process_row_wrapper(idx): all_labels = stats_df["label"].unique() for col, title, _ in stats_config: - plt.figure(figsize=(14, 8)) - sns.boxplot(x="label", y=col, hue="label", data=stats_df, palette=colors[: len(all_labels)], legend=False) - plt.title(f"{title} by Active Region Class", fontsize=18) - plt.xticks(rotation=45, ha="right", fontsize=14) - plt.xlabel("") # Remove x-axis label - plt.ylabel("Value", fontsize=16) - plt.grid(True, alpha=0.3) + fig, ax = plt.subplots(figsize=(14, 8)) + data_by_label = [stats_df[stats_df["label"] == label][col].values for label in all_labels] + bp = ax.boxplot(data_by_label, labels=all_labels, patch_artist=True) + for patch, color in zip(bp["boxes"], colors[: len(all_labels)]): + patch.set_facecolor(color) + ax.set_title(f"{title} by Active Region Class", fontsize=18) + ax.tick_params(axis="x", rotation=45, labelsize=14) + ax.set_xlabel("") + ax.set_ylabel("Value", fontsize=16) + ax.grid(True, alpha=0.3) + fig.autofmt_xdate(ha="right") plt.tight_layout() plt.show() # %% From 5adfe33e4dffd4bc330b8028c8179e251f92fe1a Mon Sep 17 00:00:00 2001 From: Edoardo Legnaro Date: Wed, 22 Oct 2025 13:45:08 +0200 Subject: [PATCH 3/6] updated dependencies --- docs/conf.py | 2 +- pyproject.toml | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 707f98c8..e4086541 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,7 +25,7 @@ "sphinx.ext.intersphinx", "sphinx.ext.todo", "sphinx.ext.coverage", - "sphinx.ext.inheritance_diagram", + # "sphinx.ext.inheritance_diagram", # Removed: requires graphviz, not currently used "sphinx.ext.viewcode", "sphinx.ext.napoleon", "sphinx.ext.doctest", diff --git a/pyproject.toml b/pyproject.toml index 552a56d7..4a319cc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,8 @@ dependencies = [ "pyarrow", "drms~=0.9", "astropy~=7.0", - "pandas~=2.0" + "pandas~=2.0", + "p_tqdm" ] [project.optional-dependencies] From 41cc57025bf47cb65731b301cdc9e30b6a88595b Mon Sep 17 00:00:00 2001 From: Edoardo Legnaro Date: Wed, 22 Oct 2025 14:00:13 +0200 Subject: [PATCH 4/6] fixed tox --- arccnet/models/tests/test_utils.py | 17 +++++++++++------ docs/conf.py | 2 +- pytest.ini | 1 + 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/arccnet/models/tests/test_utils.py b/arccnet/models/tests/test_utils.py index 4f2d44cb..21035903 100644 --- a/arccnet/models/tests/test_utils.py +++ b/arccnet/models/tests/test_utils.py @@ -11,7 +11,9 @@ def sample_dataframe(): data = { "processed_path_image_hmi": ["path1", "", "path3", "path4", "", "path6"], "longitude_hmi": [30, -70, 45, 80, -50, 50], - "longitude_mdi": [np.nan, 60, np.nan, 90, np.nan, 60], + "longitude_mdi": [np.nan, 60, np.nan, 90, 40, 60], + "latitude_hmi": [0, 10, 20, 30, 40, 50], + "latitude_mdi": [np.nan, 15, np.nan, 35, 45, 55], "label": ["A", "B", "C", "A", "B", "C"], } return pd.DataFrame(data) @@ -105,9 +107,12 @@ def test_undersample_group_filter(monkeypatch, sample_dataframe): buffer_percentage=0.1, ) - # Test that all front locations are present without undersampling - expected_front_indices = df_modified_no_undersample[df_modified_no_undersample["location"] == "front"].index - assert set(df_no_undersample.index) == set(expected_front_indices), "Filtering without undersampling incorrect." + # Test that all locations in the result are 'front' (rear should be filtered) + assert all(df_no_undersample["location"] == "front"), "Not all locations are 'front' after filtering." - # Test that all front locations are present - assert df_no_undersample.shape[0] == 5, "Incorrect number of rows after filtering without undersampling." + # Test that we have the correct number of front locations (5 out of 6 total rows) + expected_front_count = (df_modified_no_undersample["location"] == "front").sum() + assert df_no_undersample.shape[0] == expected_front_count, ( + "Incorrect number of rows after filtering without undersampling." + ) + assert df_no_undersample.shape[0] == 5, "Expected 5 front locations, got a different count." diff --git a/docs/conf.py b/docs/conf.py index e4086541..707f98c8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,7 +25,7 @@ "sphinx.ext.intersphinx", "sphinx.ext.todo", "sphinx.ext.coverage", - # "sphinx.ext.inheritance_diagram", # Removed: requires graphviz, not currently used + "sphinx.ext.inheritance_diagram", "sphinx.ext.viewcode", "sphinx.ext.napoleon", "sphinx.ext.doctest", diff --git a/pytest.ini b/pytest.ini index d5c75bc0..3652f0b6 100644 --- a/pytest.ini +++ b/pytest.ini @@ -13,6 +13,7 @@ norecursedirs = ".jupyter" ".history" "tools" + "notebooks" doctest_plus = enabled text_file_format = rst addopts = --doctest-rst From f89f4dd82b8db63ac2afce8b3708421a665ed434 Mon Sep 17 00:00:00 2001 From: Edoardo Legnaro Date: Wed, 22 Oct 2025 21:22:35 +0000 Subject: [PATCH 5/6] updated EDA cutouts --- arccnet/notebooks/analysis/EDA_cutouts.py | 11 ++++++----- arccnet/utils/arccnetrc | 13 ++++++++++++- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/arccnet/notebooks/analysis/EDA_cutouts.py b/arccnet/notebooks/analysis/EDA_cutouts.py index acb2f6a9..b2d126a7 100644 --- a/arccnet/notebooks/analysis/EDA_cutouts.py +++ b/arccnet/notebooks/analysis/EDA_cutouts.py @@ -26,6 +26,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +import seaborn as sns from p_tqdm import p_map from arccnet import load_config @@ -413,7 +414,7 @@ def process_row_wrapper(idx): # %% # Find the indices of the 10 highest mag_mean values -top10_indices = stats_df["mag_max"].nlargest(10).index +top10_indices = stats_df["mag_mean"].nlargest(10).index # Get the corresponding rows top10_rows = stats_df.loc[top10_indices] top10_rows @@ -450,12 +451,12 @@ def process_row_wrapper(idx): colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f"] all_labels = stats_df["label"].unique() +# Create a custom color palette +palette = {label: colors[i % len(colors)] for i, label in enumerate(all_labels)} + for col, title, _ in stats_config: fig, ax = plt.subplots(figsize=(14, 8)) - data_by_label = [stats_df[stats_df["label"] == label][col].values for label in all_labels] - bp = ax.boxplot(data_by_label, labels=all_labels, patch_artist=True) - for patch, color in zip(bp["boxes"], colors[: len(all_labels)]): - patch.set_facecolor(color) + sns.boxplot(data=stats_df, x="label", y=col, hue="label", palette=palette, ax=ax, legend=False) ax.set_title(f"{title} by Active Region Class", fontsize=18) ax.tick_params(axis="x", rotation=45, labelsize=14) ax.set_xlabel("") diff --git a/arccnet/utils/arccnetrc b/arccnet/utils/arccnetrc index 9477c75f..55c45b85 100644 --- a/arccnet/utils/arccnetrc +++ b/arccnet/utils/arccnetrc @@ -39,7 +39,18 @@ lon_lim_degrees = 85 ; Magnetograms ; ;;;;;;;;;;;;;;;; [magnetograms] -problematic_quicklooks = 20010116_000028_MDI.png, 20001130_000028_MDI.png, 19990420_235943_MDI.png +problematic_quicklooks = 20010116_000028_MDI.png, \ + 20001130_000028_MDI.png, \ + 19990420_235943_MDI.png, \ + 19960904_000030_.png, \ + 19960905_000030_.png, \ + 20010116_000028_.png, \ + 20001130_000028_.png, \ + 19990420_235943_.png, \ + 19960916_000030_.png, \ + 19960903_000030_.png, \ + 20050201_235943_.png, \ + 20090527_000026_.png [magnetograms.cutouts] From 72598ee96d81ba18c52e2b2b05867a80a2785874 Mon Sep 17 00:00:00 2001 From: Edoardo Legnaro Date: Wed, 22 Oct 2025 21:54:14 +0000 Subject: [PATCH 6/6] fixed problematic cutouts not being parsed correctly --- arccnet/models/dataset_utils.py | 18 ++++++++++++++++-- arccnet/notebooks/analysis/EDA_cutouts.py | 2 +- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/arccnet/models/dataset_utils.py b/arccnet/models/dataset_utils.py index c13731e7..87b846bf 100644 --- a/arccnet/models/dataset_utils.py +++ b/arccnet/models/dataset_utils.py @@ -54,8 +54,22 @@ def _convert_jd_to_datetime(df): def _remove_problematic_quicklooks(df): """Remove problematic magnetograms from the dataset.""" - problematic_quicklooks = [ql.strip() for ql in config.get("magnetograms", "problematic_quicklooks").split(",")] - mask = df["quicklook_path_mdi"].apply(lambda x: os.path.basename(x) in problematic_quicklooks) + # Parse and clean the problematic quicklooks list + # Strip whitespace and remove backslash line continuations + problematic_quicklooks = [ + ql.strip().lstrip("\\").strip() for ql in config.get("magnetograms", "problematic_quicklooks").split(",") + ] + + def is_problematic(path): + """Check if a path's basename matches any problematic quicklook.""" + if not path or pd.isna(path) or path == "" or path == "None": + return False + return os.path.basename(path) in problematic_quicklooks + + # Check both MDI and HMI quicklook paths + mask_mdi = df["quicklook_path_mdi"].apply(is_problematic) + mask_hmi = df["quicklook_path_hmi"].apply(is_problematic) + mask = mask_mdi | mask_hmi filtered_df = df[mask] df = df[~mask].reset_index(drop=True) return df, filtered_df diff --git a/arccnet/notebooks/analysis/EDA_cutouts.py b/arccnet/notebooks/analysis/EDA_cutouts.py index b2d126a7..8e9f5cc5 100644 --- a/arccnet/notebooks/analysis/EDA_cutouts.py +++ b/arccnet/notebooks/analysis/EDA_cutouts.py @@ -352,7 +352,7 @@ def group_and_sort_classes(class_list): # %% -idx = 5564 +idx = 144 data = load_and_analyze_fits_pair(idx, AR_IA_df, data_folder, dataset_folder)