Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/prime-sandboxes/src/prime_sandboxes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from .models import (
AdvancedConfigs,
AuthorizeSSHForVMResponse,
BackgroundJob,
BackgroundJobStatus,
BulkDeleteSandboxRequest,
Expand Down Expand Up @@ -84,6 +85,7 @@
"ExposedPort",
"ListExposedPortsResponse",
"SSHSession",
"AuthorizeSSHForVMResponse",
# Exceptions
"APIError",
"UnauthorizedError",
Expand Down
12 changes: 12 additions & 0 deletions packages/prime-sandboxes/src/prime_sandboxes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,18 @@ class SSHSession(BaseModel):
token: str


class AuthorizeSSHForVMResponse(BaseModel):
"""Result of authorizing an SSH public key for a VM sandbox session."""

session_id: str
sandbox_id: str
host: str
port: int
external_endpoint: str
expires_at: datetime
ttl_seconds: int


class BackgroundJob(BaseModel):
"""Background job handle returned when starting a background job"""

Expand Down
67 changes: 59 additions & 8 deletions packages/prime-sandboxes/src/prime_sandboxes/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
UploadTimeoutError,
)
from .models import (
AuthorizeSSHForVMResponse,
BackgroundJob,
BackgroundJobStatus,
BulkDeleteSandboxRequest,
Expand Down Expand Up @@ -1367,8 +1368,12 @@ def create_ssh_session(
sandbox_id: str,
ttl_seconds: Optional[int] = None,
) -> SSHSession:
"""Create an SSH session"""
self._guard_vm_unsupported(sandbox_id, "SSH")
"""Create an SSH session.

This only creates the session. Register the ephemeral public key in a
separate step: container sandboxes authorize directly against the SSH
sidecar, VM sandboxes authorize via ``authorize_ssh_session``.
"""
payload: Dict[str, Any] = {}
if ttl_seconds is not None:
payload["ttl_seconds"] = ttl_seconds
Expand All @@ -1379,9 +1384,30 @@ def create_ssh_session(
)
return SSHSession.model_validate(response)

def authorize_ssh_session(
self,
sandbox_id: str,
session_id: str,
public_key: str,
ttl_seconds: Optional[int] = None,
) -> AuthorizeSSHForVMResponse:
"""Authorize an ephemeral public key for a VM sandbox SSH session.

Container sandboxes authorize directly against the SSH sidecar and
should not call this method.
"""
payload: Dict[str, Any] = {"public_key": public_key}
if ttl_seconds is not None:
payload["ttl_seconds"] = ttl_seconds
response = self.client.request(
"POST",
f"/sandbox/{sandbox_id}/ssh-session/{session_id}/authorize",
json=payload,
)
return AuthorizeSSHForVMResponse.model_validate(response)

def close_ssh_session(self, sandbox_id: str, session_id: str) -> None:
"""Close an SSH session and remove its exposure"""
self._guard_vm_unsupported(sandbox_id, "SSH")
"""Close an SSH session."""
self.client.request("DELETE", f"/sandbox/{sandbox_id}/ssh-session/{session_id}")


Expand Down Expand Up @@ -2308,8 +2334,12 @@ async def create_ssh_session(
sandbox_id: str,
ttl_seconds: Optional[int] = None,
) -> SSHSession:
"""Create an SSH session"""
await self._guard_vm_unsupported(sandbox_id, "SSH")
"""Create an SSH session.

This only creates the session. Register the ephemeral public key in a
separate step: container sandboxes authorize directly against the SSH
sidecar, VM sandboxes authorize via ``authorize_ssh_session``.
"""
payload: Dict[str, Any] = {}
if ttl_seconds is not None:
payload["ttl_seconds"] = ttl_seconds
Expand All @@ -2320,9 +2350,30 @@ async def create_ssh_session(
)
return SSHSession.model_validate(response)

async def authorize_ssh_session(
self,
sandbox_id: str,
session_id: str,
public_key: str,
ttl_seconds: Optional[int] = None,
) -> AuthorizeSSHForVMResponse:
"""Authorize an ephemeral public key for a VM sandbox SSH session.

Container sandboxes authorize directly against the SSH sidecar and
should not call this method.
"""
payload: Dict[str, Any] = {"public_key": public_key}
if ttl_seconds is not None:
payload["ttl_seconds"] = ttl_seconds
response = await self.client.request(
"POST",
f"/sandbox/{sandbox_id}/ssh-session/{session_id}/authorize",
json=payload,
)
return AuthorizeSSHForVMResponse.model_validate(response)

async def close_ssh_session(self, sandbox_id: str, session_id: str) -> None:
"""Close an SSH session and remove its exposure"""
await self._guard_vm_unsupported(sandbox_id, "SSH")
"""Close an SSH session."""
await self.client.request("DELETE", f"/sandbox/{sandbox_id}/ssh-session/{session_id}")


Expand Down
116 changes: 96 additions & 20 deletions packages/prime-sandboxes/tests/test_vm_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,56 @@ def test_sync_list_exposed_ports_blocked_for_vm():
assert recording.calls == []


def test_sync_create_ssh_session_blocked_for_vm():
def test_sync_create_ssh_session_allowed_for_vm():
client, recording = _make_sync_client(is_vm=True)
with pytest.raises(APIError) as exc_info:
client.create_ssh_session("sbx-vm")
assert "SSH" in str(exc_info.value)
assert recording.calls == []
cast(Any, recording)._response = {
"session_id": "s",
"exposure_id": "s",
"sandbox_id": "sbx-vm",
"host": "h",
"port": 2222,
"external_endpoint": "h:2222",
"expires_at": datetime.now(timezone.utc).isoformat(),
"ttl_seconds": 300,
"gateway_url": "",
"user_ns": "",
"job_id": "sbx-vm",
"token": "",
}
client.create_ssh_session("sbx-vm")
assert any(
method == "POST" and path.endswith("/ssh-session") for method, path, _ in recording.calls
)
# Create no longer carries the public key; that happens in authorize.
assert "public_key" not in recording.calls[-1][2]["json"]


def test_sync_close_ssh_session_blocked_for_vm():
def test_sync_authorize_ssh_session_allowed_for_vm():
client, recording = _make_sync_client(is_vm=True)
with pytest.raises(APIError) as exc_info:
client.close_ssh_session("sbx-vm", "sess-1")
assert "SSH" in str(exc_info.value)
assert recording.calls == []
cast(Any, recording)._response = {
"session_id": "sess-1",
"sandbox_id": "sbx-vm",
"host": "h",
"port": 2222,
"external_endpoint": "h:2222",
"expires_at": datetime.now(timezone.utc).isoformat(),
"ttl_seconds": 300,
}
client.authorize_ssh_session("sbx-vm", "sess-1", public_key="ssh-ed25519 AAAA...")
assert any(
method == "POST" and path.endswith("/ssh-session/sess-1/authorize")
for method, path, _ in recording.calls
)
assert recording.calls[-1][2]["json"]["public_key"] == "ssh-ed25519 AAAA..."


def test_sync_close_ssh_session_allowed_for_vm():
client, recording = _make_sync_client(is_vm=True)
client.close_ssh_session("sbx-vm", "sess-1")
assert any(
method == "DELETE" and path.endswith("/ssh-session/sess-1")
for method, path, _ in recording.calls
)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -327,25 +363,65 @@ async def test_async_list_exposed_ports_blocked_for_vm():


@pytest.mark.asyncio
async def test_async_create_ssh_session_blocked_for_vm():
async def test_async_create_ssh_session_allowed_for_vm():
client, recording = _make_async_client(is_vm=True)
try:
with pytest.raises(APIError) as exc_info:
await client.create_ssh_session("sbx-vm")
assert "SSH" in str(exc_info.value)
assert recording.calls == []
cast(Any, recording)._response = {
"session_id": "s",
"exposure_id": "s",
"sandbox_id": "sbx-vm",
"host": "h",
"port": 2222,
"external_endpoint": "h:2222",
"expires_at": datetime.now(timezone.utc).isoformat(),
"ttl_seconds": 300,
"gateway_url": "",
"user_ns": "",
"job_id": "sbx-vm",
"token": "",
}
await client.create_ssh_session("sbx-vm")
assert any(
method == "POST" and path.endswith("/ssh-session")
for method, path, _ in recording.calls
)
assert "public_key" not in recording.calls[-1][2]["json"]
finally:
await client.aclose()


@pytest.mark.asyncio
async def test_async_close_ssh_session_blocked_for_vm():
async def test_async_authorize_ssh_session_allowed_for_vm():
client, recording = _make_async_client(is_vm=True)
try:
with pytest.raises(APIError) as exc_info:
await client.close_ssh_session("sbx-vm", "sess-1")
assert "SSH" in str(exc_info.value)
assert recording.calls == []
cast(Any, recording)._response = {
"session_id": "sess-1",
"sandbox_id": "sbx-vm",
"host": "h",
"port": 2222,
"external_endpoint": "h:2222",
"expires_at": datetime.now(timezone.utc).isoformat(),
"ttl_seconds": 300,
}
await client.authorize_ssh_session("sbx-vm", "sess-1", public_key="ssh-ed25519 AAAA...")
assert any(
method == "POST" and path.endswith("/ssh-session/sess-1/authorize")
for method, path, _ in recording.calls
)
assert recording.calls[-1][2]["json"]["public_key"] == "ssh-ed25519 AAAA..."
finally:
await client.aclose()


@pytest.mark.asyncio
async def test_async_close_ssh_session_allowed_for_vm():
client, recording = _make_async_client(is_vm=True)
try:
await client.close_ssh_session("sbx-vm", "sess-1")
assert any(
method == "DELETE" and path.endswith("/ssh-session/sess-1")
for method, path, _ in recording.calls
)
finally:
await client.aclose()

Expand Down
Loading
Loading