Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,8 @@ docs/whatsnew/latest_changelog.txt
*.log

data/**

### Training Temporary Folders
arccnet/models/temp
arccnet/models/weights
arccnet/models/trained_models
43 changes: 43 additions & 0 deletions arccnet/models/cutouts/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os

from torchvision.transforms import v2

project_name = "arcaff-v2-qs-ia-a-b-bg"

batch_size = 32
num_workers = 12
num_epochs = 200
patience = 10
pretrained = True
learning_rate = 1e-5

model_name = "beit_base_patch16_224"
gpu_index = 0
device = "cuda:" + str(gpu_index)

data_folder = os.getenv("ARCAFF_DATA_FOLDER", "../../data/")
dataset_folder = "arccnet-cutout-dataset-v20240715"
df_file_name = "cutout-mcintosh-catalog-v20240715.parq"

label_mapping = {
"QS": "QS",
"IA": "IA",
"Alpha": "Alpha",
"Beta": "Beta",
"Beta-Delta": "Beta",
"Beta-Gamma": "Beta-Gamma",
"Beta-Gamma-Delta": "Beta-Gamma",
"Gamma": None,
"Gamma-Delta": None,
}

train_transforms = v2.Compose(
[
v2.RandomHorizontalFlip(),
v2.RandomVerticalFlip(),
Comment on lines +36 to +37
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How valid is this without also flipping magnetic field?

v2.RandomResizedCrop(size=(224, 224), scale=(0.9, 0.9), antialias=True),
v2.RandomRotation(35),
]
)

val_transforms = None
99 changes: 99 additions & 0 deletions arccnet/models/cutouts/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import argparse
from pathlib import Path

import numpy as np
import requests
import timm
import torch
from comet_ml import API

from astropy.io import fits

from arccnet.models import utilities as ut


def download_model(api, workspace, model_name, model_version, model_path):
model_comet = api.get_model(workspace, model_name)
model_assets = model_comet.get_assets(model_version)
for asset in model_assets:
if asset["fileName"] == "model-data/comet-torch-model.pth":
model_url = asset["s3Link"]
break
if model_path.exists():
print(f"Model file already exists at {model_path}. \nSkipping download.")
else:
response = requests.get(model_url)
if response.status_code == 200:
with open(model_path, "wb") as f:
f.write(response.content)
print(f"Model downloaded successfully and saved to {model_path}")
else:
print(f"Failed to download model. Status code: {response.status_code}")


def preprocess_fits_data(fits_file_path, hardtanh=True, target_height=224, target_width=224):
with fits.open(fits_file_path, memmap=True) as img_fits:
image_data = np.array(img_fits[1].data, dtype=np.float32)
image_data = np.nan_to_num(image_data, nan=0.0)
if hardtanh:
image_data = ut.hardtanh_transform_npy(image_data, divisor=800, min_val=-1.0, max_val=1.0)
image_data = ut.pad_resize_normalize(image_data, target_height=target_height, target_width=target_width)
return torch.from_numpy(image_data).unsqueeze(0)


def run_inference(model, fits_file_path, device):
model.eval()
with torch.no_grad():
data = preprocess_fits_data(fits_file_path)
data = data.unsqueeze(0).to(device) # dimensions: (batch_size, channels, height, width)
output = model(data)
return output.cpu().numpy()


def main(args):
api = API()
script_dir = Path(__file__).parent.resolve()
output_dir = script_dir.parent / "trained_models"
output_dir.mkdir(parents=True, exist_ok=True)
model_path = output_dir / f"{args.model_name}-{args.model_version}.pth"
Comment on lines +55 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think should be using the generic configuration stuff here


download_model(api, args.workspace, args.model_name, args.model_version, model_path)

# Find number of classes from project name
substring_after_v2 = args.project_name.split("arcaff-v2-")[1]
values = substring_after_v2.split("-")
num_classes = len(values)

# Create the model
model = timm.create_model(args.model_name, num_classes=num_classes, in_chans=1)
ut.replace_activations(model, torch.nn.ReLU, torch.nn.LeakyReLU, negative_slope=0.01)

# Load the model state
device = "cpu"
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))

# Run inference
print(f"FITS file: {args.fits_file_path}")
result = run_inference(model, args.fits_file_path, device)
predicted_class = np.argmax(result)
probabilities = torch.softmax(torch.tensor(result), dim=1).numpy()

print("Normalized Predictions:", probabilities)
print("Predicted class:", ut.index_to_label[predicted_class])


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run inference on FITS data using a pre-trained model.")
parser.add_argument(
"--fits_file_path",
type=str,
default="/ARCAFF/data/arccnet-cutout-dataset-v20240715/fits/19970125_235945_QS-5_MDI.fits",
help="Path to the FITS file.",
)
parser.add_argument("--project_name", type=str, default="arcaff-v2-qs-ia-a-b-bg", help="Name of the project.")
parser.add_argument("--workspace", type=str, default="arcaff", help="Workspace name in Comet.ml.")
parser.add_argument("--model_name", type=str, default="resnet10t", help="Model name.")
parser.add_argument("--model_version", type=str, default="1.0.0", help="Model version.")

args = parser.parse_args()
main(args)
10 changes: 10 additions & 0 deletions arccnet/models/cutouts/readme.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
The train.py script requires:
- Two folders,
- ARCCnet/arccnet/models/temp
- ARCCnet/arccnet/models/weights
are needed to store temporary files.
- An environment variable named ARCAFF_DATA_FOLDER must be set, pointing to the location where the dataset is stored.
Default is ../../data
- Before logging training to Comet ML, run
import comet_ml
comet_ml.login()
90 changes: 90 additions & 0 deletions arccnet/models/cutouts/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import argparse

import torch
from comet_ml import Experiment

import arccnet.models.cutouts.config as config
import arccnet.models.utilities as ut

# Initialize argument parser
parser = argparse.ArgumentParser(description="Training script with configurable options.")
parser.add_argument("--model_name", type=str, help="Timm model name")
parser.add_argument("--batch_size", type=int, help="Batch size for training.")
parser.add_argument("--num_workers", type=int, help="Number of workers for data loading and preprocessing.")
parser.add_argument("--num_epochs", type=int, help="Number of epochs for training.")
parser.add_argument("--patience", type=int, help="Patience for early stopping.")
parser.add_argument("--learning_rate", type=float, help="Learning rate for optimizer.")
parser.add_argument("--gpu_index", type=int, help="Index of the GPU to use.")
parser.add_argument("--data_folder", type=str, help="Path to the data folder.")
parser.add_argument("--dataset_folder", type=str, help="Path to the dataset folder.")
parser.add_argument("--df_file_name", type=str, help="Name of the dataframe file.")

args = parser.parse_args()

# Override config settings with arguments if provided
if args.model_name is not None:
config.model_name = args.model_name
if args.batch_size is not None:
config.batch_size = args.batch_size
if args.num_epochs is not None:
config.num_epochs = args.num_epochs
if args.patience is not None:
config.patience = args.patience
if args.learning_rate is not None:
config.learning_rate = args.learning_rate
if args.gpu_index is not None:
config.gpu_index = args.gpu_index
config.device = f"cuda:{args.gpu_index}"
if args.data_folder is not None:
config.data_folder = args.data_folder
if args.dataset_folder is not None:
config.dataset_folder = args.dataset_folder
if args.df_file_name is not None:
config.df_file_name = args.df_file_name
if args.num_workers is not None:
config.num_workers = args.num_workers

run_id, weights_dir = ut.generate_run_id(config)

run_comet = Experiment(project_name=config.project_name, workspace="arcaff")

run_comet.add_tags([config.model_name])
run_comet.log_parameters(
{
"model_name": config.model_name,
"batch_size": config.batch_size,
"GPU": f"GPU{config.gpu_index}_{torch.cuda.get_device_name()}",
"num_epochs": config.num_epochs,
"patience": config.patience,
}
)

run_comet.log_code(config.__file__)
run_comet.log_code(ut.__file__)

print("Making dataframe...")
df, AR_df = ut.make_dataframe(config.data_folder, config.dataset_folder, config.df_file_name)

df, df_du = ut.undersample_group_filter(
df, config.label_mapping, long_limit_deg=60, undersample=True, buffer_percentage=0.1
)
fold_df = ut.split_data(df_du, label_col="grouped_labels", group_col="number", random_state=42)
df = ut.assign_fold_sets(df, fold_df)
print("done.")
print("Starting Training...")

(avg_test_loss, test_accuracy, test_precision, test_recall, test_f1, cm_test, report_df) = ut.train_model(
config, df, weights_dir, experiment=run_comet
)

print("Logging assets...")
script_dir = os.path.dirname(ut.__file__)
save_path = os.path.join(script_dir, "temp", "working_dataset.png")
ut.make_classes_histogram(
df_du["grouped_labels"], title="Dataset (Grouped Undersampled)", y_off=100, figsz=(7, 5), save_path=save_path
)
run_comet.log_image(save_path)

run_comet.log_asset_data(df.to_csv(index=False), name="dataset.csv")
print("done.")
Loading