diff --git a/src/ares/containers/__init__.py b/src/ares/containers/__init__.py index 1dd0114..9f85967 100644 --- a/src/ares/containers/__init__.py +++ b/src/ares/containers/__init__.py @@ -3,6 +3,7 @@ from ares.containers.containers import Container from ares.containers.containers import ContainerFactory from ares.containers.containers import Resources +from ares.containers.containers import SnapshotableContainer from ares.containers.daytona import DaytonaContainer __all__ = [ @@ -10,4 +11,5 @@ "ContainerFactory", "DaytonaContainer", "Resources", + "SnapshotableContainer", ] diff --git a/src/ares/containers/containers.py b/src/ares/containers/containers.py index d4da940..6db9622 100644 --- a/src/ares/containers/containers.py +++ b/src/ares/containers/containers.py @@ -131,6 +131,66 @@ def stop_and_remove(self) -> None: """ +class SnapshotableContainer(Container, Protocol): + """A container that supports filesystem state snapshotting. + + This extends the Container protocol with the ability to capture and + restore filesystem state, enabling algorithms like Go-Explore that + need to return to previously visited environment states. + """ + + @abc.abstractmethod + async def snapshot(self) -> str: + """Capture the container's current filesystem state. + + Creates a snapshot of all filesystem changes since the container was started. + Running processes are NOT captured -- only filesystem state. + + Returns: + A snapshot ID string that can be passed to from_snapshot() to create + a new container with this filesystem state. + """ + + @classmethod + @abc.abstractmethod + def from_snapshot( + cls, + snapshot_id: str, + *, + name: str | None = None, + resources: Resources | None = None, + default_workdir: str | None = None, + ) -> "SnapshotableContainer": + """Create a new (unstarted) container from a previously captured snapshot. + + Args: + snapshot_id: A snapshot ID previously returned by snapshot(). + name: Optional name for the container. + resources: Optional resource constraints. + default_workdir: Optional default working directory for commands. + + Returns: + A new SnapshotableContainer instance (not yet started). + """ + ... + + @abc.abstractmethod + async def delete_snapshot(self, snapshot_id: str) -> None: + """Delete a previously captured snapshot, freeing its resources. + + Args: + snapshot_id: The snapshot ID to delete. + """ + + @abc.abstractmethod + def delete_snapshot_sync(self, snapshot_id: str) -> None: + """Synchronous version of delete_snapshot for atexit cleanup. + + Args: + snapshot_id: The snapshot ID to delete. + """ + + class ContainerFactory(Protocol): """Protocol for creating containers from images or Dockerfiles. diff --git a/src/ares/containers/docker.py b/src/ares/containers/docker.py index 6d1cae5..1892ff2 100644 --- a/src/ares/containers/docker.py +++ b/src/ares/containers/docker.py @@ -1,12 +1,15 @@ """An interface for interacting with local Docker containers.""" import asyncio +import contextlib import dataclasses import functools import io +import logging import pathlib import tarfile from typing import cast +import uuid import docker import docker.errors @@ -15,6 +18,8 @@ from ares.containers import containers +_LOGGER = logging.getLogger(__name__) + def _make_docker_client() -> docker.DockerClient: try: @@ -24,7 +29,7 @@ def _make_docker_client() -> docker.DockerClient: @dataclasses.dataclass(kw_only=True) -class DockerContainer(containers.Container): +class DockerContainer(containers.SnapshotableContainer): image: str | None = None dockerfile_path: pathlib.Path | str | None = None name: str | None = None @@ -183,6 +188,46 @@ async def download_files(self, remote_paths: list[str], local_paths: list[pathli with open(local_path, "wb") as f: f.write(file_data.read()) + async def snapshot(self) -> str: + """Commit current container state as a Docker image.""" + if self._container is None: + raise RuntimeError("Container not started, snapshot is not possible.") + + tag = f"ares-snapshot-{uuid.uuid4().hex[:12]}" + image = await asyncio.to_thread( + self._container.commit, + repository="ares-go-explore", + tag=tag, + conf={"Labels": {"ares-go-explore": "true"}}, + ) + _LOGGER.info("Snapshot created: %s (tag: %s)", image.id, tag) + return image.id + + async def delete_snapshot(self, snapshot_id: str) -> None: + """Delete a Docker image created by snapshot().""" + try: + await asyncio.to_thread(self._client.images.remove, snapshot_id) + _LOGGER.info("Snapshot deleted: %s", snapshot_id) + except docker.errors.ImageNotFound: + _LOGGER.debug("Snapshot %s already deleted.", snapshot_id) + + def delete_snapshot_sync(self, snapshot_id: str) -> None: + """Synchronous version for atexit cleanup.""" + with contextlib.suppress(docker.errors.ImageNotFound): + self._client.images.remove(snapshot_id) + + @classmethod + def from_snapshot( + cls, + snapshot_id: str, + *, + name: str | None = None, + resources: containers.Resources | None = None, + default_workdir: str | None = None, + ) -> "DockerContainer": + """Create a DockerContainer from a previously captured snapshot.""" + return DockerContainer(image=snapshot_id, name=name, resources=resources, default_workdir=default_workdir) + @classmethod def from_image( cls, diff --git a/src/ares/containers/docker_snapshot_test.py b/src/ares/containers/docker_snapshot_test.py new file mode 100644 index 0000000..9e8259f --- /dev/null +++ b/src/ares/containers/docker_snapshot_test.py @@ -0,0 +1,86 @@ +"""Tests for Docker container snapshotting.""" + +import unittest.mock + +import pytest + +from ares.containers import docker + + +@pytest.fixture +def mock_docker_client(): + """Create a mock Docker client.""" + with unittest.mock.patch.object(docker, "_make_docker_client") as mock_fn: + client = unittest.mock.MagicMock() + mock_fn.return_value = client + yield client + + +@pytest.mark.asyncio +async def test_snapshot_creates_image(mock_docker_client): # noqa: ARG001 + """Test that snapshot() commits the container and returns an image ID.""" + container = docker.DockerContainer(image="test:latest") + + # Set up mock container + mock_inner = unittest.mock.MagicMock() + container._container = mock_inner + + mock_image = unittest.mock.MagicMock() + mock_image.id = "sha256:abc123" + mock_inner.commit.return_value = mock_image + + snapshot_id = await container.snapshot() + + assert snapshot_id == "sha256:abc123" + mock_inner.commit.assert_called_once() + call_kwargs = mock_inner.commit.call_args + assert call_kwargs[1]["repository"] == "ares-go-explore" + assert call_kwargs[1]["conf"]["Labels"]["ares-go-explore"] == "true" + + +@pytest.mark.asyncio +async def test_snapshot_raises_if_not_started(): + """Test that snapshot() raises if container isn't started.""" + container = docker.DockerContainer(image="test:latest") + with pytest.raises(RuntimeError, match="not started"): + await container.snapshot() + + +def test_from_snapshot_creates_container(): + """Test that from_snapshot() creates a DockerContainer with the snapshot as image.""" + container = docker.DockerContainer.from_snapshot( + "sha256:abc123", + name="restored", + default_workdir="/workspace", + ) + + assert isinstance(container, docker.DockerContainer) + assert container.image == "sha256:abc123" + assert container.name == "restored" + assert container.default_workdir == "/workspace" + + +@pytest.mark.asyncio +async def test_delete_snapshot(mock_docker_client): + """Test that delete_snapshot() removes the Docker image.""" + container = docker.DockerContainer(image="test:latest") + await container.delete_snapshot("sha256:abc123") + mock_docker_client.images.remove.assert_called_once_with("sha256:abc123") + + +@pytest.mark.asyncio +async def test_delete_snapshot_ignores_not_found(mock_docker_client): + """Test that delete_snapshot() handles already-deleted images.""" + import docker as docker_lib + + mock_docker_client.images.remove.side_effect = docker_lib.errors.ImageNotFound("not found") + container = docker.DockerContainer(image="test:latest") + # Should not raise + await container.delete_snapshot("sha256:abc123") + + +def test_delete_snapshot_sync(mock_docker_client): + """Test synchronous snapshot deletion for atexit cleanup.""" + container = docker.DockerContainer(image="test:latest") + container.delete_snapshot_sync("sha256:abc123") + mock_docker_client.images.remove.assert_called_once_with("sha256:abc123") diff --git a/src/ares/presets.py b/src/ares/presets.py index 7429cfa..07824bb 100644 --- a/src/ares/presets.py +++ b/src/ares/presets.py @@ -125,14 +125,19 @@ def _register_default_presets() -> None: This function is called automatically when the presets module is imported, ensuring built-in presets are always available. """ + seen: set[str] = set() for ds_spec in code_env.list_harbor_datasets(): for code_agent_id, code_agent_factory in [ ("mswea", mini_swe_agent.MiniSWECodeAgent), ("terminus2", terminus2_agent.Terminus2Agent), ]: ds_id = _make_harbor_dataset_id(ds_spec.name, ds_spec.version) + preset_name = f"{ds_id}-{code_agent_id}" + if preset_name in seen: + continue + seen.add(preset_name) registry.register_preset( - f"{ds_id}-{code_agent_id}", + preset_name, HarborSpec( ds_spec=ds_spec, dataset_id=ds_id,