-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathmain_ensemble.py
77 lines (61 loc) · 2.9 KB
/
main_ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from shutil import copy
from copy import deepcopy
import torch
import os
import numpy as np
from prototree.prototree import ProtoTree
from util.log import Log
from util.args import get_args, save_args, get_optimizer
from util.data import get_dataloaders
from util.analyse import analyse_ensemble
import gc
from main_tree import run_tree
def run_ensemble():
all_args = get_args()
# Create a logger
log = Log(all_args.log_dir)
print("Log dir: ", all_args.log_dir, flush=True)
# Log the run arguments
save_args(all_args, log.metadata_dir)
if not all_args.disable_cuda and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
if not os.path.isdir(os.path.join(all_args.log_dir, "files")):
os.mkdir(os.path.join(all_args.log_dir, "files"))
# Obtain the data loaders
trainloader, projectloader, test_loader, classes, num_channels = get_dataloaders(all_args)
log_dir_orig = all_args.log_dir
trained_orig_trees = []
trained_pruned_trees = []
trained_pruned_projected_trees = []
orig_test_accuracies = []
pruned_test_accuracies = []
pruned_projected_test_accuracies = []
project_infos = []
infos_sample_max = []
infos_greedy = []
infos_fidelity = []
# Train trees in ensemble one by one and save corresponding trees and accuracies
for pt in range(1,all_args.nr_trees_ensemble+1):
torch.cuda.empty_cache()
print("\nTraining tree ",pt, "/", all_args.nr_trees_ensemble, flush=True)
log.log_message('Training tree %s...'%str(pt))
args = deepcopy(all_args)
args.log_dir = os.path.join(log_dir_orig,'tree_'+str(pt))
trained_tree, pruned_tree, pruned_projected_tree, original_test_acc, pruned_test_acc, pruned_projected_test_acc, project_info, eval_info_samplemax, eval_info_greedy, info_fidelity = run_tree(args)
trained_orig_trees.append(trained_tree)
trained_pruned_trees.append(pruned_tree)
trained_pruned_projected_trees.append(pruned_projected_tree)
orig_test_accuracies.append(original_test_acc)
pruned_test_accuracies.append(pruned_test_acc)
pruned_projected_test_accuracies.append(pruned_projected_test_acc)
project_infos.append(project_info)
infos_sample_max.append(eval_info_samplemax)
infos_greedy.append(eval_info_greedy)
infos_fidelity.append(info_fidelity)
if pt > 1:
#analyse ensemble with > 1 trees:
analyse_ensemble(log, all_args, test_loader, device, trained_orig_trees, trained_pruned_trees, trained_pruned_projected_trees, orig_test_accuracies, pruned_test_accuracies, pruned_projected_test_accuracies, project_infos, infos_sample_max, infos_greedy, infos_fidelity)
if __name__ == '__main__':
run_ensemble()