diff --git a/docs/evaluation.md b/docs/evaluation.md index 1d6cfb4ed..33e74dfa9 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -127,6 +127,23 @@ env.set_concurrency(256) The `renderer` client type requires the optional renderer package. Install it with `uv add "verifiers[renderers]"` before running evals with `--api-client-type renderer`. +#### Model precedence + +`--model` / `-m` sets the inference client's model. Custom environments that +need to know that model can accept `model` in `load_environment()` / their +constructor, or read the injected `model` environment kwarg, instead of +requiring users to repeat the same value in `--env-args`. + +To use a different model inside the environment than the one driving inference, +pass it explicitly via `--env-args`: + +```bash +prime eval run my-env -m google/gemma-3-27b-it -a '{"model": "qwen/qwen3-14b"}' +``` + +That override changes only the environment's view of `model`; the inference +client still uses `--model`. + For convenience, define model endpoints in `./configs/endpoints.toml` to avoid repeating URL and key flags. ```toml diff --git a/tests/scripts/test_eval_model_kwarg.py b/tests/scripts/test_eval_model_kwarg.py new file mode 100644 index 000000000..bd95ef5c1 --- /dev/null +++ b/tests/scripts/test_eval_model_kwarg.py @@ -0,0 +1,33 @@ +from verifiers.scripts.eval import build_eval_config + + +def test_resolved_model_lands_in_extra_env_kwargs(monkeypatch): + raw = { + "env_id": "math-python", + "model": "openai/gpt-4.1-mini", + "api_base_url": "https://example.test/v1", + "api_key_var": "OPENAI_API_KEY", + } + monkeypatch.setenv("OPENAI_API_KEY", "sk-test") + + cfg = build_eval_config(raw) + + assert cfg.model == "openai/gpt-4.1-mini" + assert cfg.extra_env_kwargs.get("model") == "openai/gpt-4.1-mini" + + +def test_env_args_model_overrides_for_env_but_not_client(monkeypatch): + raw = { + "env_id": "math-python", + "model": "openai/gpt-4.1-mini", + "env_args": {"model": "qwen/qwen3-14b"}, + "api_base_url": "https://example.test/v1", + "api_key_var": "OPENAI_API_KEY", + } + monkeypatch.setenv("OPENAI_API_KEY", "sk-test") + + cfg = build_eval_config(raw) + + assert cfg.model == "openai/gpt-4.1-mini" + assert cfg.env_args.get("model") == "qwen/qwen3-14b" + assert cfg.extra_env_kwargs.get("model") is None diff --git a/verifiers/scripts/eval.py b/verifiers/scripts/eval.py index 0c3c07828..d4f633b17 100644 --- a/verifiers/scripts/eval.py +++ b/verifiers/scripts/eval.py @@ -499,6 +499,299 @@ def build_parser() -> argparse.ArgumentParser: return parser +def build_eval_config(raw: dict) -> EvalConfig: + """Build EvalConfig from a raw config dict.""" + env_id = raw["env_id"] + name = raw.get("name") + if name is not None and (not isinstance(name, str) or not name): + raise ValueError("'name' must be a non-empty string when provided.") + + # Resolve num_examples and rollouts_per_example with env defaults + env_defaults = get_env_eval_defaults(env_id) + raw_num_examples = raw.get("num_examples") + raw_rollouts = raw.get("rollouts_per_example") + + num_examples = ( + raw_num_examples + if raw_num_examples is not None + else env_defaults.get("num_examples", DEFAULT_NUM_EXAMPLES) + ) + rollouts_per_example = ( + raw_rollouts + if raw_rollouts is not None + else env_defaults.get("rollouts_per_example", DEFAULT_ROLLOUTS_PER_EXAMPLE) + ) + + if raw_num_examples is None: + source = ( + "pyproject.toml" if "num_examples" in env_defaults else "global default" + ) + logger.debug(f"Using num_examples={num_examples} from {source}") + if raw_rollouts is None: + source = ( + "pyproject.toml" + if "rollouts_per_example" in env_defaults + else "global default" + ) + logger.debug(f"Using rollouts_per_example={rollouts_per_example} from {source}") + + raw_endpoint_id = raw.get("endpoint_id") + raw_model_field = raw.get("model") + if raw_endpoint_id is not None and raw_model_field is not None: + raise ValueError( + "Cannot set both 'endpoint_id' and 'model' in eval config; choose one." + ) + if raw_endpoint_id is not None and not isinstance(raw_endpoint_id, str): + raise ValueError("'endpoint_id' must be a string when provided.") + if isinstance(raw_endpoint_id, str) and not raw_endpoint_id: + raise ValueError("'endpoint_id' must be a non-empty string when provided.") + endpoints_path = raw.get("endpoints_path", DEFAULT_ENDPOINTS_PATH) + resolved_endpoints_file = resolve_endpoints_file(str(endpoints_path)) + if raw_endpoint_id is not None and ( + resolved_endpoints_file is None or resolved_endpoints_file.suffix != ".toml" + ): + raise ValueError( + "'endpoint_id' is only supported with TOML endpoint registries. " + "Set endpoints_path to an endpoints.toml file." + ) + + raw_model = raw_model_field if raw_model_field is not None else DEFAULT_MODEL + endpoint_lookup_id = raw_endpoint_id if raw_endpoint_id is not None else raw_model + raw_client_type = raw.get("api_client_type") + raw_api_key_var = raw.get("api_key_var") + raw_api_base_url = raw.get("api_base_url") + if isinstance(raw_api_base_url, list): + raise ValueError( + "api_base_url lists are no longer supported. " + "Use endpoint_id + endpoints.toml for multi-endpoint configuration." + ) + + # Provider resolution: + # - model IN registry: registry -> provider overrides -> CLI overrides + # - model NOT in registry: provider (default: prime) -> CLI overrides + raw_provider = raw.get("provider") + api_key_override = raw_api_key_var is not None + api_base_url_override = raw_api_base_url is not None + client_type_override = raw_client_type is not None + direct_endpoint_config = ( + raw_endpoint_id is None and api_key_override and api_base_url_override + ) + endpoints = {} if direct_endpoint_config else load_endpoints(endpoints_path) + endpoint_group: list[Endpoint] | None = None + resolved_endpoint_id: str | None = None + + if endpoint_lookup_id in endpoints: + endpoint_group = endpoints[endpoint_lookup_id] + resolved_endpoint_id = endpoint_lookup_id + endpoint = endpoint_group[0] + + # Start from registry values + api_key_var = endpoint["key"] + api_base_url = endpoint["url"] + client_type = endpoint.get("api_client_type", DEFAULT_CLIENT_TYPE) + + endpoint_models = {entry["model"] for entry in endpoint_group} + if len(endpoint_models) > 1: + raise ValueError( + f"Endpoint alias '{endpoint_lookup_id}' maps to multiple model ids {sorted(endpoint_models)}, " + "which is not yet supported by EvalConfig." + ) + model = endpoint["model"] + + # Provider overrides registry + if raw_provider is not None: + provider_cfg = PROVIDER_CONFIGS[raw_provider] + api_key_var = provider_cfg["key"] + api_base_url = provider_cfg["url"] + if "client_type" in provider_cfg: + client_type = provider_cfg["client_type"] + + # CLI overrides provider / registry + if api_key_override: + api_key_var = raw_api_key_var + if api_base_url_override: + api_base_url = raw_api_base_url + if client_type_override: + client_type = raw_client_type + + if ( + api_key_override + or api_base_url_override + or client_type_override + or raw_provider is not None + ): + logger.debug( + "Using endpoint registry for model '%s' with overrides (key: %s, url: %s, api_client_type: %s)", + model, + "override" if api_key_override or raw_provider else "registry", + "override" if api_base_url_override or raw_provider else "registry", + "override" if client_type_override or raw_provider else "registry", + ) + else: + logger.debug( + "Using endpoint configuration for model '%s' from registry (%d endpoint variant(s))", + model, + len(endpoint_group), + ) + else: + if raw_endpoint_id is not None: + raise ValueError( + f"Endpoint id '{raw_endpoint_id}' not found in endpoint registry at {endpoints_path}" + ) + # Fall back to provider (default: prime) + provider_cfg = PROVIDER_CONFIGS[raw_provider or DEFAULT_PROVIDER] + logger.debug( + "Model '%s' not found in endpoint registry, using provider '%s'", + raw_model, + raw_provider or DEFAULT_PROVIDER, + ) + model = raw_model + api_key_var = raw_api_key_var if api_key_override else provider_cfg["key"] + api_base_url = ( + raw_api_base_url if api_base_url_override else provider_cfg["url"] + ) + client_type = ( + raw_client_type + if client_type_override + else provider_cfg.get("client_type", DEFAULT_CLIENT_TYPE) + ) + + # Merge sampling args + merged_sampling_args = merge_sampling_args( + raw.get("sampling_args"), + max_tokens=raw.get("max_tokens"), + temperature=raw.get("temperature"), + include_none_max_tokens=True, + ) + # Build headers: registry < [[eval]] headers table < header list / --header + eval_headers_merged = build_extra_headers(raw) + # Default X-Session-ID → example_id for sticky DP-aware routing; + # user-supplied headers_from_state / --header-from-state override. + eval_headers_from_state = { + "X-Session-ID": "example_id", + **build_extra_headers_from_state(raw), + } + + registry_headers_base: dict[str, str] = {} + if endpoint_group is not None: + registry_headers_base = dict(endpoint_group[0].get("extra_headers", {})) + + merged_headers: dict[str, str] = { + **registry_headers_base, + **eval_headers_merged, + } + + primary_api_base_url = api_base_url + if not isinstance(primary_api_base_url, str): + raise ValueError("api_base_url must be a single string URL") + assert api_key_var is not None + resolved_api_key_var = api_key_var + + endpoint_configs: list[EndpointClientConfig] = [] + if ( + endpoint_group is not None + and not api_base_url_override + and raw_provider is None + and len(endpoint_group) > 1 + ): + endpoint_configs = [ + EndpointClientConfig( + api_key_var=(resolved_api_key_var if api_key_override else ep["key"]), + api_base_url=ep["url"], + extra_headers={ + **dict(ep.get("extra_headers", {})), + **eval_headers_merged, + }, + ) + for ep in endpoint_group + ] + + assert primary_api_base_url is not None + client_config = ClientConfig( + client_type=cast(ClientType, client_type), + api_key_var=resolved_api_key_var, + api_base_url=primary_api_base_url, + endpoint_configs=endpoint_configs, + extra_headers=merged_headers, + extra_headers_from_state=eval_headers_from_state, + ) + + # Backward-compatible TOML field: resume_path + if raw.get("resume") is None and raw.get("resume_path") is not None: + raw["resume"] = raw["resume_path"] + + # handle resume path resolution + resume_arg = raw.get("resume") + resume_path: Path | None = None + if isinstance(resume_arg, str): + resume_path = Path(resume_arg) + if not is_valid_eval_results_path(resume_path): + raise ValueError( + f"Resume path {resume_path} is not a valid evaluation results path" + ) + logger.info(f"Resuming from explicit path: {resume_path}") + elif resume_arg is True: + auto_resume_path = find_latest_incomplete_eval_results_path( + env_id=env_id, + model=model, + num_examples=num_examples, + rollouts_per_example=rollouts_per_example, + env_dir_path=raw.get("env_dir_path", DEFAULT_ENV_DIR_PATH), + output_dir=raw.get("output_dir"), + name=name, + ) + if auto_resume_path is not None: + resume_path = auto_resume_path + logger.info(f"Auto-resuming from: {resume_path}") + else: + logger.info( + "No matching incomplete run found for --resume; starting a new run" + ) + elif resume_arg in (None, False): + pass + else: + raise ValueError(f"Invalid value for --resume: {resume_arg!r}") + + env_args = raw.get("env_args", {}) + extra_env_kwargs = dict(raw.get("extra_env_kwargs", {})) + if "model" in env_args: + extra_env_kwargs.pop("model", None) + else: + # Make the resolved inference model available to custom envs without + # forcing users to repeat it in -a / env_args. Explicit + # extra_env_kwargs.model still wins for callers that set it directly. + extra_env_kwargs.setdefault("model", model) + if raw.get("timeout") is not None: + extra_env_kwargs["timeout_seconds"] = raw["timeout"] + + return EvalConfig( + env_id=env_id, + name=name, + env_args=env_args, + env_dir_path=raw.get("env_dir_path", DEFAULT_ENV_DIR_PATH), + output_dir=raw.get("output_dir"), + extra_env_kwargs=extra_env_kwargs, + endpoint_id=resolved_endpoint_id, + model=model, + client_config=client_config, + sampling_args=merged_sampling_args, + num_examples=num_examples, + rollouts_per_example=rollouts_per_example, + max_concurrent=raw.get("max_concurrent", DEFAULT_MAX_CONCURRENT), + max_retries=raw.get("max_retries", 0), + num_workers=raw.get("num_workers", "auto"), + disable_env_server=raw.get("disable_env_server", False), + verbose=raw.get("verbose", False), + disable_tui=raw.get("disable_tui", False), + state_columns=raw.get("state_columns", []), + save_results=raw.get("save_results", False), + resume_path=resume_path, + independent_scoring=raw.get("independent_scoring", False), + save_to_hf_hub=raw.get("save_to_hf_hub", False), + hf_hub_dataset_name=raw.get("hf_hub_dataset_name", ""), + ) + + def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser = build_parser() if argv is None: @@ -533,296 +826,6 @@ def main(argv: list[str] | None = None): raw_config.update(vars(args)) raw_eval_configs = [raw_config] - def build_eval_config(raw: dict) -> EvalConfig: - """Build EvalConfig from a raw config dict.""" - env_id = raw["env_id"] - name = raw.get("name") - if name is not None and (not isinstance(name, str) or not name): - raise ValueError("'name' must be a non-empty string when provided.") - - # Resolve num_examples and rollouts_per_example with env defaults - env_defaults = get_env_eval_defaults(env_id) - raw_num_examples = raw.get("num_examples") - raw_rollouts = raw.get("rollouts_per_example") - - num_examples = ( - raw_num_examples - if raw_num_examples is not None - else env_defaults.get("num_examples", DEFAULT_NUM_EXAMPLES) - ) - rollouts_per_example = ( - raw_rollouts - if raw_rollouts is not None - else env_defaults.get("rollouts_per_example", DEFAULT_ROLLOUTS_PER_EXAMPLE) - ) - - if raw_num_examples is None: - source = ( - "pyproject.toml" if "num_examples" in env_defaults else "global default" - ) - logger.debug(f"Using num_examples={num_examples} from {source}") - if raw_rollouts is None: - source = ( - "pyproject.toml" - if "rollouts_per_example" in env_defaults - else "global default" - ) - logger.debug( - f"Using rollouts_per_example={rollouts_per_example} from {source}" - ) - - raw_endpoint_id = raw.get("endpoint_id") - raw_model_field = raw.get("model") - if raw_endpoint_id is not None and raw_model_field is not None: - raise ValueError( - "Cannot set both 'endpoint_id' and 'model' in eval config; choose one." - ) - if raw_endpoint_id is not None and not isinstance(raw_endpoint_id, str): - raise ValueError("'endpoint_id' must be a string when provided.") - if isinstance(raw_endpoint_id, str) and not raw_endpoint_id: - raise ValueError("'endpoint_id' must be a non-empty string when provided.") - endpoints_path = raw.get("endpoints_path", DEFAULT_ENDPOINTS_PATH) - resolved_endpoints_file = resolve_endpoints_file(str(endpoints_path)) - if raw_endpoint_id is not None and ( - resolved_endpoints_file is None or resolved_endpoints_file.suffix != ".toml" - ): - raise ValueError( - "'endpoint_id' is only supported with TOML endpoint registries. " - "Set endpoints_path to an endpoints.toml file." - ) - - raw_model = raw_model_field if raw_model_field is not None else DEFAULT_MODEL - endpoint_lookup_id = ( - raw_endpoint_id if raw_endpoint_id is not None else raw_model - ) - raw_client_type = raw.get("api_client_type") - raw_api_key_var = raw.get("api_key_var") - raw_api_base_url = raw.get("api_base_url") - if isinstance(raw_api_base_url, list): - raise ValueError( - "api_base_url lists are no longer supported. " - "Use endpoint_id + endpoints.toml for multi-endpoint configuration." - ) - - # Provider resolution: - # - model IN registry: registry -> provider overrides -> CLI overrides - # - model NOT in registry: provider (default: prime) -> CLI overrides - raw_provider = raw.get("provider") - api_key_override = raw_api_key_var is not None - api_base_url_override = raw_api_base_url is not None - client_type_override = raw_client_type is not None - direct_endpoint_config = ( - raw_endpoint_id is None and api_key_override and api_base_url_override - ) - endpoints = {} if direct_endpoint_config else load_endpoints(endpoints_path) - endpoint_group: list[Endpoint] | None = None - resolved_endpoint_id: str | None = None - - if endpoint_lookup_id in endpoints: - endpoint_group = endpoints[endpoint_lookup_id] - resolved_endpoint_id = endpoint_lookup_id - endpoint = endpoint_group[0] - - # Start from registry values - api_key_var = endpoint["key"] - api_base_url = endpoint["url"] - client_type = endpoint.get("api_client_type", DEFAULT_CLIENT_TYPE) - - endpoint_models = {entry["model"] for entry in endpoint_group} - if len(endpoint_models) > 1: - raise ValueError( - f"Endpoint alias '{endpoint_lookup_id}' maps to multiple model ids {sorted(endpoint_models)}, " - "which is not yet supported by EvalConfig." - ) - model = endpoint["model"] - - # Provider overrides registry - if raw_provider is not None: - provider_cfg = PROVIDER_CONFIGS[raw_provider] - api_key_var = provider_cfg["key"] - api_base_url = provider_cfg["url"] - if "client_type" in provider_cfg: - client_type = provider_cfg["client_type"] - - # CLI overrides provider / registry - if api_key_override: - api_key_var = raw_api_key_var - if api_base_url_override: - api_base_url = raw_api_base_url - if client_type_override: - client_type = raw_client_type - - if ( - api_key_override - or api_base_url_override - or client_type_override - or raw_provider is not None - ): - logger.debug( - "Using endpoint registry for model '%s' with overrides (key: %s, url: %s, api_client_type: %s)", - model, - "override" if api_key_override or raw_provider else "registry", - "override" if api_base_url_override or raw_provider else "registry", - "override" if client_type_override or raw_provider else "registry", - ) - else: - logger.debug( - "Using endpoint configuration for model '%s' from registry (%d endpoint variant(s))", - model, - len(endpoint_group), - ) - else: - if raw_endpoint_id is not None: - raise ValueError( - f"Endpoint id '{raw_endpoint_id}' not found in endpoint registry at {endpoints_path}" - ) - # Fall back to provider (default: prime) - provider_cfg = PROVIDER_CONFIGS[raw_provider or DEFAULT_PROVIDER] - logger.debug( - "Model '%s' not found in endpoint registry, using provider '%s'", - raw_model, - raw_provider or DEFAULT_PROVIDER, - ) - model = raw_model - api_key_var = raw_api_key_var if api_key_override else provider_cfg["key"] - api_base_url = ( - raw_api_base_url if api_base_url_override else provider_cfg["url"] - ) - client_type = ( - raw_client_type - if client_type_override - else provider_cfg.get("client_type", DEFAULT_CLIENT_TYPE) - ) - - # Merge sampling args - merged_sampling_args = merge_sampling_args( - raw.get("sampling_args"), - max_tokens=raw.get("max_tokens"), - temperature=raw.get("temperature"), - include_none_max_tokens=True, - ) - # Build headers: registry < [[eval]] headers table < header list / --header - eval_headers_merged = build_extra_headers(raw) - # Default X-Session-ID → example_id for sticky DP-aware routing; - # user-supplied headers_from_state / --header-from-state override. - eval_headers_from_state = { - "X-Session-ID": "example_id", - **build_extra_headers_from_state(raw), - } - - registry_headers_base: dict[str, str] = {} - if endpoint_group is not None: - registry_headers_base = dict(endpoint_group[0].get("extra_headers", {})) - - merged_headers: dict[str, str] = { - **registry_headers_base, - **eval_headers_merged, - } - - primary_api_base_url = api_base_url - if not isinstance(primary_api_base_url, str): - raise ValueError("api_base_url must be a single string URL") - assert api_key_var is not None - resolved_api_key_var = api_key_var - - endpoint_configs: list[EndpointClientConfig] = [] - if ( - endpoint_group is not None - and not api_base_url_override - and raw_provider is None - and len(endpoint_group) > 1 - ): - endpoint_configs = [ - EndpointClientConfig( - api_key_var=( - resolved_api_key_var if api_key_override else ep["key"] - ), - api_base_url=ep["url"], - extra_headers={ - **dict(ep.get("extra_headers", {})), - **eval_headers_merged, - }, - ) - for ep in endpoint_group - ] - - assert primary_api_base_url is not None - client_config = ClientConfig( - client_type=cast(ClientType, client_type), - api_key_var=resolved_api_key_var, - api_base_url=primary_api_base_url, - endpoint_configs=endpoint_configs, - extra_headers=merged_headers, - extra_headers_from_state=eval_headers_from_state, - ) - - # Backward-compatible TOML field: resume_path - if raw.get("resume") is None and raw.get("resume_path") is not None: - raw["resume"] = raw["resume_path"] - - # handle resume path resolution - resume_arg = raw.get("resume") - resume_path: Path | None = None - if isinstance(resume_arg, str): - resume_path = Path(resume_arg) - if not is_valid_eval_results_path(resume_path): - raise ValueError( - f"Resume path {resume_path} is not a valid evaluation results path" - ) - logger.info(f"Resuming from explicit path: {resume_path}") - elif resume_arg is True: - auto_resume_path = find_latest_incomplete_eval_results_path( - env_id=env_id, - model=model, - num_examples=num_examples, - rollouts_per_example=rollouts_per_example, - env_dir_path=raw.get("env_dir_path", DEFAULT_ENV_DIR_PATH), - output_dir=raw.get("output_dir"), - name=name, - ) - if auto_resume_path is not None: - resume_path = auto_resume_path - logger.info(f"Auto-resuming from: {resume_path}") - else: - logger.info( - "No matching incomplete run found for --resume; starting a new run" - ) - elif resume_arg in (None, False): - pass - else: - raise ValueError(f"Invalid value for --resume: {resume_arg!r}") - - extra_env_kwargs = dict(raw.get("extra_env_kwargs", {})) - if raw.get("timeout") is not None: - extra_env_kwargs["timeout_seconds"] = raw["timeout"] - - return EvalConfig( - env_id=env_id, - name=name, - env_args=raw.get("env_args", {}), - env_dir_path=raw.get("env_dir_path", DEFAULT_ENV_DIR_PATH), - output_dir=raw.get("output_dir"), - extra_env_kwargs=extra_env_kwargs, - endpoint_id=resolved_endpoint_id, - model=model, - client_config=client_config, - sampling_args=merged_sampling_args, - num_examples=num_examples, - rollouts_per_example=rollouts_per_example, - max_concurrent=raw.get("max_concurrent", DEFAULT_MAX_CONCURRENT), - max_retries=raw.get("max_retries", 0), - num_workers=raw.get("num_workers", "auto"), - disable_env_server=raw.get("disable_env_server", False), - verbose=raw.get("verbose", False), - disable_tui=raw.get("disable_tui", False), - state_columns=raw.get("state_columns", []), - save_results=raw.get("save_results", False), - resume_path=resume_path, - independent_scoring=raw.get("independent_scoring", False), - save_to_hf_hub=raw.get("save_to_hf_hub", False), - hf_hub_dataset_name=raw.get("hf_hub_dataset_name", ""), - ) - # Check Hub environments are installed before running missing_envs = [] for raw in raw_eval_configs: diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 5e7c8651d..66cca081e 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -1,4 +1,6 @@ import asyncio +import importlib +import inspect import itertools import json import logging @@ -926,6 +928,20 @@ def quiet_datasets(): enable_progress_bar() +def _load_environment_accepts_arg(env_id: str, arg_name: str) -> bool: + module_name = env_id.replace("-", "_").split("/")[-1] + try: + module = importlib.import_module(module_name) + env_load_func = getattr(module, "load_environment") + sig = inspect.signature(env_load_func) + except Exception: + return False + + return arg_name in sig.parameters or any( + param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values() + ) + + async def run_evaluation( config: EvalConfig, on_start: StartCallback | None = None, @@ -937,13 +953,22 @@ async def run_evaluation( maybe_suppress_logs = ( log_level(logging.CRITICAL) if not config.disable_env_server else nullcontext() ) + extra_env_kwargs = dict(config.extra_env_kwargs) + env_args = dict(config.env_args) + if "model" in config.env_args: + extra_env_kwargs.pop("model", None) + elif "model" in extra_env_kwargs and _load_environment_accepts_arg( + config.env_id, "model" + ): + env_args["model"] = extra_env_kwargs["model"] + with maybe_suppress_logs: - vf_env = vf.load_environment(env_id=config.env_id, **config.env_args) + vf_env = vf.load_environment(env_id=config.env_id, **env_args) # set extra environment kwargs - if config.extra_env_kwargs: - logger.info(f"Setting extra environment kwargs: {config.extra_env_kwargs}") - vf_env.set_kwargs(**config.extra_env_kwargs) + if extra_env_kwargs: + logger.info(f"Setting extra environment kwargs: {extra_env_kwargs}") + vf_env.set_kwargs(**extra_env_kwargs) results_path = config.resume_path or get_eval_results_path(config) model_pricing = await _resolve_model_pricing(config) @@ -951,7 +976,7 @@ async def run_evaluation( try: if not config.disable_env_server: - extra_env_kwargs = dict(config.extra_env_kwargs) + extra_env_kwargs = dict(extra_env_kwargs) # resolve total concurrency if "concurrency" not in extra_env_kwargs: if config.max_concurrent <= 0: