forked from openai/evals
-
Notifications
You must be signed in to change notification settings - Fork 0
/
classify.py
124 lines (107 loc) · 4.47 KB
/
classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Generic eval that uses a prompt + classification.
"""
from collections import Counter
from random import Random
from typing import Any, Optional, Union
import evals
import evals.record
from evals.elsuite.modelgraded.classify_utils import classify, sample_and_concat_n_completions
from evals.elsuite.utils import PromptFn, scrub_formatting_from_prompt
class ModelBasedClassify(evals.Eval):
def __init__(
self,
modelgraded_spec: str,
*args,
modelgraded_spec_args: Optional[dict[str, dict[str, str]]] = None,
sample_kwargs: Optional[dict[str, Any]] = None,
eval_kwargs: Optional[dict[str, Any]] = None,
multicomp_n: Union[int, str] = 1,
eval_type: Optional[str] = None,
metaeval: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
# treat last completion_fn as eval_completion_fn
self.eval_completion_fn = self.completion_fns[-1]
if len(self.completion_fns) > 1:
self.completion_fns = self.completion_fns[:-1]
n_models = len(self.completion_fns)
self.sample_kwargs = {"max_tokens": 1024}
self.sample_kwargs.update(sample_kwargs or {})
self.eval_kwargs = {"max_tokens": 1024}
self.eval_kwargs.update(eval_kwargs or {})
self.metaeval = metaeval
self.modelgraded_spec_args = modelgraded_spec_args or {}
self.eval_type = eval_type
if multicomp_n == "from_models":
assert n_models > 1
self.multicomp_n = n_models
else:
assert isinstance(multicomp_n, int)
self.multicomp_n = multicomp_n
if len(self.completion_fns) > 1:
assert self.multicomp_n == n_models
self.mg = self.registry.get_modelgraded_spec(modelgraded_spec)
def eval_sample(self, test_sample: dict, rng: Random) -> None:
"""Evaluate a single sample.
Recorded metrics are always: one of the self.choice_strings, or "__invalid__".
"""
# process test_sample
for k in self.mg.input_outputs:
test_sample[k] = scrub_formatting_from_prompt(test_sample[k])
# run policy completions
completions = {}
for k, v in self.mg.input_outputs.items():
if v in test_sample: # test_sample already has completion, skip.
continue
if self.multicomp_n > 1:
completion = sample_and_concat_n_completions(
self.completion_fns,
prompt=test_sample[k],
template_i=self.mg.output_template,
sample_kwargs=self.sample_kwargs,
n=self.multicomp_n,
)
else:
get_input_completion = PromptFn(
test_sample[k], completion_fn=self.completion_fn, **self.sample_kwargs
)
completion, _ = get_input_completion()
completions[v] = completion
# run modelgraded eval
metrics = {}
choice, info = classify(
mg=self.mg,
completion_fn=self.eval_completion_fn,
completion_kwargs=self.eval_kwargs,
eval_type=self.eval_type,
n=self.multicomp_n,
format_kwargs={**completions, **test_sample, **self.modelgraded_spec_args},
)
metrics.update(dict(choice=choice, score=info["score"]))
# run metaeval if requested
if self.metaeval:
assert "choice" in test_sample
metrics["metascore"] = choice == test_sample["choice"]
evals.record.record_metrics(**metrics)
return choice
def run(self, recorder):
samples = self.get_samples()
self.eval_all_samples(recorder, samples)
record_metrics = {}
all_sample_metrics = recorder.get_metrics()
if not all_sample_metrics:
return record_metrics
# record the counts
choices = [m["choice"] for m in all_sample_metrics]
counts = dict(Counter(choices))
record_metrics.update({f"counts/{k}": v for k, v in counts.items()})
# record the scores
scores = [m["score"] for m in all_sample_metrics if m["score"] is not None]
if scores:
record_metrics[f"score"] = sum(scores) / len(scores)
metascores = [m["metascore"] for m in all_sample_metrics if "metascore" in m]
if metascores:
record_metrics[f"metascore"] = sum(metascores) / len(metascores)
return record_metrics