Skip to content

Commit 91f78ef

Browse files
edoardolegnaroEdoardo Legnaro
andauthored
McIntosh Code (#174)
* added mcintosh code and moved mount_wilson code to respective folder * renamed mount_wilson to hale --------- Co-authored-by: Edoardo Legnaro <[email protected]>
1 parent 2665349 commit 91f78ef

File tree

11 files changed

+1641
-6
lines changed

11 files changed

+1641
-6
lines changed

arccnet/cli/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from collections.abc import Mapping
1010

1111
from arccnet import load_config
12-
from arccnet.models.cutouts import config as config_module
13-
from arccnet.models.cutouts.inference import predict
14-
from arccnet.models.cutouts.train import run_training
12+
from arccnet.models.cutouts.hale import config as config_module
13+
from arccnet.models.cutouts.hale.inference import predict
14+
from arccnet.models.cutouts.hale.train import run_training
1515
from arccnet.pipeline.main import process_ar_catalogs, process_ars, process_flares
1616
from arccnet.utils.logging import get_logger
1717

arccnet/models/cutouts/hale/__init__.py

Whitespace-only changes.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from astropy.io import fits
1313

1414
from arccnet.models import train_utils as ut_t
15-
from arccnet.models.cutouts import config
15+
from arccnet.models.cutouts.hale import config
1616
from arccnet.utils.logging import get_logger
1717
from arccnet.visualisation import utils as ut_v
1818

@@ -66,7 +66,7 @@ def run_inference(model, fits_file_path, device):
6666
model.eval()
6767
with torch.no_grad():
6868
data = preprocess_fits_data(fits_file_path)
69-
data = data.to(device) # Removed extra unsqueeze since preprocess already has unsqueeze
69+
data = data.to(device)
7070
output = model(data)
7171
return output.cpu().numpy()
7272
except Exception:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from comet_ml import Experiment
66

7-
import arccnet.models.cutouts.config as config
7+
import arccnet.models.cutouts.hale.config as config
88
import arccnet.models.dataset_utils as ut_d
99
import arccnet.models.train_utils as ut_t
1010
import arccnet.visualisation.utils as ut_v

arccnet/models/cutouts/mcintosh/__init__.py

Whitespace-only changes.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
3+
import torchvision.transforms as v2
4+
5+
### General ###
6+
resnet_version = "resnet18"
7+
gpu_index = 0
8+
epochs = 500
9+
patience = 15
10+
batch_size = 32
11+
num_workers = 12 # os.cpu_count()
12+
learning_rate = 1e-5
13+
random_state = 42
14+
15+
train_transforms = v2.Compose(
16+
[
17+
# v2.RandomVerticalFlip(),
18+
# v2.RandomHorizontalFlip(),
19+
v2.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.98, 1.02)),
20+
]
21+
)
22+
23+
### Teacher Forcing ###
24+
initial_teacher_forcing_ratio = 0.75
25+
min_teacher_forcing_ratio = 0.0
26+
teacher_forcing_decay = 0.9
27+
teacher_forcing = True
28+
29+
### Dataset ###
30+
data_folder = os.getenv("ARCAFF_DATA_FOLDER", "../../../../../data/")
31+
dataset_folder = "arccnet-cutout-dataset-v20240715"
32+
df_name = "cutout-magnetic-catalog-v20240715.parq"
33+
long_limit_deg = 65
34+
train_size = 0.7
35+
val_size = 0.15
36+
test_size = 0.15
37+
38+
### Logging ###
39+
plot_histograms = False
40+
41+
### Comet ###
42+
use_comet = True
43+
project_name = "arcaff-mcintosh"
44+
workspace = "arcaff"

0 commit comments

Comments
 (0)