-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_run.py
109 lines (88 loc) · 4.95 KB
/
model_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import argparse
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from .utils import *
from .model_module import ModelModule
from .data_module import DataModule
# Parse command-line arguments to set parameters for the script
def parse_arguments():
parser = argparse.ArgumentParser(description='Parameters for CNS tumor classification')
parser.add_argument('--dataset', type=str,
default='./aggregator_train_val/data/dataset', help='Path to the dataset directory containing embeddings generated by a feature extractor, saved in .pt format')
parser.add_argument('--label', type=str,
default='./aggregator_train_val/annot_files/labels.csv',
help='Path to the slide label CSV file, which should contain columns including slide, family, probability vector, age, and location')
parser.add_argument('--split', type=str, default='./aggregator_train_val/annot_files/split.yaml',
help='Path to the dataset split file (YAML) containing train and test slide IDs, structured as {"train": [slide_id], "test": [slide_id]}')
parser.add_argument('--mode', type=str, default='train', help='Operation mode: train or test')
parser.add_argument('--data_aug', action='store_true', help='Apply data augmentation during training')
parser.add_argument('--soft_labels', action='store_true', help='Use soft labels during training')
parser.add_argument('--exp_name', type=str, default='default_exp', help='Identifier for the experiment')
parser.add_argument('--output_dir', type=str, default='./aggregator_train_val/predictions', help='Directory to save predictions')
parser.add_argument('--model', type=str, default='ATransMIL', help='Model architecture to use')
parser.add_argument('--groups', type=int, default=3, help='Number of slide matrix divisions')
parser.add_argument('--classes', type=int, default=186, help='Output class number by the classifier')
parser.add_argument('--cl_weight', type=float, default=20, help='Weight for contrastive loss')
parser.add_argument('--config', type=str, default='./aggregator_train_val/config.yaml', help='Path to configuration file')
parser.add_argument('--label_map', type=str, default='./aggregator_train_val/annot_files/class_ID.yaml', help='Path to label mapping file')
parser.add_argument('--resume', action='store_true', help='Resume training from the latest checkpoint')
return parser.parse_args()
# Main function that orchestrates the training/testing process
def model_run(cfg):
# Set random seed for reproducibility
if cfg['General']['mode'] != 'train':
set_seed(cfg['General']['seed'])
# Load loggers and callbacks based on configuration
loggers = load_loggers(cfg)
callbacks = load_callbacks(cfg)
# Initialize data and model module using the configuration
data_module = DataModule(**cfg['Data'])
model = ModelModule(**cfg)
# Set up the PyTorch Lightning trainer with specified settings
trainer = Trainer(
logger=loggers,
callbacks=callbacks,
max_epochs=cfg['General']['epochs'],
accelerator='gpu',
devices='auto',
precision=cfg['General']['precision'],
accumulate_grad_batches=cfg['General']['grad_acc'],
check_val_every_n_epoch=1,
num_sanity_val_steps=0
)
# Train or test based on the mode specified in the configuration
if cfg['General']['mode'] == 'train':
# Resume training if the resume flag is set
if cfg['resume']:
last_checkpoint_path = os.path.join(cfg['General']['log_path'], 'last.ckpt')
model = model.load_from_checkpoint(checkpoint_path=last_checkpoint_path, cfg=cfg)
trainer.fit(model=model, datamodule=data_module)
else:
# Test the model using the latest checkpoints
latest_checkpoint_path = max(cfg.log_path.glob('*.ckpt'), key=os.path.getctime)
print(f'Testing with checkpoint: {latest_checkpoint_path}')
loaded_model = model.load_from_checkpoint(checkpoint_path=latest_checkpoint_path, Data=cfg['Data'])
trainer.test(model=loaded_model, datamodule=data_module)
if __name__ == '__main__':
# Parse command line arguments
args = parse_arguments()
# Read configuration from YAML file
cfg = read_yaml(args.config)
# Update configuration with command line arguments
cfg.Data.data_dir = args.dataset
cfg.Data.data_split = args.split
cfg.Data.label_file = args.label
cfg.Data.aug = args.data_aug
cfg.Data.soft_labels = args.soft_labels
cfg.General.mode = args.mode
cfg.Model.exp_name = args.exp_name
cfg.Model.name = args.model
cfg.Model.group_num = args.groups
cfg.Model.n_classes = args.classes
cfg.Model.cl_w = args.cl_weight
cfg.Data.preds_save = args.output_dir
cfg.Data.label_mapping = args.label_map
cfg.resume = args.resume
# Run the main function
model_run(cfg)