diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 7bcdd7c2..5a026636 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -879,6 +879,16 @@ def create_run( help="Skip action status check and run even if environment action failed.", ), yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"), + cluster: Optional[str] = typer.Option( + None, + "--cluster", + help=( + "Pin the dispatch to a specific cluster by name. Overrides the " + "TOML's `cluster_name` field. The backend hard-fails with a 400 " + "error if the cluster is unknown, not allocated, or out of " + "capacity — no silent fallback to another cluster." + ), + ), ) -> None: """Launch a Hosted Training run from a config file. @@ -891,6 +901,14 @@ def create_run( console.print(f"[dim]Loading config from {config_path}[/dim]\n") cfg = load_config(config_path) + # `--cluster` overrides the TOML's `cluster_name` so a user can + # retarget a single dispatch without editing the config. We don't + # validate the name client-side: the backend's picker is the source + # of truth for which clusters the caller can hit, and it returns a + # clear 400 with the available alternatives when the name is wrong. + if cluster is not None: + cfg.cluster_name = cluster + # Collect secrets from all sources def warn(msg: str) -> None: console.print(f"[yellow]Warning:[/yellow] {msg}") diff --git a/packages/prime/tests/test_train_cli.py b/packages/prime/tests/test_train_cli.py index 3259213d..addf73ab 100644 --- a/packages/prime/tests/test_train_cli.py +++ b/packages/prime/tests/test_train_cli.py @@ -65,6 +65,134 @@ def test_train_init_defaults_to_rl_toml() -> None: assert Path("rl.toml").exists() +def test_train_help_lists_cluster_flag() -> None: + # The flag has to be discoverable from `--help` so users don't have + # to grep source to find it. Regression guard against future arg + # reorders silently hiding the option. + result = runner.invoke(app, ["train", "--help"], env=TEST_ENV) + + assert result.exit_code == 0, result.output + assert "--cluster" in result.output + + +def test_train_cluster_flag_overrides_config_cluster_name(monkeypatch, tmp_path: Path) -> None: + # CLI `--cluster` should win over `cluster_name = "..."` in the TOML + # so users can retarget a single dispatch without editing the config. + # We don't validate the cluster name client-side: the backend's picker + # is the source of truth — verify here only that the override reaches + # the RLClient payload as `cluster_name`. + captured: dict[str, Any] = {} + + def mock_create_run(self: Any, **kwargs: Any) -> Any: + captured["kwargs"] = kwargs + + class _Run: + id = "run-1" + status = "QUEUED" + runs_ahead = None + + def model_dump(self_inner) -> dict[str, Any]: + return {"id": "run-1", "status": "QUEUED"} + + return _Run() + + def mock_list_models(self: Any, **kwargs: Any) -> list: + return [] + + monkeypatch.setattr( + "prime_cli.api.rl.RLClient.create_run", + mock_create_run, + ) + monkeypatch.setattr( + "prime_cli.api.rl.RLClient.list_models", + mock_list_models, + ) + + config_path = tmp_path / "rl.toml" + config_path.write_text( + 'model = "Qwen/Qwen3-0.6B"\n' + 'cluster_name = "config-cluster"\n' + "\n" + "[[env]]\n" + 'id = "reverse-text"\n' + ) + + result = runner.invoke( + app, + [ + "train", + str(config_path), + "--cluster", + "flag-cluster", + "--output", + "json", + "--yes", + "--skip-action-check", + ], + env={**TEST_ENV, "PRIME_API_KEY": "test-key"}, + ) + + assert result.exit_code == 0, result.output + assert captured["kwargs"]["cluster_name"] == "flag-cluster" + + +def test_train_without_cluster_flag_uses_config_cluster_name(monkeypatch, tmp_path: Path) -> None: + # Sanity check the inverse: with no --cluster, the TOML's + # cluster_name is what reaches the backend. Without this we'd never + # know if the override path silently took over the no-override path. + captured: dict[str, Any] = {} + + def mock_create_run(self: Any, **kwargs: Any) -> Any: + captured["kwargs"] = kwargs + + class _Run: + id = "run-1" + status = "QUEUED" + runs_ahead = None + + def model_dump(self_inner) -> dict[str, Any]: + return {"id": "run-1", "status": "QUEUED"} + + return _Run() + + def mock_list_models(self: Any, **kwargs: Any) -> list: + return [] + + monkeypatch.setattr( + "prime_cli.api.rl.RLClient.create_run", + mock_create_run, + ) + monkeypatch.setattr( + "prime_cli.api.rl.RLClient.list_models", + mock_list_models, + ) + + config_path = tmp_path / "rl.toml" + config_path.write_text( + 'model = "Qwen/Qwen3-0.6B"\n' + 'cluster_name = "config-cluster"\n' + "\n" + "[[env]]\n" + 'id = "reverse-text"\n' + ) + + result = runner.invoke( + app, + [ + "train", + str(config_path), + "--output", + "json", + "--yes", + "--skip-action-check", + ], + env={**TEST_ENV, "PRIME_API_KEY": "test-key"}, + ) + + assert result.exit_code == 0, result.output + assert captured["kwargs"]["cluster_name"] == "config-cluster" + + def test_train_request_submits_model_request(monkeypatch) -> None: captured: dict[str, Any] = {}