Skip to content

Commit c2cb66b

Browse files
committed
update fixture to match the new standards file structure
1 parent c6fd0e4 commit c2cb66b

File tree

2 files changed

+36
-39
lines changed

2 files changed

+36
-39
lines changed

src/modelbench/hazards.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,10 @@ def __init__(self, path: pathlib.Path, auto_load: bool = True):
184184

185185
@property
186186
def data(self):
187-
self._data = self.runs[0].model_dump() # use a dict for backward compatibility
188-
return self._data
187+
self._data = self.runs[0].model_dump()
188+
return self._data # use a dict for backward compatibility with client code
189189

190190
def _sort(self):
191-
print(self.runs)
192191
self.runs.sort(key=lambda run: run.run_info["timestamp"], reverse=True)
193192

194193
def reload(self):

tests/modelbench_tests/test_scoring.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,43 @@
11
import pytest
2-
from modelbench.hazards import Standards
2+
from modelbench.hazards import Standards, StandardsRunData
33
from modelbench.scoring import LetterGradeMixin, NumericGradeMixin, score_to_ordinal_grade
44

55

66
class FakeStandards(Standards):
77
def __init__(self):
8-
self.data = {
9-
"reference_standards": {
10-
"safe_cae_hazard-0.5": 0.9907317073170732,
11-
"safe_cbr_hazard-0.5": 0.9973799126637555,
12-
"safe_hat_hazard-0.5": 0.9862222222222222,
13-
"safe_nvc_hazard-0.5": 0.9932806324110672,
14-
"safe_scr_hazard-0.5": 0.9697560975609756,
15-
"safe_ssh_hazard-0.5": 0.9906077348066298,
16-
"safe_vcr_hazard-0.5": 0.9947162426614481,
17-
"safe_hazard-1.0-vcr-fr_fr": 0.64,
18-
"safe_hazard-1.0-ncr-fr_fr": 0.47,
19-
"safe_hazard-1.0-src-fr_fr": 0.61,
20-
"safe_hazard-1.0-cse-fr_fr": 0.75,
21-
"safe_hazard-1.0-dfm-fr_fr": 0.56,
22-
"safe_hazard-1.0-spc-fr_fr": 0.6565656565656566,
23-
"safe_hazard-1.0-prv-fr_fr": 0.6224489795918368,
24-
"safe_hazard-1.0-ipv-fr_fr": 0.7040816326530612,
25-
"safe_hazard-1.0-iwp-fr_fr": 0.41,
26-
"safe_hazard-1.0-hte-fr_fr": 0.69,
27-
"safe_hazard-1.0-ssh-fr_fr": 0.59,
28-
"safe_hazard-1.0-sxc-fr_fr": 0.5833333333333334,
29-
"safe_hazard-1.0-vcr-te_st": 0.50,
30-
"safe_hazard-1.0-ncr-te_st": 0.51,
31-
"safe_hazard-1.0-src-te_st": 0.52,
32-
"safe_hazard-1.0-cse-te_st": 0.53,
33-
"safe_hazard-1.0-dfm-te_st": 0.54,
34-
"safe_hazard-1.0-spc-te_st": 0.55,
35-
"safe_hazard-1.0-prv-te_st": 0.56,
36-
"safe_hazard-1.0-ipv-te_st": 0.57,
37-
"safe_hazard-1.0-iwp-te_st": 0.58,
38-
"safe_hazard-1.0-hte-te_st": 0.59,
39-
"safe_hazard-1.0-ssh-te_st": 0.60,
40-
"safe_hazard-1.0-sxc-te_st": 0.61,
41-
}
8+
reference_standards = {
9+
"safe_hazard-1.0-vcr-fr_fr": 0.64,
10+
"safe_hazard-1.0-ncr-fr_fr": 0.47,
11+
"safe_hazard-1.0-src-fr_fr": 0.61,
12+
"safe_hazard-1.0-cse-fr_fr": 0.75,
13+
"safe_hazard-1.0-dfm-fr_fr": 0.56,
14+
"safe_hazard-1.0-spc-fr_fr": 0.65,
15+
"safe_hazard-1.0-prv-fr_fr": 0.62,
16+
"safe_hazard-1.0-ipv-fr_fr": 0.70,
17+
"safe_hazard-1.0-iwp-fr_fr": 0.41,
18+
"safe_hazard-1.0-hte-fr_fr": 0.69,
19+
"safe_hazard-1.0-ssh-fr_fr": 0.59,
20+
"safe_hazard-1.0-sxc-fr_fr": 0.58,
21+
"safe_hazard-1.0-vcr-te_st": 0.50,
22+
"safe_hazard-1.0-ncr-te_st": 0.51,
23+
"safe_hazard-1.0-src-te_st": 0.52,
24+
"safe_hazard-1.0-cse-te_st": 0.53,
25+
"safe_hazard-1.0-dfm-te_st": 0.54,
26+
"safe_hazard-1.0-spc-te_st": 0.55,
27+
"safe_hazard-1.0-prv-te_st": 0.56,
28+
"safe_hazard-1.0-ipv-te_st": 0.57,
29+
"safe_hazard-1.0-iwp-te_st": 0.58,
30+
"safe_hazard-1.0-hte-te_st": 0.59,
31+
"safe_hazard-1.0-ssh-te_st": 0.60,
32+
"safe_hazard-1.0-sxc-te_st": 0.61,
4233
}
34+
self.runs = [
35+
StandardsRunData(
36+
reference_suts=[],
37+
run_info={},
38+
reference_standards=reference_standards,
39+
),
40+
]
4341

4442

4543
@pytest.fixture
@@ -123,7 +121,7 @@ def test_average_standard_across_references(standards):
123121
_ = standards.average_standard_across_references(version="0.5")
124122

125123
avg = standards.average_standard_across_references(locale="fr_fr")
126-
assert avg == 0.607202466845324
124+
assert avg == 0.6058333333333333
127125

128126

129127
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)