-
Notifications
You must be signed in to change notification settings - Fork 846
/
Copy pathtest_falcon_color.py
115 lines (97 loc) · 2.86 KB
/
test_falcon_color.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import re
import json
import torch
import random
import transformers
from tqdm import tqdm
from datasets import DatasetDict, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from prompt_pattern import PROMPT, STOP_WORD
def clean(content):
pattern = '<<.+>>'
result = re.findall(pattern, content)
for t in result:
content = content.replace(t, '')
content = content.replace('\n', '. ')
return content
def load_multi_line_json(f):
data = ''
all_data = []
raw_data =f.readlines()
for line in raw_data:
data = data + line
if (line.startswith('}')):
all_data.append(json.loads(data))
data = ''
return all_data
model = "model/falcon/"
set_seed(2023)
random.seed(2023)
pattern = PROMPT['VANILLA']
stop = STOP_WORD['VANILLA']
def load_from_json(folder_path):
demo_file = os.path.join(folder_path, 'demo.txt')
with open(demo_file, 'r') as fin:
demo = fin.readlines()
demo = ''.join(demo)
demo = demo.strip() + '\n\n'
data_file = os.path.join(folder_path, 'test.json')
with open(data_file, 'r') as fin:
raw_dataset = json.load(fin)
dataset = {
'problem': [],
'solution': []
}
for data in raw_dataset['examples']:
dataset['problem'].append('Problem: ' + data['input'] + '\nAnswer:')
tmp_ans = []
for k, v in data['target_scores'].items():
if (v == 1):
tmp_ans.append(k)
assert(len(tmp_ans) > 0)
dataset['solution'].append(tmp_ans)
return demo, Dataset.from_dict(dataset)
demo, test_dataset = load_from_json('dataset/colored_objects')
print(test_dataset)
print(demo)
batch_size = 6
tokenizer = AutoTokenizer.from_pretrained(model, padding_side='left')
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
batch_size=batch_size
)
pipeline.tokenizer.pad_token_id = tokenizer.eos_token_id
fout = open('result/falcon-color-3shot.json', 'w')
inputs = []
origin_data = []
def make_query():
global inputs, origin_data, fout
sequences = pipeline(
inputs,
max_length=450,
do_sample=False,
top_k=1,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
)
for pred, data in zip(sequences, origin_data):
tmp_data = data
tmp_data['real_ans'] = data['solution']
tmp_data['pred_ans'] = pred[0]['generated_text']
fout.write(json.dumps(tmp_data, indent=4) + '\n')
origin_data = []
inputs = []
for step, data in enumerate(tqdm(test_dataset)):
inputs.append(demo + data['problem'])
origin_data.append(data)
if (len(inputs) == batch_size):
make_query()
if (len(inputs) != 0):
make_query()
fout.close()