-
Notifications
You must be signed in to change notification settings - Fork 8
/
main.py
81 lines (61 loc) · 2.27 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import logging
from pathlib import Path
from typing import Dict, List
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from vegann.trainer import VeganTrainer
logger = logging.getLogger(__name__)
def metrics_to_csv(split_metrics: List[Dict], out_path: str):
acc, iou, f1, acc_im, iou_im, f1_im = [], [], [], [], [], []
for test_metrics in split_metrics:
# store Accuracy and IOU
acc.append(test_metrics[0]["test_dataset_acc"])
iou.append(test_metrics[0]["test_dataset_iou"])
f1.append(test_metrics[0]["test_dataset_f1"])
acc_im.append(test_metrics[0]["test_per_image_acc"])
iou_im.append(test_metrics[0]["test_per_image_iou"])
f1_im.append(test_metrics[0]["test_per_image_f1"])
# results to CSV
dd = {
"OA-dt": [np.mean(acc)],
"IOU-dt": [np.mean(iou)],
"f1-dt": [np.mean(f1)],
"OA-dt_std": [np.std(acc)],
"IOU-dt_std": [np.std(iou)],
"f1-dt_std": [np.std(f1)],
"OA-im": [np.mean(acc_im)],
"IOU-im": [np.mean(iou_im)],
"f1-im": [np.mean(f1_im)],
"OA-im_std": [np.std(acc_im)],
"IOU-im_std": [np.std(iou_im)],
"f1-im_std": [np.std(f1_im)],
}
df = pd.DataFrame(data=dd)
df.to_csv(path_or_buf=str(out_path), index=False)
def main(config_path: str):
config = OmegaConf.load(config_path)
n_split = config.dataset.n_split
split_metrics = []
for split_id in range(1, n_split + 1):
# initialize vegantrainer
Vtrainer = VeganTrainer(
config=config, expt_dir=Path(config.expt_dir) / f"split_{split_id}"
)
Vtrainer.setup_dataloaders(split_id=split_id)
logger.info(f"Train size: {len(Vtrainer.train_dataset)}")
logger.info(f"Test size: {len(Vtrainer.test_dataset)}")
logger.info("loading dataset ..")
# launch training
Vtrainer.train()
# get test metrics
test_metrics = Vtrainer.test()
split_metrics.append(test_metrics)
out_path = (
Path(config.expt_dir)
/ f"results_{Vtrainer.model_}_{Vtrainer.encoder_}_fullmetrics.csv"
)
metrics_to_csv(split_metrics, out_path=out_path)
if __name__ == "__main__":
config_path = "./resources/config.yaml"
main(config_path)