diff --git a/packages/prime-sandboxes/src/prime_sandboxes/__init__.py b/packages/prime-sandboxes/src/prime_sandboxes/__init__.py index a854a7d45..49753bf1c 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/__init__.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/__init__.py @@ -25,6 +25,7 @@ ) from .models import ( AdvancedConfigs, + AuthorizeSSHForVMResponse, BackgroundJob, BackgroundJobStatus, BulkDeleteSandboxRequest, @@ -84,6 +85,7 @@ "ExposedPort", "ListExposedPortsResponse", "SSHSession", + "AuthorizeSSHForVMResponse", # Exceptions "APIError", "UnauthorizedError", diff --git a/packages/prime-sandboxes/src/prime_sandboxes/models.py b/packages/prime-sandboxes/src/prime_sandboxes/models.py index acdec9e25..4239d328d 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/models.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/models.py @@ -290,6 +290,18 @@ class SSHSession(BaseModel): token: str +class AuthorizeSSHForVMResponse(BaseModel): + """Result of authorizing an SSH public key for a VM sandbox session.""" + + session_id: str + sandbox_id: str + host: str + port: int + external_endpoint: str + expires_at: datetime + ttl_seconds: int + + class BackgroundJob(BaseModel): """Background job handle returned when starting a background job""" diff --git a/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py b/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py index b36092fbc..647e9469b 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py @@ -37,6 +37,7 @@ UploadTimeoutError, ) from .models import ( + AuthorizeSSHForVMResponse, BackgroundJob, BackgroundJobStatus, BulkDeleteSandboxRequest, @@ -1367,8 +1368,12 @@ def create_ssh_session( sandbox_id: str, ttl_seconds: Optional[int] = None, ) -> SSHSession: - """Create an SSH session""" - self._guard_vm_unsupported(sandbox_id, "SSH") + """Create an SSH session. + + This only creates the session. Register the ephemeral public key in a + separate step: container sandboxes authorize directly against the SSH + sidecar, VM sandboxes authorize via ``authorize_ssh_session``. + """ payload: Dict[str, Any] = {} if ttl_seconds is not None: payload["ttl_seconds"] = ttl_seconds @@ -1379,9 +1384,30 @@ def create_ssh_session( ) return SSHSession.model_validate(response) + def authorize_ssh_session( + self, + sandbox_id: str, + session_id: str, + public_key: str, + ttl_seconds: Optional[int] = None, + ) -> AuthorizeSSHForVMResponse: + """Authorize an ephemeral public key for a VM sandbox SSH session. + + Container sandboxes authorize directly against the SSH sidecar and + should not call this method. + """ + payload: Dict[str, Any] = {"public_key": public_key} + if ttl_seconds is not None: + payload["ttl_seconds"] = ttl_seconds + response = self.client.request( + "POST", + f"/sandbox/{sandbox_id}/ssh-session/{session_id}/authorize", + json=payload, + ) + return AuthorizeSSHForVMResponse.model_validate(response) + def close_ssh_session(self, sandbox_id: str, session_id: str) -> None: - """Close an SSH session and remove its exposure""" - self._guard_vm_unsupported(sandbox_id, "SSH") + """Close an SSH session.""" self.client.request("DELETE", f"/sandbox/{sandbox_id}/ssh-session/{session_id}") @@ -2308,8 +2334,12 @@ async def create_ssh_session( sandbox_id: str, ttl_seconds: Optional[int] = None, ) -> SSHSession: - """Create an SSH session""" - await self._guard_vm_unsupported(sandbox_id, "SSH") + """Create an SSH session. + + This only creates the session. Register the ephemeral public key in a + separate step: container sandboxes authorize directly against the SSH + sidecar, VM sandboxes authorize via ``authorize_ssh_session``. + """ payload: Dict[str, Any] = {} if ttl_seconds is not None: payload["ttl_seconds"] = ttl_seconds @@ -2320,9 +2350,30 @@ async def create_ssh_session( ) return SSHSession.model_validate(response) + async def authorize_ssh_session( + self, + sandbox_id: str, + session_id: str, + public_key: str, + ttl_seconds: Optional[int] = None, + ) -> AuthorizeSSHForVMResponse: + """Authorize an ephemeral public key for a VM sandbox SSH session. + + Container sandboxes authorize directly against the SSH sidecar and + should not call this method. + """ + payload: Dict[str, Any] = {"public_key": public_key} + if ttl_seconds is not None: + payload["ttl_seconds"] = ttl_seconds + response = await self.client.request( + "POST", + f"/sandbox/{sandbox_id}/ssh-session/{session_id}/authorize", + json=payload, + ) + return AuthorizeSSHForVMResponse.model_validate(response) + async def close_ssh_session(self, sandbox_id: str, session_id: str) -> None: - """Close an SSH session and remove its exposure""" - await self._guard_vm_unsupported(sandbox_id, "SSH") + """Close an SSH session.""" await self.client.request("DELETE", f"/sandbox/{sandbox_id}/ssh-session/{session_id}") diff --git a/packages/prime-sandboxes/tests/test_vm_guards.py b/packages/prime-sandboxes/tests/test_vm_guards.py index 8bb751e63..7ab431ad4 100644 --- a/packages/prime-sandboxes/tests/test_vm_guards.py +++ b/packages/prime-sandboxes/tests/test_vm_guards.py @@ -119,20 +119,56 @@ def test_sync_list_exposed_ports_blocked_for_vm(): assert recording.calls == [] -def test_sync_create_ssh_session_blocked_for_vm(): +def test_sync_create_ssh_session_allowed_for_vm(): client, recording = _make_sync_client(is_vm=True) - with pytest.raises(APIError) as exc_info: - client.create_ssh_session("sbx-vm") - assert "SSH" in str(exc_info.value) - assert recording.calls == [] + cast(Any, recording)._response = { + "session_id": "s", + "exposure_id": "s", + "sandbox_id": "sbx-vm", + "host": "h", + "port": 2222, + "external_endpoint": "h:2222", + "expires_at": datetime.now(timezone.utc).isoformat(), + "ttl_seconds": 300, + "gateway_url": "", + "user_ns": "", + "job_id": "sbx-vm", + "token": "", + } + client.create_ssh_session("sbx-vm") + assert any( + method == "POST" and path.endswith("/ssh-session") for method, path, _ in recording.calls + ) + # Create no longer carries the public key; that happens in authorize. + assert "public_key" not in recording.calls[-1][2]["json"] -def test_sync_close_ssh_session_blocked_for_vm(): +def test_sync_authorize_ssh_session_allowed_for_vm(): client, recording = _make_sync_client(is_vm=True) - with pytest.raises(APIError) as exc_info: - client.close_ssh_session("sbx-vm", "sess-1") - assert "SSH" in str(exc_info.value) - assert recording.calls == [] + cast(Any, recording)._response = { + "session_id": "sess-1", + "sandbox_id": "sbx-vm", + "host": "h", + "port": 2222, + "external_endpoint": "h:2222", + "expires_at": datetime.now(timezone.utc).isoformat(), + "ttl_seconds": 300, + } + client.authorize_ssh_session("sbx-vm", "sess-1", public_key="ssh-ed25519 AAAA...") + assert any( + method == "POST" and path.endswith("/ssh-session/sess-1/authorize") + for method, path, _ in recording.calls + ) + assert recording.calls[-1][2]["json"]["public_key"] == "ssh-ed25519 AAAA..." + + +def test_sync_close_ssh_session_allowed_for_vm(): + client, recording = _make_sync_client(is_vm=True) + client.close_ssh_session("sbx-vm", "sess-1") + assert any( + method == "DELETE" and path.endswith("/ssh-session/sess-1") + for method, path, _ in recording.calls + ) # --------------------------------------------------------------------------- @@ -327,25 +363,65 @@ async def test_async_list_exposed_ports_blocked_for_vm(): @pytest.mark.asyncio -async def test_async_create_ssh_session_blocked_for_vm(): +async def test_async_create_ssh_session_allowed_for_vm(): client, recording = _make_async_client(is_vm=True) try: - with pytest.raises(APIError) as exc_info: - await client.create_ssh_session("sbx-vm") - assert "SSH" in str(exc_info.value) - assert recording.calls == [] + cast(Any, recording)._response = { + "session_id": "s", + "exposure_id": "s", + "sandbox_id": "sbx-vm", + "host": "h", + "port": 2222, + "external_endpoint": "h:2222", + "expires_at": datetime.now(timezone.utc).isoformat(), + "ttl_seconds": 300, + "gateway_url": "", + "user_ns": "", + "job_id": "sbx-vm", + "token": "", + } + await client.create_ssh_session("sbx-vm") + assert any( + method == "POST" and path.endswith("/ssh-session") + for method, path, _ in recording.calls + ) + assert "public_key" not in recording.calls[-1][2]["json"] finally: await client.aclose() @pytest.mark.asyncio -async def test_async_close_ssh_session_blocked_for_vm(): +async def test_async_authorize_ssh_session_allowed_for_vm(): client, recording = _make_async_client(is_vm=True) try: - with pytest.raises(APIError) as exc_info: - await client.close_ssh_session("sbx-vm", "sess-1") - assert "SSH" in str(exc_info.value) - assert recording.calls == [] + cast(Any, recording)._response = { + "session_id": "sess-1", + "sandbox_id": "sbx-vm", + "host": "h", + "port": 2222, + "external_endpoint": "h:2222", + "expires_at": datetime.now(timezone.utc).isoformat(), + "ttl_seconds": 300, + } + await client.authorize_ssh_session("sbx-vm", "sess-1", public_key="ssh-ed25519 AAAA...") + assert any( + method == "POST" and path.endswith("/ssh-session/sess-1/authorize") + for method, path, _ in recording.calls + ) + assert recording.calls[-1][2]["json"]["public_key"] == "ssh-ed25519 AAAA..." + finally: + await client.aclose() + + +@pytest.mark.asyncio +async def test_async_close_ssh_session_allowed_for_vm(): + client, recording = _make_async_client(is_vm=True) + try: + await client.close_ssh_session("sbx-vm", "sess-1") + assert any( + method == "DELETE" and path.endswith("/ssh-session/sess-1") + for method, path, _ in recording.calls + ) finally: await client.aclose() diff --git a/packages/prime/src/prime_cli/commands/sandbox.py b/packages/prime/src/prime_cli/commands/sandbox.py index bf9442d86..6345f9f7b 100644 --- a/packages/prime/src/prime_cli/commands/sandbox.py +++ b/packages/prime/src/prime_cli/commands/sandbox.py @@ -5,6 +5,7 @@ import shutil import string import subprocess +import sys import tempfile import time from datetime import datetime, timedelta @@ -221,11 +222,9 @@ def _select_sandbox_for_ssh(sandbox_client: SandboxClient) -> str: break page += 1 - # SSH is only supported for non-VM sandboxes (see _guard_vm_unsupported). items = [ {"id": sb.id, "name": sb.name, "image": sb.docker_image} for sb in sort_by_created(sandboxes) - if not sb.vm ] selected = require_selection( @@ -1593,7 +1592,7 @@ def cleanup() -> None: with console.status("[bold blue]Checking sandbox status...", spinner="dots"): sandbox = sandbox_client.get(sandbox_id) - _guard_vm_unsupported(sandbox, "SSH") + is_vm_sandbox = bool(getattr(sandbox, "vm", False)) if sandbox.status != "RUNNING": console.print(f"[red]Error:[/red] Sandbox is not running (status: {sandbox.status})") @@ -1613,25 +1612,35 @@ def cleanup() -> None: with open(f"{key_path}.pub", "r") as f: public_key = f.read().strip() - # Create SSH session + # Create SSH session (same call for container and VM sandboxes). console.print("[bold blue]Creating SSH session...[/bold blue]") with console.status("[bold blue]Setting up SSH session...", spinner="dots"): session = sandbox_client.create_ssh_session(sandbox_id) session_id = session.session_id - # Authorize the key - authorize_url = ( - f"{session.gateway_url.rstrip('/')}/{session.user_ns}/{session.job_id}/authorize" - ) - headers = {"Authorization": f"Bearer {session.token}"} - payload = { - "session_id": session.session_id, - "public_key": public_key, - "ttl_seconds": session.ttl_seconds, - } + # Authorize the ephemeral key. The two-step create -> authorize flow is + # identical across runtimes; only the authorize transport differs: + # container sandboxes hit the SSH sidecar directly, VM sandboxes go + # through the platform authorize endpoint. try: - with httpx.Client(timeout=30) as client: - client.post(authorize_url, json=payload, headers=headers).raise_for_status() + if is_vm_sandbox: + sandbox_client.authorize_ssh_session( + sandbox_id, + session_id, + public_key=public_key, + ttl_seconds=session.ttl_seconds, + ) + else: + gateway = session.gateway_url.rstrip("/") + authorize_url = f"{gateway}/{session.user_ns}/{session.job_id}/authorize" + headers = {"Authorization": f"Bearer {session.token}"} + payload = { + "session_id": session.session_id, + "public_key": public_key, + "ttl_seconds": session.ttl_seconds, + } + with httpx.Client(timeout=30) as client: + client.post(authorize_url, json=payload, headers=headers).raise_for_status() except Exception as e: console.print(f"[red]Error:[/red] Failed to authorize SSH key: {e}") cleanup() @@ -1660,6 +1669,45 @@ def cleanup() -> None: ssh_cmd.extend(["-o", "UserKnownHostsFile=/dev/null"]) ssh_cmd.extend(["-o", "LogLevel=ERROR"]) + if is_vm_sandbox: + # VM SSH connections start by sending a session prefix to the L4 gateway. + # The proxy then passes the remaining SSH byte stream through unchanged. + prefix = f"PRIME-SSH-SESSION {session.session_id}\n" + python_exec = sys.executable or "python3" + proxy_script = ( + "import socket, sys, threading\n" + f"s = socket.create_connection(({ssh_host!r}, {int(ssh_port)}))\n" + f"s.sendall({prefix!r}.encode())\n" + "def _reader():\n" + " try:\n" + " while True:\n" + " b = s.recv(4096)\n" + " if not b:\n" + " break\n" + " sys.stdout.buffer.write(b)\n" + " sys.stdout.buffer.flush()\n" + " except OSError:\n" + " pass\n" + "t = threading.Thread(target=_reader, daemon=True)\n" + "t.start()\n" + "try:\n" + " while True:\n" + " b = sys.stdin.buffer.read1(4096)\n" + " if not b:\n" + " break\n" + " s.sendall(b)\n" + "except OSError:\n" + " pass\n" + "finally:\n" + " try:\n" + " s.shutdown(socket.SHUT_WR)\n" + " except OSError:\n" + " pass\n" + " s.close()\n" + ) + proxy_cmd = f"{shlex.quote(python_exec)} -c {shlex.quote(proxy_script)}" + ssh_cmd.extend(["-o", f"ProxyCommand={proxy_cmd}"]) + # Add identity file if specified if key_path: ssh_cmd.extend(["-i", key_path]) diff --git a/packages/prime/tests/test_sandbox_cli.py b/packages/prime/tests/test_sandbox_cli.py index c1f2bbcba..8153754c7 100644 --- a/packages/prime/tests/test_sandbox_cli.py +++ b/packages/prime/tests/test_sandbox_cli.py @@ -404,7 +404,7 @@ def mock_bulk_delete(self: Any, **kwargs: Any) -> Any: def test_sandbox_ssh_no_id_picks_running_sandbox(monkeypatch: pytest.MonkeyPatch) -> None: - """`prime sandbox ssh` with no ID lists running, non-VM sandboxes to pick from. + """`prime sandbox ssh` with no ID lists running sandboxes to pick from. Selecting one feeds its ID into the rest of the flow; we stop the flow right after by returning a non-RUNNING sandbox from ``get``. @@ -450,10 +450,10 @@ def mock_get(self: Any, sandbox_id: str) -> Any: result = runner.invoke(app, ["sandbox", "ssh"], input="1\n") output = strip_ansi(result.output) - # Only RUNNING sandboxes are requested, and the VM one is filtered out of the picker. + # Only RUNNING sandboxes are requested, and VMs are valid SSH targets. assert captured["list_kwargs"]["status"] == "RUNNING" assert "sbx-container" in output - assert "sbx-vm" not in output + assert "sbx-vm" in output # The chosen sandbox flows into the rest of the SSH flow. assert captured["get_id"] == "sbx-container" assert "not running" in output @@ -467,18 +467,9 @@ def test_sandbox_ssh_no_id_no_running_sandboxes(monkeypatch: pytest.MonkeyPatch) monkeypatch.setattr("prime_cli.commands.sandbox.shutil.which", lambda _: "/usr/bin/ssh") def mock_list(self: Any, **kwargs: Any) -> Any: - # Only a VM sandbox exists; it is not SSH-able, so the picker is empty. return SimpleNamespace( - sandboxes=[ - SimpleNamespace( - id="sbx-vm", - name="gpu-box", - docker_image="cuda:12", - vm=True, - created_at="2026-05-02T00:00:00Z", - ), - ], - total=1, + sandboxes=[], + total=0, page=1, per_page=100, has_next=False, @@ -498,11 +489,7 @@ def mock_get(self: Any, sandbox_id: str) -> Any: def test_sandbox_ssh_no_id_pages_through_results(monkeypatch: pytest.MonkeyPatch) -> None: - """The picker pages past page 1, even when page 1 holds only VMs. - - Guards against reporting "no running sandboxes" when the only SSH-able - container lives on a later page. - """ + """The picker pages past page 1 before presenting all running sandboxes.""" monkeypatch.setenv("PRIME_API_KEY", "dummy") monkeypatch.setenv("PRIME_DISABLE_VERSION_CHECK", "1") monkeypatch.setattr("prime_cli.commands.sandbox.shutil.which", lambda _: "/usr/bin/ssh") @@ -512,7 +499,7 @@ def test_sandbox_ssh_no_id_pages_through_results(monkeypatch: pytest.MonkeyPatch 1: SimpleNamespace( sandboxes=[ SimpleNamespace( - id="sbx-vm", + id="sbx-first", name="gpu-box", docker_image="cuda:12", vm=True, @@ -556,8 +543,9 @@ def mock_get(self: Any, sandbox_id: str) -> Any: output = strip_ansi(result.output) assert captured["pages_requested"] == [1, 2] + assert "sbx-first" in output assert "sbx-container" in output - assert captured["get_id"] == "sbx-container" + assert captured["get_id"] == "sbx-first" assert result.exit_code == 1 @@ -600,6 +588,102 @@ def mock_get(self: Any, sandbox_id: str) -> Any: assert result.exit_code == 1 +def test_sandbox_ssh_vm_uses_public_key_and_proxy_command( + monkeypatch: pytest.MonkeyPatch, tmp_path: Any +) -> None: + monkeypatch.setenv("PRIME_API_KEY", "dummy") + monkeypatch.setenv("PRIME_DISABLE_VERSION_CHECK", "1") + monkeypatch.setattr("prime_cli.commands.sandbox.shutil.which", lambda _: "/usr/bin/ssh") + + temp_dir = tmp_path / "ssh" + temp_dir.mkdir() + monkeypatch.setattr("prime_cli.commands.sandbox.tempfile.mkdtemp", lambda prefix: str(temp_dir)) + + captured: dict[str, Any] = {"closed": []} + + def mock_get(self: Any, sandbox_id: str) -> Any: + return SimpleNamespace(id=sandbox_id, vm=True, status="RUNNING") + + def mock_create_ssh_session(self: Any, sandbox_id: str, **kwargs: Any) -> Any: + captured["create_kwargs"] = kwargs + return SimpleNamespace( + session_id="sess-vm", + host="gw-tcp.example.com", + port=2222, + ttl_seconds=300, + gateway_url="", + user_ns="", + job_id=sandbox_id, + token="", + ) + + def mock_authorize_ssh_session( + self: Any, sandbox_id: str, session_id: str, **kwargs: Any + ) -> Any: + captured["authorize_args"] = (sandbox_id, session_id) + captured["authorize_kwargs"] = kwargs + return SimpleNamespace( + session_id=session_id, + sandbox_id=sandbox_id, + host="gw-tcp.example.com", + port=2222, + external_endpoint="gw-tcp.example.com:2222", + ttl_seconds=300, + ) + + def mock_close_ssh_session(self: Any, sandbox_id: str, session_id: str) -> None: + captured["closed"].append((sandbox_id, session_id)) + + def mock_subprocess_run(cmd: list[str], **kwargs: Any) -> Any: + if cmd[:2] == ["ssh-keygen", "-t"]: + (temp_dir / "id_ed25519.pub").write_text("ssh-ed25519 AAAA vm@test") + return SimpleNamespace(returncode=0) + captured["ssh_cmd"] = cmd + return SimpleNamespace(returncode=0) + + class _NoHTTPClient: + def __init__(self, *_args: Any, **_kwargs: Any) -> None: + pass + + def __enter__(self) -> "_NoHTTPClient": + return self + + def __exit__(self, *_args: Any) -> None: + return None + + def post(self, *_args: Any, **_kwargs: Any) -> Any: + raise AssertionError("VM SSH should not call sidecar authorize") + + monkeypatch.setattr("prime_cli.commands.sandbox.SandboxClient.get", mock_get) + monkeypatch.setattr( + "prime_cli.commands.sandbox.SandboxClient.create_ssh_session", + mock_create_ssh_session, + ) + monkeypatch.setattr( + "prime_cli.commands.sandbox.SandboxClient.authorize_ssh_session", + mock_authorize_ssh_session, + ) + monkeypatch.setattr( + "prime_cli.commands.sandbox.SandboxClient.close_ssh_session", + mock_close_ssh_session, + ) + monkeypatch.setattr("prime_cli.commands.sandbox.subprocess.run", mock_subprocess_run) + monkeypatch.setattr("prime_cli.commands.sandbox.httpx.Client", _NoHTTPClient) + + result = runner.invoke(app, ["sandbox", "ssh", "sbx-vm"]) + + assert result.exit_code == 0, result.output + # Create no longer carries the public key. + assert "public_key" not in captured["create_kwargs"] + # Authorize is a separate step, called with the generated public key. + assert captured["authorize_args"] == ("sbx-vm", "sess-vm") + assert captured["authorize_kwargs"]["public_key"] == "ssh-ed25519 AAAA vm@test" + ssh_cmd = captured["ssh_cmd"] + assert any("ProxyCommand=" in arg for arg in ssh_cmd) + assert any("PRIME-SSH-SESSION sess-vm" in arg for arg in ssh_cmd) + assert captured["closed"] == [("sbx-vm", "sess-vm")] + + def test_format_sandbox_expiry_running_shows_time_left() -> None: now = datetime.now(timezone.utc) sb = _fake_sandbox(status="RUNNING", started_at=now - timedelta(minutes=10), timeout_minutes=60)