Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions packages/prime/src/prime_cli/api/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ def deploy_adapter(self, adapter_id: str) -> Adapter:
raise APIError(f"Failed to deploy adapter: {e.response.text}")
raise APIError(f"Failed to deploy adapter: {str(e)}")

def deploy_checkpoint(self, checkpoint_id: str) -> Adapter:
"""Deploy a checkpoint by preparing it as an adapter for inference."""
try:
response = self.client.post(f"/rft/checkpoints/{checkpoint_id}/deploy")
return Adapter.model_validate(response.get("adapter"))
except Exception as e:
if hasattr(e, "response") and hasattr(e.response, "text"):
raise APIError(f"Failed to deploy checkpoint: {e.response.text}")
raise APIError(f"Failed to deploy checkpoint: {str(e)}")

def unload_adapter(self, adapter_id: str) -> Adapter:
"""Unload an adapter from inference."""
try:
Expand Down
39 changes: 37 additions & 2 deletions packages/prime/src/prime_cli/commands/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,27 +172,62 @@ def list_deployments(
def create_deployment(
ctx: typer.Context,
model_id: Optional[str] = typer.Argument(None, help="Model ID to deploy"),
checkpoint_id: Optional[str] = typer.Option(
None,
"--checkpoint-id",
help="Deploy a Hosted Training checkpoint by checkpoint ID",
),
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
) -> None:
"""Deploy a model for inference.

Makes the trained model available for inference requests.
Model must be in READY status.
Model must be in READY status. To deploy a checkpoint, pass --checkpoint-id.

Example:

prime deployments create <model_id>

prime deployments create <model_id> --yes

prime deployments create --checkpoint-id <checkpoint_id>
"""
if model_id is None:
if model_id and checkpoint_id:
console.print("[red]Error:[/red] Use either MODEL_ID or --checkpoint-id, not both.")
raise typer.Exit(1)

if model_id is None and checkpoint_id is None:
console.print(ctx.get_help())
raise typer.Exit(0)

try:
api_client = APIClient()
deployments_client = DeploymentsClient(api_client)

if checkpoint_id:
console.print("[bold]Deploying checkpoint:[/bold]")
console.print(f" Checkpoint ID: {checkpoint_id}")
console.print()

if not yes:
confirm = typer.confirm("Are you sure you want to deploy this checkpoint?")
if not confirm:
console.print("Cancelled.")
raise typer.Exit(0)

adapter = deployments_client.deploy_checkpoint(checkpoint_id)

console.print("[green]Deployment initiated successfully![/green]")
console.print(f"Adapter ID: [cyan]{adapter.id}[/cyan]")
console.print(f"Status: [yellow]{adapter.deployment_status}[/yellow]")
console.print("\n[dim]The model is being deployed. This may take a few minutes.[/dim]")
console.print("[dim]Use 'prime deployments list' to check deployment status.[/dim]")

_print_inference_usage(adapter.base_model, adapter.id)
return

assert model_id is not None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty checkpoint-id triggers assert

Medium Severity

Passing --checkpoint-id with an empty value skips the checkpoint deploy path and the no-argument help exit, then hits assert model_id is not None. That assertion is not caught as APIError, so the CLI can crash instead of showing a clear validation error.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 13d3303. Configure here.


# Get model to validate status
model = deployments_client.get_adapter(model_id)

Expand Down
115 changes: 115 additions & 0 deletions packages/prime/tests/test_deployments.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,42 @@
from types import SimpleNamespace
from typing import Any

from prime_cli.api.deployments import DeploymentsClient
from prime_cli.client import APIError
from prime_cli.main import app
from prime_cli.utils import strip_ansi
from typer.testing import CliRunner

runner = CliRunner()

TEST_ENV = {"PRIME_API_KEY": "dummy", "PRIME_DISABLE_VERSION_CHECK": "1", "COLUMNS": "200"}


def _adapter_response(
*,
adapter_id: str = "adapter-123",
base_model: str = "meta-llama/Llama-3.1-8B-Instruct",
deployment_status: str = "DEPLOYING",
) -> dict[str, Any]:
return {
"adapter": {
"id": adapter_id,
"displayName": "Checkpoint Adapter",
"userId": "user-123",
"teamId": None,
"rftRunId": "run-123",
"baseModel": base_model,
"step": 20,
"status": "READY",
"deploymentStatus": deployment_status,
"deployedAt": None,
"deploymentError": None,
"createdAt": "2026-01-01T00:00:00Z",
"updatedAt": "2026-01-01T00:00:00Z",
},
"message": "Checkpoint adapter deployment started",
}


def test_deployments_create_prints_chat_and_api_key_commands(monkeypatch) -> None:
monkeypatch.setenv("PRIME_API_KEY", "dummy")
Expand Down Expand Up @@ -57,3 +87,88 @@ def deploy_adapter(self, model_id: str) -> Any:
assert "export PRIME_API_KEY=<insert_key_here>" in output
assert "PRIME_API_KEY" in output
assert "curl -X POST" in output


def test_deployments_client_deploy_checkpoint_posts_endpoint() -> None:
captured: dict[str, Any] = {}

class DummyAPIClient:
def post(self, endpoint: str, json: dict[str, Any] | None = None) -> dict:
captured["endpoint"] = endpoint
captured["json"] = json
return _adapter_response()

adapter = DeploymentsClient(DummyAPIClient()).deploy_checkpoint("ckpt-123")

assert captured["endpoint"] == "/rft/checkpoints/ckpt-123/deploy"
assert captured["json"] is None
assert adapter.id == "adapter-123"


def test_deployments_create_checkpoint_prints_adapter_result(monkeypatch) -> None:
monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None))

