diff --git a/packages/prime/src/prime_cli/api/deployments.py b/packages/prime/src/prime_cli/api/deployments.py index dee21359..69f0ea58 100644 --- a/packages/prime/src/prime_cli/api/deployments.py +++ b/packages/prime/src/prime_cli/api/deployments.py @@ -82,6 +82,16 @@ def deploy_adapter(self, adapter_id: str) -> Adapter: raise APIError(f"Failed to deploy adapter: {e.response.text}") raise APIError(f"Failed to deploy adapter: {str(e)}") + def deploy_checkpoint(self, checkpoint_id: str) -> Adapter: + """Deploy a checkpoint by preparing it as an adapter for inference.""" + try: + response = self.client.post(f"/rft/checkpoints/{checkpoint_id}/deploy") + return Adapter.model_validate(response.get("adapter")) + except Exception as e: + if hasattr(e, "response") and hasattr(e.response, "text"): + raise APIError(f"Failed to deploy checkpoint: {e.response.text}") + raise APIError(f"Failed to deploy checkpoint: {str(e)}") + def unload_adapter(self, adapter_id: str) -> Adapter: """Unload an adapter from inference.""" try: diff --git a/packages/prime/src/prime_cli/commands/deployments.py b/packages/prime/src/prime_cli/commands/deployments.py index c3fabe1d..72dbb563 100644 --- a/packages/prime/src/prime_cli/commands/deployments.py +++ b/packages/prime/src/prime_cli/commands/deployments.py @@ -172,20 +172,31 @@ def list_deployments( def create_deployment( ctx: typer.Context, model_id: Optional[str] = typer.Argument(None, help="Model ID to deploy"), + checkpoint_id: Optional[str] = typer.Option( + None, + "--checkpoint-id", + help="Deploy a Hosted Training checkpoint by checkpoint ID", + ), yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"), ) -> None: """Deploy a model for inference. Makes the trained model available for inference requests. - Model must be in READY status. + Model must be in READY status. To deploy a checkpoint, pass --checkpoint-id. Example: prime deployments create prime deployments create --yes + + prime deployments create --checkpoint-id """ - if model_id is None: + if model_id and checkpoint_id: + console.print("[red]Error:[/red] Use either MODEL_ID or --checkpoint-id, not both.") + raise typer.Exit(1) + + if model_id is None and checkpoint_id is None: console.print(ctx.get_help()) raise typer.Exit(0) @@ -193,6 +204,30 @@ def create_deployment( api_client = APIClient() deployments_client = DeploymentsClient(api_client) + if checkpoint_id: + console.print("[bold]Deploying checkpoint:[/bold]") + console.print(f" Checkpoint ID: {checkpoint_id}") + console.print() + + if not yes: + confirm = typer.confirm("Are you sure you want to deploy this checkpoint?") + if not confirm: + console.print("Cancelled.") + raise typer.Exit(0) + + adapter = deployments_client.deploy_checkpoint(checkpoint_id) + + console.print("[green]Deployment initiated successfully![/green]") + console.print(f"Adapter ID: [cyan]{adapter.id}[/cyan]") + console.print(f"Status: [yellow]{adapter.deployment_status}[/yellow]") + console.print("\n[dim]The model is being deployed. This may take a few minutes.[/dim]") + console.print("[dim]Use 'prime deployments list' to check deployment status.[/dim]") + + _print_inference_usage(adapter.base_model, adapter.id) + return + + assert model_id is not None + # Get model to validate status model = deployments_client.get_adapter(model_id) diff --git a/packages/prime/tests/test_deployments.py b/packages/prime/tests/test_deployments.py index 759e5de3..122eaee8 100644 --- a/packages/prime/tests/test_deployments.py +++ b/packages/prime/tests/test_deployments.py @@ -1,12 +1,42 @@ from types import SimpleNamespace from typing import Any +from prime_cli.api.deployments import DeploymentsClient +from prime_cli.client import APIError from prime_cli.main import app from prime_cli.utils import strip_ansi from typer.testing import CliRunner runner = CliRunner() +TEST_ENV = {"PRIME_API_KEY": "dummy", "PRIME_DISABLE_VERSION_CHECK": "1", "COLUMNS": "200"} + + +def _adapter_response( + *, + adapter_id: str = "adapter-123", + base_model: str = "meta-llama/Llama-3.1-8B-Instruct", + deployment_status: str = "DEPLOYING", +) -> dict[str, Any]: + return { + "adapter": { + "id": adapter_id, + "displayName": "Checkpoint Adapter", + "userId": "user-123", + "teamId": None, + "rftRunId": "run-123", + "baseModel": base_model, + "step": 20, + "status": "READY", + "deploymentStatus": deployment_status, + "deployedAt": None, + "deploymentError": None, + "createdAt": "2026-01-01T00:00:00Z", + "updatedAt": "2026-01-01T00:00:00Z", + }, + "message": "Checkpoint adapter deployment started", + } + def test_deployments_create_prints_chat_and_api_key_commands(monkeypatch) -> None: monkeypatch.setenv("PRIME_API_KEY", "dummy") @@ -57,3 +87,88 @@ def deploy_adapter(self, model_id: str) -> Any: assert "export PRIME_API_KEY=" in output assert "PRIME_API_KEY" in output assert "curl -X POST" in output + + +def test_deployments_client_deploy_checkpoint_posts_endpoint() -> None: + captured: dict[str, Any] = {} + + class DummyAPIClient: + def post(self, endpoint: str, json: dict[str, Any] | None = None) -> dict: + captured["endpoint"] = endpoint + captured["json"] = json + return _adapter_response() + + adapter = DeploymentsClient(DummyAPIClient()).deploy_checkpoint("ckpt-123") + + assert captured["endpoint"] == "/rft/checkpoints/ckpt-123/deploy" + assert captured["json"] is None + assert adapter.id == "adapter-123" + + +def test_deployments_create_checkpoint_prints_adapter_result(monkeypatch) -> None: + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + + adapter = SimpleNamespace( + id="adapter-456", + base_model="Qwen/Qwen3.5-0.8B", + deployment_status="DEPLOYING", + ) + + class DummyDeploymentsClient: + def __init__(self, api_client: Any) -> None: + self.api_client = api_client + + def deploy_checkpoint(self, checkpoint_id: str) -> Any: + assert checkpoint_id == "ckpt-456" + return adapter + + monkeypatch.setattr("prime_cli.commands.deployments.APIClient", lambda: object()) + monkeypatch.setattr( + "prime_cli.commands.deployments.DeploymentsClient", + DummyDeploymentsClient, + ) + + result = runner.invoke( + app, + ["deployments", "create", "--checkpoint-id", "ckpt-456", "--yes"], + env=TEST_ENV, + ) + output = strip_ansi(result.output) + + assert result.exit_code == 0, result.output + assert "Deploying checkpoint:" in output + assert "Checkpoint ID: ckpt-456" in output + assert "Deployment initiated successfully!" in output + assert "Adapter ID: adapter-456" in output + assert "Status: DEPLOYING" in output + assert '"Qwen/Qwen3.5-0.8B:adapter-456"' in output + assert "prime deployments list" in output + + +def test_deployments_create_checkpoint_surfaces_conflict_errors(monkeypatch) -> None: + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + + class DummyDeploymentsClient: + def __init__(self, api_client: Any) -> None: + self.api_client = api_client + + def deploy_checkpoint(self, checkpoint_id: str) -> Any: + assert checkpoint_id == "ckpt-busy" + raise APIError("HTTP 409: Checkpoint adapter preparation is already in progress") + + monkeypatch.setattr("prime_cli.commands.deployments.APIClient", lambda: object()) + monkeypatch.setattr( + "prime_cli.commands.deployments.DeploymentsClient", + DummyDeploymentsClient, + ) + + result = runner.invoke( + app, + ["deployments", "create", "--checkpoint-id", "ckpt-busy", "--yes"], + env=TEST_ENV, + ) + output = strip_ansi(result.output) + + assert result.exit_code == 1 + assert "Error: HTTP 409" in output + assert "Checkpoint adapter preparation is already in progress" in output