-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathprocess_results_interpolation.py
81 lines (61 loc) · 2.47 KB
/
process_results_interpolation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Merge results from distributed evaluation."""
import torch
import argparse
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from lm_eval.tasks import get_task, ALL_TASKS
from utils import aggregate
from utils import PILE_WEIGHTS
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--results_dir', type=str, default='results/')
parser.add_argument('--tasks', type=str, default='pile_all')
parser.add_argument('--world_size', type=int, default=32)
args = parser.parse_args()
if args.tasks == 'pile_all':
from utils import PILE_WEIGHTS
task_names = PILE_WEIGHTS.keys()
elif args.tasks == 'super_glue':
task_names = ['boolq', 'cb', 'copa', 'multirc', 'record', 'rte', 'wic', 'wsc']
else:
from utils import pattern_match
task_names = pattern_match(args.tasks.split(","), ALL_TASKS)
metrics_before = {}
metrics_after = {}
pile_all = 0
for task_name in task_names:
all_stats = []
all_losses = []
training_costs = []
retrieval_costs = []
task = get_task(task_name)(download=False)
for rank in tqdm(range(args.world_size)):
results_file = '%s/%s_%d.pth' % (args.results_dir, task_name, rank)
try:
results = torch.load(results_file)
all_stats += results[0]
except:
print('Not found: %s' % (results_file))
all_stats = np.array(all_stats)
aggregate_stats = []
for j in range(all_stats.shape[1]):
aggregate_row = all_stats[:, j]
aggregate_stats.append(aggregate(aggregate_row, task))
plot_stats = []
metric_name = list(aggregate_stats[0].keys())[2]
for aggregate_entry in aggregate_stats:
plot_stats.append(aggregate_entry[metric_name])
metrics_before[task_name] = plot_stats[0]
metrics_after[task_name] = min(plot_stats)
plt.figure()
plt.plot(np.arange(0, 0.2, 0.01), plot_stats)
plt.ylabel(metric_name)
plt.xlabel('alpha')
plt.savefig('%s/alpha_%s_%s.pdf' % (args.results_dir, task_name, metric_name))
pile_all += np.array(plot_stats) * PILE_WEIGHTS[task_name]
for task_name in task_names:
print(task_name, metrics_before[task_name], metrics_after[task_name])
for task_name in task_names:
print('%s %.2f' % (task_name, metrics_after[task_name]))
print(pile_all)