adapter = SimpleNamespace(
id="adapter-456",
base_model="Qwen/Qwen3.5-0.8B",
deployment_status="DEPLOYING",
)

class DummyDeploymentsClient:
def __init__(self, api_client: Any) -> None:
self.api_client = api_client

def deploy_checkpoint(self, checkpoint_id: str) -> Any:
assert checkpoint_id == "ckpt-456"
return adapter

monkeypatch.setattr("prime_cli.commands.deployments.APIClient", lambda: object())
monkeypatch.setattr(
"prime_cli.commands.deployments.DeploymentsClient",
DummyDeploymentsClient,
)

result = runner.invoke(
app,
["deployments", "create", "--checkpoint-id", "ckpt-456", "--yes"],
env=TEST_ENV,
)
output = strip_ansi(result.output)

assert result.exit_code == 0, result.output
assert "Deploying checkpoint:" in output
assert "Checkpoint ID: ckpt-456" in output
assert "Deployment initiated successfully!" in output
assert "Adapter ID: adapter-456" in output
assert "Status: DEPLOYING" in output
assert '"Qwen/Qwen3.5-0.8B:adapter-456"' in output
assert "prime deployments list" in output


def test_deployments_create_checkpoint_surfaces_conflict_errors(monkeypatch) -> None:
monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None))

class DummyDeploymentsClient:
def __init__(self, api_client: Any) -> None:
self.api_client = api_client

def deploy_checkpoint(self, checkpoint_id: str) -> Any:
assert checkpoint_id == "ckpt-busy"
raise APIError("HTTP 409: Checkpoint adapter preparation is already in progress")

monkeypatch.setattr("prime_cli.commands.deployments.APIClient", lambda: object())
monkeypatch.setattr(
"prime_cli.commands.deployments.DeploymentsClient",
DummyDeploymentsClient,
)

result = runner.invoke(
app,
["deployments", "create", "--checkpoint-id", "ckpt-busy", "--yes"],
env=TEST_ENV,
)
output = strip_ansi(result.output)

assert result.exit_code == 1
assert "Error: HTTP 409" in output
assert "Checkpoint adapter preparation is already in progress" in output
Loading