Skip to content

Commit

Permalink
simplify locale handling by using lowercase strings everywhere (#768)
Browse files Browse the repository at this point in the history
* simplify locale handling by using lowercase strings everywhere

* add test for version string

* remove guard clause

* poetry is now complaining about missing project.name in CI

* update lockfile

* force a refresh of the lockfile
  • Loading branch information
rogthefrog authored Jan 6, 2025
1 parent fe91f3e commit 01b49c5
Show file tree
Hide file tree
Showing 16 changed files with 1,269 additions and 1,221 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:
cache: 'poetry'

- name: Install dependencies
run: poetry install --no-interaction --with dev --extras all_plugins
run: poetry lock && poetry install --no-interaction --with dev --extras all_plugins

- name: Lint formatting
run: poetry run black --check .
Expand Down
2,107 changes: 1,058 additions & 1,049 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[project]
name = "modelbench"
requires-python = ">=3.10, !=3.12.5, <3.13"

[tool.poetry]
Expand Down
7 changes: 4 additions & 3 deletions src/modelbench/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import casefy
from modelgauge.sut import PromptResponseSUT
from modelgauge.tests.safe_v1 import Locale
from modelgauge.locales import DEFAULT_LOCALE, validate_locale

from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1, Standards, STANDARDS
from modelbench.scoring import LetterGradeMixin, score_to_ordinal_grade
Expand All @@ -29,7 +29,7 @@ def __init__(
self.score = 0.0
self._scoring_log = {}

def numeric_grade(self, locale: str = "en_us", version: str = "1.0b", standards: Standards = STANDARDS):
def numeric_grade(self, locale: str = DEFAULT_LOCALE, version: str = "1.0b", standards: Standards = STANDARDS):
total_scored_items = 0
total_safe_items = 0
for hazard_score in self.hazard_scores:
Expand Down Expand Up @@ -132,7 +132,8 @@ class GeneralPurposeAiChatBenchmarkV1(BenchmarkDefinition):
- Hazard score = Fraction unsafe
"""

def __init__(self, locale: Locale, prompt_set: str, evaluator: str = "default"):
def __init__(self, locale: str, prompt_set: str, evaluator: str = "default"):
validate_locale(locale)
self.locale = locale
self.prompt_set = prompt_set
self.evaluator = evaluator
Expand Down
24 changes: 11 additions & 13 deletions src/modelbench/hazards.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import casefy
from modelgauge.base_test import BaseTest, PromptResponseTest
from modelgauge.locales import display_for, validate_locale
from modelgauge.records import TestRecord
from modelgauge.secret_values import RawSecrets
from modelgauge.test_registry import TESTS
from modelgauge.tests.safe_v1 import BaseSafeTestVersion1, Locale, SafeTestResult
from modelgauge.tests.safe_v1 import BaseSafeTestVersion1, SafeTestResult
from pydantic import BaseModel, ConfigDict

from modelbench.scoring import LetterGradeMixin, NumericGradeMixin, ValueEstimate
Expand Down Expand Up @@ -63,16 +64,17 @@ class SafeHazardV1(HazardDefinition, ABC):

all_hazard_keys = BaseSafeTestVersion1.hazards

def __init__(self, hazard_key: str, locale: Locale, prompt_set: str, evaluator="default"):
def __init__(self, hazard_key: str, locale: str, prompt_set: str, evaluator="default"):
assert hazard_key in self.all_hazard_keys, f"Unknown hazard key {hazard_key}."
validate_locale(locale)
self.hazard_key = hazard_key
self.locale = locale
self.prompt_set = prompt_set
self._evaluator = evaluator
super().__init__()

def name(self):
return f"Safe {self.hazard_key.title()} {self.locale.value} Hazard V1.0 ({self.prompt_set.title()})"
return f"Safe {self.hazard_key.title()} {display_for(self.locale)} Hazard V1.0 ({self.prompt_set.title()})"

def key(self):
return f"safe_hazard-1_0-{self.hazard_key}"
Expand Down Expand Up @@ -170,16 +172,12 @@ def reference_standard_for(self, name):
return self.data["reference_standards"][name]

def average_standard_across_references(self, locale: str = "", version: str = "1.0") -> float:
values = []
if version == "0.5":
raise ValueError("Version 0.5 is no longer supported.")
elif version == "1.0":
if not locale:
raise ValueError("Locale is required for v1.0 scoring.")
locale = locale.lower()
values = [v for k, v in self.data["reference_standards"].items() if locale in k]
else:
raise ValueError(f"Unknown benchmark version: {version}")
assert version == "1.0", "Only version 1.0 is supported."
if not locale:
raise ValueError("Locale is required for v1.0 scoring.")
locale = locale.lower()
validate_locale(locale)
values = [v for k, v in self.data["reference_standards"].items() if locale in k]
assert len(values), "No reference values found"
return fmean(values)

Expand Down
63 changes: 32 additions & 31 deletions src/modelbench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,25 @@
from typing import List

import click

import modelgauge
import termcolor
from click import echo
from modelgauge.config import load_secrets_from_config, raise_if_missing_from_config, write_default_config
from modelgauge.load_plugins import load_plugins
from modelgauge.locales import DEFAULT_LOCALE, EN_US, LOCALES, validate_locale
from modelgauge.sut import SUT
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
from modelgauge.tests.safe_v1 import PROMPT_SETS
from rich.console import Console
from rich.table import Table

import modelgauge
from modelbench.benchmark_runner import BenchmarkRunner, TqdmRunTracker, JsonRunTracker
from modelbench.benchmark_runner import BenchmarkRunner, JsonRunTracker, TqdmRunTracker
from modelbench.benchmarks import BenchmarkDefinition, GeneralPurposeAiChatBenchmarkV1
from modelbench.consistency_checker import ConsistencyChecker, summarize_consistency_check_results
from modelbench.hazards import STANDARDS
from modelbench.record import dump_json
from modelgauge.config import load_secrets_from_config, raise_if_missing_from_config, write_default_config
from modelgauge.load_plugins import load_plugins
from modelgauge.sut import SUT
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
from modelgauge.tests.safe_v1 import PROMPT_SETS, Locale


def load_local_plugins(_, __, path: pathlib.Path):
Expand Down Expand Up @@ -93,14 +95,14 @@ def cli() -> None:
@click.option(
"--locale",
"-l",
type=click.Choice(["en_us", "fr_fr"], case_sensitive=False),
default="en_us",
help=f"Locale for v1.0 benchmark (Default: en_us)",
type=click.Choice(LOCALES, case_sensitive=False),
default=DEFAULT_LOCALE,
help=f"Locale for v1.0 benchmark (Default: {DEFAULT_LOCALE})",
multiple=False,
)
@click.option(
"--prompt-set",
type=click.Choice(PROMPT_SETS.keys()),
type=click.Choice(list(PROMPT_SETS.keys())),
default="practice",
help="Which prompt set to use",
show_default=True,
Expand Down Expand Up @@ -131,9 +133,11 @@ def benchmark(
start_time = datetime.now(timezone.utc)
suts = find_suts_for_sut_argument(sut_uids)
if locale == "all":
locales = Locale
locales = LOCALES
else:
locales = [Locale(locale)]
locales = [
locale.lower(),
]

benchmarks = [get_benchmark(version, l, prompt_set, evaluator) for l in locales]

Expand Down Expand Up @@ -217,7 +221,7 @@ def find_suts_for_sut_argument(sut_uids: List[str]):

def ensure_ensemble_annotators_loaded():
try:
from modelgauge.private_ensemble_annotator_set import EnsembleAnnotatorSet, ensemble_secrets
from modelgauge.private_ensemble_annotator_set import ensemble_secrets, EnsembleAnnotatorSet

private_annotators = EnsembleAnnotatorSet(secrets=ensemble_secrets(load_secrets_from_config()))
modelgauge.tests.safe_v1.register_private_annotator_tests(private_annotators, "ensemble")
Expand All @@ -227,24 +231,21 @@ def ensure_ensemble_annotators_loaded():
return False


def get_benchmark(version: str, locale: Locale, prompt_set: str, evaluator) -> BenchmarkDefinition:
if version == "0.5":
raise ValueError("Version 0.5 is no longer supported.")
elif version == "1.0":
if evaluator == "ensemble":
if not ensure_ensemble_annotators_loaded():
print(f"Can't build benchmark for {str} {locale} {prompt_set} {evaluator}; couldn't load evaluator.")
exit(1)
return GeneralPurposeAiChatBenchmarkV1(locale, prompt_set, evaluator)
else:
raise ValueError(f"Unknown benchmark version: {version}")
def get_benchmark(version: str, locale: str, prompt_set: str, evaluator) -> BenchmarkDefinition:
assert version == "1.0", ValueError(f"Version {version} is not supported.")
validate_locale(locale)
if evaluator == "ensemble":
if not ensure_ensemble_annotators_loaded():
print(f"Can't build benchmark for {str} {locale} {prompt_set} {evaluator}; couldn't load evaluator.")
exit(1)
return GeneralPurposeAiChatBenchmarkV1(locale, prompt_set, evaluator)


def score_benchmarks(benchmarks, suts, max_instances, json_logs=False, debug=False):
run = run_benchmarks_for_suts(benchmarks, suts, max_instances, debug=debug, json_logs=json_logs)
benchmark_scores = []
for bd, score_dict in run.benchmark_scores.items():
for k, score in score_dict.items():
for _, score_dict in run.benchmark_scores.items():
for _, score in score_dict.items():
benchmark_scores.append(score)
return benchmark_scores

Expand Down Expand Up @@ -344,13 +345,13 @@ def update_standards_to(standards_file):
exit(1)

benchmarks = []
for l in [Locale.EN_US]:
for l in [EN_US]:
for prompt_set in PROMPT_SETS:
benchmarks.append(GeneralPurposeAiChatBenchmarkV1(l, prompt_set, "ensemble"))
run_result = run_benchmarks_for_suts(benchmarks, reference_suts, None)
all_hazard_numeric_scores = defaultdict(list)
for benchmark, scores_by_sut in run_result.benchmark_scores.items():
for sut, benchmark_score in scores_by_sut.items():
for _, scores_by_sut in run_result.benchmark_scores.items():
for _, benchmark_score in scores_by_sut.items():
for hazard_score in benchmark_score.hazard_scores:
all_hazard_numeric_scores[hazard_score.hazard_definition.uid].append(hazard_score.score.estimate)

Expand Down
34 changes: 34 additions & 0 deletions src/modelgauge/locales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Keep these in all lowercase
# Always and only use these named constants in function calls.
# They are meant to simplify the Locale(enum) and prevent case errors.
EN_US = "en_us"
FR_FR = "fr_fr"
ZH_CN = "zh_cn"
HI_IN = "hi_in"
DEFAULT_LOCALE = "en_us"

# add the other languages after we have official and practice prompt sets
LOCALES = (EN_US, FR_FR)


def is_valid(locale: str) -> bool:
return locale in LOCALES


def display_for(locale: str) -> str:
chunks = locale.split("_")
try:
assert len(chunks) == 2
display = f"{chunks[0].lower()}_{chunks[1].upper()}"
except:
display = locale
return display


def bad_locale(locale: str) -> str:
return f"You requested \"{locale}.\" Only {', '.join(LOCALES)} (in lowercase) are supported."


def validate_locale(locale) -> bool:
assert is_valid(locale), bad_locale(locale)
return True
38 changes: 11 additions & 27 deletions src/modelgauge/tests/safe_v1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import csv
from abc import ABC
from enum import Enum
from pydantic import BaseModel
from typing import Dict, List, Mapping, Optional

from modelgauge.aggregations import get_measurement_stats_by_key
Expand All @@ -10,12 +9,14 @@
from modelgauge.default_annotator_set import DEFAULT_ANNOTATOR_SET
from modelgauge.dependency_helper import DependencyHelper
from modelgauge.external_data import ExternalData, WebData
from modelgauge.locales import EN_US, LOCALES, validate_locale
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem, TestItemAnnotations
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.test_decorator import modelgauge_test
from modelgauge.test_registry import TESTS
from pydantic import BaseModel


# TODO: agree on a standard for this
Expand Down Expand Up @@ -52,22 +53,6 @@ def description(cls) -> SecretDescription:
)


class Locale(Enum):
EN_US = "en_US" # English, United States
FR_FR = "fr_FR" # French, France
# TODO: uncomment when we have prompt support for these locales
# ZH_CN = "zh_CN" # Simplified Chinese, China
# HI_IN = "hi_IN" # Hindi, India

@classmethod
def _missing_(cls, key: object):
key = str(key).lower()
for member in cls:
if member.value.lower() == key:
return member
return None


class SafePersonasVersion1(str, Enum):
NORMAL = "normal"
SKILLED = "skilled"
Expand Down Expand Up @@ -112,10 +97,9 @@ class BaseSafeTestVersion1(PromptResponseTest, ABC):
annotators: AnnotatorSet

@staticmethod
def create_uid(hazard: str, locale: str | Locale, prompt_set: str, evaluator=None):
def create_uid(hazard: str, locale: str, prompt_set: str, evaluator=None):
assert prompt_set in PROMPT_SETS
if isinstance(locale, Locale):
locale = locale.value
validate_locale(locale)
if evaluator is None or evaluator == "default":
postfix = ""
else:
Expand All @@ -127,7 +111,7 @@ def __init__(
self,
uid: str,
hazard: str,
locale: Locale,
locale: str,
persona_types: List[SafePersonasVersion1],
prompt_set: str,
token: Optional[ModellabFileDownloadToken] = None,
Expand All @@ -140,6 +124,7 @@ def __init__(
persona_types
), f"Must specify a unique set of persona types, but got {persona_types}"
assert prompt_set in PROMPT_SETS, f"Invalid prompt set {prompt_set}. Must be one of {PROMPT_SETS.keys()}."
validate_locale(locale)

self.hazard = hazard
self.locale = locale
Expand All @@ -155,9 +140,8 @@ def _check_annotators(cls):

def __localize_filename(self) -> str:
# TODO implement a better standard for prompt file names by locale
locale = self.locale.value.lower()
if locale and locale != "en_us":
tail = f"_{locale}"
if self.locale != EN_US:
tail = f"_{self.locale}"
else:
tail = ""
filename = f"{self.prompt_set_file_name}{tail}"
Expand Down Expand Up @@ -202,7 +186,7 @@ def make_test_items(self, dependency_helper: DependencyHelper) -> List[TestItem]
# Check that prompt is for correct hazard/persona/locale.
hazard = row["hazard"].split("_")[0]
persona = SafePersonasVersion1(row["persona"])
locale = Locale(row["locale"])
locale = row["locale"].lower()
if not hazard == self.hazard:
continue
if persona not in self.persona_types:
Expand Down Expand Up @@ -261,15 +245,15 @@ class SafeTestVersion1(BaseSafeTestVersion1):


def register_tests(cls, evaluator=None):
for locale in Locale:
for locale in LOCALES:
for hazard in cls.hazards:
for prompt_set in PROMPT_SETS:
test_uid = BaseSafeTestVersion1.create_uid(hazard, locale, prompt_set, evaluator)
# TODO: Remove this 'if', duplicates are already caught during registration and should raise errors.
if not test_uid in TESTS.keys():
token = None
# only practice prompt sets in English are publicly available for now
if prompt_set == "official" or locale != Locale.EN_US:
if prompt_set == "official" or locale != EN_US:
token = InjectSecret(ModellabFileDownloadToken)
TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set, token)

Expand Down
Loading

0 comments on commit 01b49c5

Please sign in to comment.