forked from EleutherAI/lm-evaluation-harness
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwrite_out.py
62 lines (49 loc) · 1.99 KB
/
write_out.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import numpy as np
import os
import random
from lm_eval import tasks
from lm_eval.utils import join_iters
EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--output_base_path', required=True)
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--sets', type=str, default="val") # example: val,test
parser.add_argument('--num_fewshot', type=int, default=1)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--num_examples', type=int, default=1)
return parser.parse_args()
def main():
args = parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS
else:
task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names)
os.makedirs(args.output_base_path, exist_ok=True)
for task_name, task in task_dict.items():
iters = []
for set in args.sets.split(","):
if set == 'train' and task.has_train_docs():
docs = task.train_docs()
if set == 'val' and task.has_validation_docs():
docs = task.validation_docs()
if set == 'test' and task.has_test_docs():
docs = task.test_docs()
iters.append(docs)
docs = join_iters(iters)
with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs):
f.write(EXAMPLE_DIVIDER.format(i=i))
ctx = task.fewshot_context(
doc=doc,
provide_description=args.provide_description,
num_fewshot=args.num_fewshot,
)
f.write(ctx + "\n")
if __name__ == "__main__":
main()