Skip to content

Commit 7f0887c

Browse files
authored
feat: add stdlib.bool_confidence function (#1368)
Signed-off-by: Louis Mandel <[email protected]>
1 parent 2705d94 commit 7f0887c

File tree

2 files changed

+78
-35
lines changed

2 files changed

+78
-35
lines changed

src/pdl/pdl_stdlib.pdl

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,24 @@
11

22
defs:
3-
reward:
3+
bool_confidence:
44
function:
55
response: object
66
return:
77
lang: python
88
code: |
9-
import math
10-
11-
def _get_logprob(value: str, top_logprobs):
12-
min_logprob = math.inf
13-
for logprob in top_logprobs:
14-
if value.startswith(logprob["token"]):
15-
return logprob["logprob"]
16-
min_logprob = min(min_logprob, logprob['logprob'])
17-
return min_logprob
18-
19-
def _find_first_token(content: str, logprobs):
20-
for logprob in reversed(logprobs):
21-
if content.startswith(logprob["token"]):
22-
return logprob
23-
assert False
24-
25-
def reward(response):
26-
content = response['choices'][0]['message']['content']
27-
if (content != 'true' and content != 'false'):
28-
raise Exception(f'Wrong value: {content}')
29-
30-
first_token_logprob = _find_first_token(content, response['choices'][0]['logprobs']['content'])
31-
top_logprobs = first_token_logprob["top_logprobs"]
32-
33-
lp_true = _get_logprob("true", top_logprobs)
34-
lp_false = _get_logprob("false", top_logprobs)
35-
p_true = math.exp(lp_true)
36-
p_false = math.exp(lp_false)
37-
if p_true == 0.0:
38-
result = -math.inf
39-
else:
40-
result = math.log(p_true / (p_true + p_false))
9+
from pdl.pdl_stdlib import bool_confidence
10+
result = bool_confidence(response)
4111

42-
return result
4312

13+
reward:
14+
function:
15+
response: object
16+
return:
17+
lang: python
18+
code: |
19+
from pdl.pdl_stdlib import reward
4420
result = reward(response)
4521

46-
4722
llm_as_judge:
4823
function:
4924
model: string

src/pdl/pdl_stdlib.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import math
2+
3+
4+
def _get_logprob(value: str, top_logprobs):
5+
min_logprob = math.inf
6+
for logprob in top_logprobs:
7+
if value.startswith(logprob["token"]):
8+
return logprob["logprob"]
9+
min_logprob = min(min_logprob, logprob["logprob"])
10+
return min_logprob
11+
12+
13+
def _find_first_token(content: str, logprobs):
14+
for logprob in reversed(logprobs):
15+
if content.startswith(logprob["token"]):
16+
return logprob
17+
assert False
18+
19+
20+
def reward(response):
21+
content = response["choices"][0]["message"]["content"]
22+
if content not in ["true", "false"]:
23+
raise ValueError(f"Wrong value: {content}")
24+
25+
first_token_logprob = _find_first_token(
26+
content, response["choices"][0]["logprobs"]["content"]
27+
)
28+
top_logprobs = first_token_logprob["top_logprobs"]
29+
30+
lp_true = _get_logprob("true", top_logprobs)
31+
lp_false = _get_logprob("false", top_logprobs)
32+
p_true = math.exp(lp_true)
33+
p_false = math.exp(lp_false)
34+
if p_true == 0.0:
35+
result = -math.inf
36+
else:
37+
result = math.log(p_true / (p_true + p_false))
38+
39+
return result
40+
41+
42+
def bool_confidence(response):
43+
content = response["choices"][0]["message"]["content"]
44+
if content not in ["true", "false"]:
45+
raise ValueError(f"Wrong value: {content}")
46+
47+
first_token_logprob = _find_first_token(
48+
content, response["choices"][0]["logprobs"]["content"]
49+
)
50+
top_logprobs = first_token_logprob["top_logprobs"]
51+
52+
lp_true = _get_logprob("true", top_logprobs)
53+
lp_false = _get_logprob("false", top_logprobs)
54+
p_true = math.exp(lp_true)
55+
p_false = math.exp(lp_false)
56+
match content:
57+
case "true":
58+
p_content = p_true
59+
case "false":
60+
p_content = p_false
61+
case _:
62+
assert False
63+
if p_content == 0.0:
64+
result = -math.inf
65+
else:
66+
result = math.log(p_content / (p_true + p_false))
67+
68+
return result

0 commit comments

Comments
 (0)