diff --git a/golden_tests/conftest.py b/golden_tests/conftest.py index 8c3eabb..a98ab74 100644 --- a/golden_tests/conftest.py +++ b/golden_tests/conftest.py @@ -62,7 +62,6 @@ def module_report_fixture(request): test_result.outcome ) - base_name = Path(request.module.__file__).stem report_filename = f"{base_name}_report.yaml" report_path = Path(__file__).parent / report_filename diff --git a/golden_tests/test_kaggle_client.py b/golden_tests/test_kaggle_client.py new file mode 100644 index 0000000..9cfd912 --- /dev/null +++ b/golden_tests/test_kaggle_client.py @@ -0,0 +1,444 @@ +# Copyright 2026 Kaggle Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Golden tests for kaggle_client.notebook_api.BenchmarkNotebookClient. + +These tests exercise the real Kaggle API (authentication, push, status +polling, fork, output download) and are excluded from CI because they +require valid Kaggle credentials and internet access. + +Tests are organized to mirror the typical user journey: + + 1. Authenticate → TestAuth + 2. Fork a benchmark → TestFork + 3. Publish & run → TestPublish + 4. Poll & get results → TestGetResults + 5. Error handling → TestErrorHandling + +Usage: + uv run --group test pytest golden_tests/test_kaggle_client.py + uv run --group test pytest golden_tests/test_kaggle_client.py::TestAuth + uv run --group test pytest golden_tests/test_kaggle_client.py::TestGetResults::test_full_round_trip +""" + +import json +import threading +import time +import uuid + +import pytest + +from kaggle_benchmarks.kaggle_client.notebook_api import ( + BenchmarkNotebookClient, + ConcurrentRunError, + RunResult, +) + +# --------------------------------------------------------------------------- +# Constants & Helpers +# --------------------------------------------------------------------------- + +# A minimal benchmark script that runs a subtraction task twice, producing +# two *.run.json output files for end-to-end verification. +MINIMAL_BENCHMARK_SCRIPT = """\ +# %% +import kaggle_benchmarks as kbench + +@kbench.task("subtraction") +def test_subtraction(llm): + llm.stream_responses = True + response = llm.prompt("9.9 - 9.11 = ?") + kbench.assertions.assert_in("0.79", response, expectation="Expect 9.9-9.11=0.79") + +# %% +# Execute the task twice to generate two run.json files +test_subtraction.run(llm=kbench.llm) +test_subtraction.run(llm=kbench.llm) +""" + + +def _make_client(base_dir): + """Create a BenchmarkNotebookClient rooted in a given directory.""" + return BenchmarkNotebookClient(base_dir=base_dir) + + +def _prepare_workspace(client, slug=None, script=MINIMAL_BENCHMARK_SCRIPT): + """Create a workspace with a benchmark script. Returns (workspace, slug).""" + slug = slug or f"kbench-golden-{uuid.uuid4().hex[:8]}" + workspace = client._workspace(slug) + workspace.mkdir(parents=True, exist_ok=True) + benchmark_py = workspace / BenchmarkNotebookClient.BENCHMARK_FILENAME + benchmark_py.write_text(script, encoding="utf-8") + return workspace, slug + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def client(tmp_path_factory): + """Module-scoped client — shared across tests that don't need isolation.""" + base = tmp_path_factory.mktemp("kaggle_client_golden") + return _make_client(base) + + +@pytest.fixture +def fresh_client(tmp_path): + """Function-scoped client — for tests that need a clean API state.""" + return _make_client(tmp_path) + + +@pytest.fixture(scope="module") +def fork_source_notebook_id(client): + """Publish a dummy notebook that fork tests can reliably pull from.""" + _, slug = _prepare_workspace( + client, + slug=f"kbench-golden-source-{uuid.uuid4().hex[:8]}", + script='print("dummy notebook for fork test")\n', + ) + client.publish_and_run(slug, force=True) + return f"{client.username}/{slug}" + + +# =========================================================================== +# AUTHENTICATE +# +# The very first thing a user does — verify they can talk to Kaggle. +# =========================================================================== + + +class TestAuth: + """Credential validation against the real Kaggle API. + + Docs: https://github.com/Kaggle/kaggle-cli/blob/main/docs/README.md#authentication + """ + + def test_client_authenticates(self, client): + """Creating a client should authenticate without error.""" + assert client.api is not None + + def test_username_is_populated(self, client): + """Client should derive a non-empty username from credentials.""" + assert isinstance(client.username, str) + assert len(client.username) > 0 + + +# =========================================================================== +# FORK AN EXISTING BENCHMARK +# +# Users often start by forking a public benchmark, then modifying it. +# =========================================================================== + + +class TestFork: + """Forking (pulling) an existing notebook from Kaggle.""" + + def test_fork_creates_workspace_with_files(self, client, fork_source_notebook_id): + """fork() pulls the notebook, metadata, and converts .ipynb → .py.""" + workspace = client.fork( + fork_source_notebook_id, + dest_notebook_slug=f"kbench-golden-fork-{uuid.uuid4().hex[:8]}", + overwrite=True, + ) + + assert workspace.exists() and workspace.is_dir() + + # Metadata should be present + meta_path = workspace / BenchmarkNotebookClient.METADATA_FILENAME + assert meta_path.exists() + + # Should have a .py or .ipynb file (depending on source notebook type) + py_path = workspace / BenchmarkNotebookClient.BENCHMARK_FILENAME + ipynb_path = workspace / BenchmarkNotebookClient.NOTEBOOK_FILENAME + assert py_path.exists() or ipynb_path.exists(), ( + f"Expected benchmark.py or benchmark.ipynb in {workspace}" + ) + + def test_fork_raises_on_existing_workspace(self, client, fork_source_notebook_id): + """fork(overwrite=False) raises FileExistsError if workspace exists.""" + slug = f"kbench-golden-fork-{uuid.uuid4().hex[:8]}" + client._workspace(slug).mkdir(parents=True, exist_ok=True) + + with pytest.raises(FileExistsError, match="Workspace already exists"): + client.fork( + fork_source_notebook_id, dest_notebook_slug=slug, overwrite=False + ) + + def test_fork_raises_on_missing_notebook(self, client): + """fork() raises ValueError for a non-existent notebook.""" + with pytest.raises(ValueError, match="Failed to pull notebook"): + client.fork("kaggle/54321-tsixe-ton-seod-koobeton-siht") + + def test_fork_modify_publish_lifecycle(self, client, fork_source_notebook_id): + """Full lifecycle: fork → modify → publish_and_run → get_results.""" + slug = f"kbench-golden-fork-lifecycle-{uuid.uuid4().hex[:8]}" + workspace = client.fork( + fork_source_notebook_id, + dest_notebook_slug=slug, + overwrite=True, + ) + + # Modify the script + benchmark_py = workspace / BenchmarkNotebookClient.BENCHMARK_FILENAME + with open(benchmark_py, "a", encoding="utf-8") as f: + f.write('\n# %%\nprint("modified after fork")\n') + + # Publish and wait for results + client.publish_and_run(slug, force=True) + result = client.get_results(slug, poll_interval=30, timeout=600) + + assert result.status == "complete", f"Error: {result.error}" + + +# =========================================================================== +# PUBLISH & RUN +# +# Prepare a script and push it to Kaggle for execution. +# =========================================================================== + + +class TestPublish: + """Publishing benchmark scripts to Kaggle.""" + + def test_publish_from_workspace(self, client): + """Tests the default workflow where the user authors benchmark.py directly inside the workspace.""" + workspace, slug = _prepare_workspace(client) + + url = client.publish_and_run(slug, force=True) + + # URL is well-formed + assert url.startswith("https://www.kaggle.com/") + assert client.username in url + assert slug in url + + # Metadata was written with correct id and keyword + meta_path = workspace / BenchmarkNotebookClient.METADATA_FILENAME + metadata = json.loads(meta_path.read_text(encoding="utf-8")) + assert metadata["id"] == f"{client.username}/{slug}" + assert "personal-benchmark" in metadata.get("keywords", []) + + # Notebook .ipynb was generated + assert (workspace / BenchmarkNotebookClient.NOTEBOOK_FILENAME).exists() + + def test_publish_with_source_file(self, client, tmp_path): + """Tests the override workflow where an external script is dynamically copied into the workspace.""" + slug = f"kbench-golden-source-file-{uuid.uuid4().hex[:8]}" + + # Write a source file outside the workspace + source = tmp_path / "my_bench.py" + source.write_text('# %%\nprint("source file test")\n', encoding="utf-8") + + url = client.publish_and_run(slug, source_file=source, force=True) + assert url is not None and slug in url + + # Verify the file was copied and converted + workspace = client._workspace(slug) + bench_py = workspace / BenchmarkNotebookClient.BENCHMARK_FILENAME + assert bench_py.exists() + assert ( + bench_py.read_text(encoding="utf-8") == '# %%\nprint("source file test")\n' + ) + + # Verify the .ipynb has the source code in a cell + notebook_data = json.loads( + (workspace / BenchmarkNotebookClient.NOTEBOOK_FILENAME).read_text( + encoding="utf-8" + ) + ) + source_found = any( + 'print("source file test")' in "".join(cell.get("source", [])) + for cell in notebook_data["cells"] + if cell.get("cell_type") == "code" + ) + assert source_found, "Source code not found in generated notebook cells" + + def test_publish_with_dataset_sources(self, client): + """publish_and_run(dataset_sources=...) includes them in metadata.""" + workspace, slug = _prepare_workspace( + client, script='# %%\nprint("dataset test")\n' + ) + + datasets = ["kaggle/meta-kaggle", "kaggle/meta-kaggle-code"] + client.publish_and_run(slug, dataset_sources=datasets, force=True) + + meta_path = workspace / BenchmarkNotebookClient.METADATA_FILENAME + metadata = json.loads(meta_path.read_text(encoding="utf-8")) + assert metadata["dataset_sources"] == datasets + + def test_publish_raises_without_benchmark_file(self, client): + """publish_and_run() raises FileNotFoundError if no benchmark.py exists.""" + slug = f"kbench-golden-no-source-{uuid.uuid4().hex[:8]}" + + with pytest.raises(FileNotFoundError, match="Benchmark file not found"): + client.publish_and_run(slug, force=True) + + def test_concurrent_guard_blocks_duplicate_push(self, fresh_client): + """publish_and_run(force=False) raises ConcurrentRunError if already running.""" + from requests.exceptions import HTTPError + + _, slug = _prepare_workspace(fresh_client) + + # First push starts the run + fresh_client.publish_and_run(slug, force=True) + + # Immediate second push without force should be blocked + # (unless Kaggle already finished the run — handle that edge case) + try: + fresh_client.publish_and_run(slug, force=False) + except ConcurrentRunError as e: + assert "already running" in str(e) + except HTTPError: + pass # 404 race condition — guard still passed + + def test_concurrent_guard_bypassed_with_force(self, fresh_client): + """publish_and_run(force=True) always succeeds regardless of run status.""" + _, slug = _prepare_workspace(fresh_client) + + # First push starts the run + fresh_client.publish_and_run(slug, force=True) + + # Immediate second push with force=True should bypass the guard and succeed + url = fresh_client.publish_and_run(slug, force=True) + assert url is not None + + +# =========================================================================== +# POLL & GET RESULTS +# +# After publishing, poll for completion and download outputs. +# =========================================================================== + + +class TestGetResults: + """Polling for results and downloading output from Kaggle.""" + + def test_full_round_trip(self, client): + """Publish → poll → download → parse run.json: the complete happy path. + + NOTE: This test takes several minutes while the notebook executes + on Kaggle. Timeout is set to 10 minutes. + """ + _, slug = _prepare_workspace(client) + + client.publish_and_run(slug, force=True) + + statuses_seen = [] + result = client.get_results( + slug, + poll_interval=30, + timeout=600, + on_status=statuses_seen.append, + ) + + assert isinstance(result, RunResult) + assert result.tracking_url is not None + assert slug in result.tracking_url + assert result.status == "complete", ( + f"Expected 'complete', got '{result.status}'. Error: {result.error}" + ) + assert result.output_dir is not None + assert result.error is None + + # Verify the on_status callback fired correctly + assert len(statuses_seen) > 0, "on_status callback was never invoked" + assert statuses_seen[-1] == "complete", ( + f"Final callback was {statuses_seen[-1]}, not complete" + ) + + # Verify *.run.json files were downloaded and have expected structure + runs = dict(result.iter_run_results()) + assert len(runs) >= 2, ( + f"Expected at least 2 run files, got {len(runs)}: {list(runs.keys())}" + ) + for filename, run_data in runs.items(): + assert run_data.get("state") == "BENCHMARK_TASK_RUN_STATE_COMPLETED", ( + f"Run '{filename}' has unexpected state: {run_data.get('state')}" + ) + assert run_data.get("taskVersion", {}).get("name") == "subtraction", ( + f"Run '{filename}' has unexpected task name" + ) + + def test_custom_output_dir(self, client, tmp_path): + """get_results(output_dir=...) saves output to the given path.""" + _, slug = _prepare_workspace(client) + client.publish_and_run(slug, force=True) + + custom_output = tmp_path / "custom_output" + result = client.get_results( + slug, output_dir=str(custom_output), poll_interval=30, timeout=600 + ) + + if result.status == "timeout": + pytest.skip("Notebook did not complete in time.") + + assert result.status == "complete", f"Failed with error: {result.error}" + assert result.output_dir == str(custom_output) + assert custom_output.exists() + + def test_timeout_returns_early(self, fresh_client): + """get_results returns status='timeout' when the time limit is hit.""" + _, slug = _prepare_workspace(fresh_client) + fresh_client.publish_and_run(slug, force=True) + + result = fresh_client.get_results(slug, poll_interval=5, timeout=1) + + assert isinstance(result, RunResult) + assert result.status == "timeout" + + def test_cancel_event_returns_early(self, fresh_client): + """get_results respects cancel_event and returns 'cancelled'.""" + _, slug = _prepare_workspace(fresh_client) + + cancel = threading.Event() + + def cancel_after_delay(): + time.sleep(0.5) + cancel.set() + + timer = threading.Thread(target=cancel_after_delay, daemon=True) + timer.start() + + fresh_client.publish_and_run(slug, force=True) + result = fresh_client.get_results(slug, poll_interval=60, cancel_event=cancel) + timer.join(timeout=5) + + assert isinstance(result, RunResult) + assert result.status == "cancelled" + + +# =========================================================================== +# ERROR HANDLING +# +# What happens when things go wrong on Kaggle. +# =========================================================================== + + +class TestErrorHandling: + """Tests for Kaggle-side execution errors.""" + + def test_crashing_script_returns_error(self, client): + """A script that raises at runtime should produce status='error'.""" + _, slug = _prepare_workspace( + client, script='# %%\nraise ValueError("Deliberate crash")\n' + ) + + client.publish_and_run(slug, force=True) + result = client.get_results(slug, poll_interval=20, timeout=600) + + assert result.status == "error" + assert result.error is not None + assert "finished with status: error" in result.error diff --git a/pyproject.toml b/pyproject.toml index e7ce6b9..87611b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,8 @@ dependencies = [ [project.optional-dependencies] kaggle-client = [ "jupytext", + "kagglesdk>=0.1.14,<1.0", + "kagglehub>=0.3.10", ] # Configuration for the build backend to find the version file diff --git a/src/kaggle_benchmarks/kaggle_client/__init__.py b/src/kaggle_benchmarks/kaggle_client/__init__.py index 22943f7..f2208ec 100644 --- a/src/kaggle_benchmarks/kaggle_client/__init__.py +++ b/src/kaggle_benchmarks/kaggle_client/__init__.py @@ -11,3 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from kaggle_benchmarks.kaggle_client.notebook_api import ( + BenchmarkNotebookClient, + ConcurrentRunError, + KaggleAuthError, + RunResult, +) + +__all__ = [ + "BenchmarkNotebookClient", + "ConcurrentRunError", + "KaggleAuthError", + "RunResult", +] diff --git a/src/kaggle_benchmarks/kaggle_client/notebook_api.py b/src/kaggle_benchmarks/kaggle_client/notebook_api.py new file mode 100644 index 0000000..5242968 --- /dev/null +++ b/src/kaggle_benchmarks/kaggle_client/notebook_api.py @@ -0,0 +1,639 @@ +# Copyright 2026 Kaggle Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core client for the Kaggle benchmark notebook workflow.""" + +import json +import logging +import os +import shutil +import tarfile +import tempfile +import threading +import time +import zipfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Iterator + +from kaggle_benchmarks.kaggle_client import utils as kaggle_utils + +logger = logging.getLogger(__name__) + + +class KaggleAuthError(RuntimeError): + """Raised when Kaggle authentication fails or credentials are invalid.""" + + +class ConcurrentRunError(RuntimeError): + """Raised when a notebook is already running and force=False.""" + + +@dataclass +class RunResult: + """Result of a benchmark notebook execution. + + Attributes: + status: One of "queued", "running", "complete", "error", + "cancelled", or "timeout". + output_dir: String path to the output directory (not Path, + for JSON serialization). None if no output available. + tracking_url: URL to the notebook on Kaggle. + error: Error message (if any). + """ + + status: str + output_dir: str | None + tracking_url: str | None + error: str | None = None + + def iter_run_results(self) -> Iterator[tuple[str, dict[str, Any]]]: + """Yields (filename, parsed_data) for all *run.json files in the output directory.""" + if not self.output_dir or not (output_path := Path(self.output_dir)).exists(): + return + + for run_file in output_path.glob("*run.json"): + yield run_file.name, json.loads(run_file.read_text(encoding="utf-8")) + + +# ============================================================================= +# Module-level helper functions (authentication) +# ============================================================================= + +_AUTH_ERROR_INSTRUCTIONS = ( + "To authenticate, use one of the following methods (in priority order):\n\n" + " Option 1: Bearer token (recommended)\n" + " export KAGGLE_API_TOKEN=\n" + " Or save your token to ~/.kaggle/access_token\n\n" + " Option 2: Basic auth via environment variables\n" + " export KAGGLE_USERNAME=\n" + " export KAGGLE_KEY=\n\n" + " Option 3: Basic auth via credentials file\n" + " Save kaggle.json to ~/.kaggle/kaggle.json\n\n" + "Get your API token at: https://www.kaggle.com/settings\n" + "Docs: https://github.com/Kaggle/kagglehub/tree/main#authenticate" +) + + +def _authenticate() -> tuple[Any, str]: + """Authenticate with the Kaggle API and resolve the username. + + Credential resolution relies on the kagglesdk for token introspection + and kagglehub for standard Basic auth/json fallbacks. + + Returns: + A tuple of (KaggleClient, username). + + Raises: + KaggleAuthError: If the SDK is missing, no credentials are found, + or the username cannot be determined. + """ + try: + from kagglesdk.kaggle_client import KaggleClient + from kagglesdk.kaggle_env import get_access_token_from_env, get_env + from kagglesdk.security.types.oauth_service import IntrospectTokenRequest + except ImportError as e: + raise KaggleAuthError( + "The 'kagglesdk' package is required. Install with:\n" + " pip install kaggle-benchmarks[kaggle-client]" + ) from e + + try: + client = KaggleClient(env=get_env()) + except Exception as e: + raise KaggleAuthError( + f"Kaggle authentication failed: {e}\n\n{_AUTH_ERROR_INSTRUCTIONS}" + ) from e + + def get_token_user() -> str | None: + if api_token := get_access_token_from_env()[0]: + try: + req = IntrospectTokenRequest() + req.token = api_token + resp = client.security.oauth_client.introspect_token(req) + return resp.username if resp.active else None + except Exception: + pass + return None + + def get_hub_user() -> str | None: + try: + import kagglehub + + return ( + creds.username + if (creds := kagglehub.config.get_kaggle_credentials()) + else None + ) + except Exception: + return None + + username = client.username or get_token_user() or get_hub_user() + + if not username: + raise KaggleAuthError( + f"No Kaggle credentials found or invalid.\n\n{_AUTH_ERROR_INSTRUCTIONS}" + ) + + return client, username + + +# ============================================================================= +# Client class +# ============================================================================= + + +class BenchmarkNotebookClient: + """Client for the Kaggle benchmark notebook workflow. + + Wraps the Kaggle SDK (kagglesdk) to handle publishing, running, + and retrieving results from benchmark notebooks (tagged with + 'personal-benchmark' keyword). + """ + + BENCHMARK_FILENAME = "benchmark.py" + NOTEBOOK_FILENAME = "benchmark.ipynb" + METADATA_FILENAME = "kernel-metadata.json" + OUTPUT_DIRNAME = "output" + + # Retry config for initial 404s from get_kernel_session_status after save_kernel. + _STATUS_RETRIES = 5 + _STATUS_RETRY_WAIT = 10 # seconds + + # ========================================================================= + # Authentication & Identity + # ========================================================================= + + def __init__( + self, + base_dir: str | Path = ".", + ): + """Initialize the BenchmarkNotebookClient. + + Args: + base_dir: Parent directory for benchmark workspaces. + """ + self.api, self.username = _authenticate() + self.base_dir = Path(base_dir) + + # ========================================================================= + # Forking + # ========================================================================= + + def fork( + self, + source_notebook_id: str, + dest_notebook_slug: str | None = None, + overwrite: bool = False, + ) -> Path: + """Pull an existing benchmark from Kaggle as a starting point. + + Workspace: // + + Downloads via get_kernel() and converts: + 1. Pulls the .ipynb source and kernel metadata + 2. Writes the source as benchmark.ipynb + 3. Converts benchmark.ipynb -> benchmark.py with # %% cell delimiters + 4. Reconstructs kernel-metadata.json from API response + + The kernel-metadata.json is preserved so that publish_and_run() + can reuse it. + + Args: + source_notebook_id: Full Kaggle notebook path including owner + (e.g., 'alice/riddle-benchmark'). + dest_notebook_slug: Local name for the benchmark directory. + Defaults to the basename of source_notebook_id. + overwrite: If True, replace the existing workspace directory. + If False (default), raises FileExistsError. + + Returns: + Path to the workspace directory containing the .py file. + """ + if dest_notebook_slug is None: + dest_notebook_slug = source_notebook_id.split("/")[-1] + + workspace = self._workspace(dest_notebook_slug) + + if workspace.exists(): + if not overwrite: + raise FileExistsError( + f"Workspace already exists: {workspace}. " + "Use overwrite=True to replace it." + ) + shutil.rmtree(workspace) + + workspace.mkdir(parents=True, exist_ok=True) + + # Pull notebook source and metadata from Kaggle via get_kernel() + from kagglesdk.kernels.types.kernels_api_service import ( + ApiGetKernelRequest, + ) + from requests.exceptions import HTTPError + + try: + req = ApiGetKernelRequest() + req.user_name, req.kernel_slug = source_notebook_id.split("/") + response = self.api.kernels.kernels_api_client.get_kernel(req) + except HTTPError as e: + raise ValueError( + f"Failed to pull notebook '{source_notebook_id}'. " + "Ensure the notebook exists, is public (or you have access), " + "and you have accepted any necessary competition rules." + ) from e + + # Write the notebook source to benchmark.ipynb + # response.blob.source is exactly the raw .ipynb JSON for notebook-type kernels + if response.blob and response.blob.source: + ipynb_path = workspace / self.NOTEBOOK_FILENAME + ipynb_path.write_text(response.blob.source, encoding="utf-8") + + # Convert .ipynb to .py with # %% cell delimiters + py_path = workspace / self.BENCHMARK_FILENAME + kaggle_utils.convert_ipynb_to_py(ipynb_path, py_path) + else: + logger.warning( + "No source found after pulling '%s'. " + "The notebook may be a script notebook.", + source_notebook_id, + ) + + # Reconstruct kernel-metadata.json from response.metadata + if response.metadata: + metadata = kaggle_utils.parse_remote_metadata( + meta=response.metadata, + default_id=source_notebook_id, + default_slug=dest_notebook_slug, + ) + meta_path = workspace / self.METADATA_FILENAME + meta_path.write_text(json.dumps(metadata, indent=4), encoding="utf-8") + + return workspace + + # ========================================================================= + # Publish & Run + # ========================================================================= + + def publish_and_run( + self, + notebook_slug: str, + source_file: str | Path | None = None, + dataset_sources: list[str] | None = None, + force: bool = False, + ) -> str: + """Convert .py -> .ipynb, push to Kaggle, and trigger execution. + + Workspace: // + + A notebook acts as the execution vehicle for your benchmark tasks. While you + can define and run multiple tasks within a single notebook, the Kaggle + leaderboard currently requires a single "main" task per notebook to be saved + and evaluated. You can use the `%choose ` magic command at the end + of your notebook to select which task's results to publish. + + For more details on task selection, see: + https://github.com/Kaggle/kaggle-benchmarks/blob/ci/quick_start.md#82-using-choose-to-select-a-notebooks-task + + Args: + notebook_slug: Short notebook name (e.g., 'my-benchmark'). + source_file: Optional path to a .py file to copy into workspace. It will replace benchmark.py. + dataset_sources: Optional list of Kaggle dataset slugs (e.g., ``["owner/dataset-name"]``) + to mount at ``/kaggle/input/``. + force: If True, push even if a previous run is in progress. + + Returns: + Tracking URL for the running notebook. + + Raises: + FileNotFoundError: If benchmark.py not found in workspace. + ConcurrentRunError: If a concurrent run is detected and force=False. + """ + from requests.exceptions import HTTPError + + workspace = self._workspace(notebook_slug) + workspace.mkdir(parents=True, exist_ok=True) + benchmark_py = workspace / self.BENCHMARK_FILENAME + + # Copy source_file into workspace (if provided) + if source_file is not None: + source = Path(source_file) + if not source.exists(): + raise FileNotFoundError(f"Source file not found: {source}") + shutil.copy2(source, benchmark_py) + + if not benchmark_py.exists(): + raise FileNotFoundError( + f"Benchmark file not found: {benchmark_py}. " + f"Create it or pass source_file to copy from." + ) + + # Resolve metadata (load existing kernel-metadata.json or generate new) + # 'personal-benchmark' keyword is ensured by resolve_metadata + metadata = kaggle_utils.build_local_metadata( + workspace_dir=workspace, + notebook_slug=notebook_slug, + username=self.username, + dataset_sources=dataset_sources, + ) + + # Convert benchmark.py (.py with # %% delimiters) to .ipynb + notebook_path = workspace / self.NOTEBOOK_FILENAME + kaggle_utils.convert_py_to_ipynb(benchmark_py, notebook_path) + + # Concurrent run guard + notebook_id = self._notebook_id(notebook_slug) + if not force: + try: + status = self._get_kernel_session_status(notebook_id) + if status in ("queued", "running"): + raise ConcurrentRunError( + f"Notebook '{notebook_id}' is already running " + f"(status: {status}). " + "Use force=True to push a new version anyway." + ) + except HTTPError as e: + if e.response.status_code == 404: + pass # No existing notebook — safe to push + else: + raise + + # Write metadata to disk (for local inspection and fork reuse) + meta_path = workspace / self.METADATA_FILENAME + meta_path.write_text(json.dumps(metadata, indent=4), encoding="utf-8") + + # Read the generated .ipynb and push inline via save_kernel + notebook_content = notebook_path.read_text(encoding="utf-8") + req = self._build_save_request( + notebook_id=self._notebook_id(notebook_slug), + notebook_content=notebook_content, + metadata=metadata, + ) + + response = self.api.kernels.kernels_api_client.save_kernel(req) + if response.error: + raise RuntimeError(f"Kaggle push failed: {response.error}") + + return self._tracking_url(notebook_slug) + + def _build_save_request( + self, + notebook_id: str, + notebook_content: str, + metadata: dict[str, Any], + ): + """Build the API request to save/push a notebook.""" + from kagglesdk.kernels.types.kernels_api_service import ApiSaveKernelRequest + + req = ApiSaveKernelRequest() + + # Special required fields + req.slug = notebook_id + req.new_title = metadata.get("title", notebook_id.split("/")[-1]) + req.text = notebook_content + + # Map the remaining configured attributes dynamically + for json_key, (api_key, _) in kaggle_utils.KAGGLE_METADATA_MAP.items(): + if json_key in metadata: + setattr(req, api_key, metadata[json_key]) + + return req + + # ========================================================================= + # Polling & Results + # ========================================================================= + + def get_results( + self, + notebook_slug: str, + output_dir: str | None = None, + poll_interval: int = 60, + timeout: int | None = None, + cancel_event: threading.Event | None = None, + on_status: Callable[[str], None] | None = None, + clear_output: bool = True, + ) -> RunResult: + """Poll execution status and download output files. + + Workspace: // + + Polls get_kernel_session_status() until the notebook completes (or fails), + then downloads output via download_kernel_output(). + + Neither timeout nor cancel_event stops the Kaggle run itself — + the notebook continues executing on Kaggle. They only stop + the local polling. + + Args: + notebook_slug: Short notebook name (e.g., 'my-benchmark'). + output_dir: Where to save output files (defaults to + //output/). + poll_interval: Seconds between status checks. + timeout: Maximum seconds to wait before returning + RunResult(status="timeout"). None means wait indefinitely. + cancel_event: A threading.Event; when set, the poll loop exits + with RunResult(status="cancelled"). + on_status: Callback invoked when the notebook status changes. + clear_output: If True (default), clear the output directory + before downloading new output. If False, download into + the existing directory without clearing. + + Returns: + RunResult with status, output path, and parsed run.json data. + """ + notebook_id = self._notebook_id(notebook_slug) + workspace = self._workspace(notebook_slug) + tracking_url = self._tracking_url(notebook_slug) + + if output_dir is None: + output_path = workspace / self.OUTPUT_DIRNAME + else: + output_path = Path(output_dir) + + status_str = self._wait_for_notebook_creation(notebook_id) + if status_str is None: + return RunResult( + status="error", + output_dir=None, + tracking_url=tracking_url, + error=( + f"Failed to find notebook '{notebook_id}' " + f"after {self._STATUS_RETRIES} retries." + ), + ) + + final_status = self._poll_notebook_status( + notebook_id=notebook_id, + initial_status=status_str, + poll_interval=poll_interval, + timeout=timeout, + cancel_event=cancel_event, + on_status=on_status, + ) + + if final_status != "complete": + error_msg = None + if final_status not in ("timeout", "cancelled"): + error_msg = f"Notebook execution finished with status: {final_status}" + + return RunResult( + status=final_status, + output_dir=None, + tracking_url=tracking_url, + error=error_msg, + ) + + self._download_notebook_output(notebook_id, output_path, clear_output) + + return RunResult( + status="complete", + output_dir=str(output_path), + tracking_url=tracking_url, + ) + + def _wait_for_notebook_creation(self, notebook_id: str) -> str | None: + """Wait for Kaggle to index a newly pushed notebook. + + Returns: + The initial status string, or None if retries exceeded. + """ + from requests.exceptions import HTTPError + + for attempt in range(self._STATUS_RETRIES): + try: + return self._get_kernel_session_status(notebook_id) + except HTTPError as e: + if e.response.status_code == 404: + logger.info( + "Notebook not found yet (attempt %d/%d), waiting...", + attempt + 1, + self._STATUS_RETRIES, + ) + time.sleep(self._STATUS_RETRY_WAIT) + else: + raise + return None + + def _poll_notebook_status( + self, + notebook_id: str, + initial_status: str, + poll_interval: int, + timeout: int | None, + cancel_event: threading.Event | None, + on_status: Callable[[str], None] | None, + ) -> str: + """Poll notebook status until complete, error, cancelled, or timeout.""" + status_str = initial_status + last_reported_status = None + start_time = time.monotonic() + + while status_str not in ("complete", "error", "cancelled"): + if cancel_event is not None and cancel_event.is_set(): + return "cancelled" + + if timeout is not None and (time.monotonic() - start_time) >= timeout: + return "timeout" + + if on_status is not None and status_str != last_reported_status: + on_status(status_str) + last_reported_status = status_str + + # Wait with cancel_event support for early exit + if cancel_event is not None: + cancel_event.wait(poll_interval) + if cancel_event.is_set(): + return "cancelled" + else: + time.sleep(poll_interval) + + status_str = self._get_kernel_session_status(notebook_id) + + # Report final status + if on_status is not None and status_str != last_reported_status: + on_status(status_str) + + return status_str + + def _download_notebook_output( + self, notebook_id: str, output_path: Path, clear_output: bool + ) -> None: + """Clear existing output (if requested) and download new output from Kaggle.""" + from kagglesdk.kernels.types.kernels_api_service import ( + ApiDownloadKernelOutputRequest, + ) + + if clear_output and output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + # Build request (no file_path -> downloads all output as archive) + req = ApiDownloadKernelOutputRequest() + req.owner_slug, req.kernel_slug = notebook_id.split("/") + + # download_kernel_output returns a streamed requests.Response + response = self.api.kernels.kernels_api_client.download_kernel_output(req) + + # Save the archive to a temp file, then extract + with tempfile.NamedTemporaryFile(delete=False, suffix=".archive") as tmp: + for chunk in response.iter_content(chunk_size=8192): + tmp.write(chunk) + archive_path = tmp.name + + try: + if tarfile.is_tarfile(archive_path): + with tarfile.open(archive_path) as tf: + tf.extractall(output_path) + elif zipfile.is_zipfile(archive_path): + with zipfile.ZipFile(archive_path, "r") as zf: + zf.extractall(output_path) + else: + raise RuntimeError("Unexpected archive format from Kaggle API") + finally: + os.remove(archive_path) + + def _get_kernel_session_status(self, notebook_id: str): + """Get the session status for a notebook via kagglesdk. + + Args: + notebook_id: Full notebook ID (format: "username/slug"). + + Returns: + The normalized status string. + """ + from kagglesdk.kernels.types.kernels_api_service import ( + ApiGetKernelSessionStatusRequest, + ) + + user_name, kernel_slug = notebook_id.split("/") + req = ApiGetKernelSessionStatusRequest() + req.user_name = user_name + req.kernel_slug = kernel_slug + response = self.api.kernels.kernels_api_client.get_kernel_session_status(req) + return kaggle_utils.normalize_status(response) + + # ========================================================================= + # Internal Utilities + # ========================================================================= + + def _workspace(self, notebook_slug: str) -> Path: + """Return the workspace directory for a notebook slug.""" + return self.base_dir / notebook_slug + + def _notebook_id(self, notebook_slug: str) -> str: + """Return the full notebook ID (username/slug).""" + return f"{self.username}/{notebook_slug}" + + def _tracking_url(self, notebook_slug: str) -> str: + """Return the Kaggle tracking URL for a notebook.""" + return f"https://www.kaggle.com/{self._notebook_id(notebook_slug)}" diff --git a/src/kaggle_benchmarks/kaggle_client/utils.py b/src/kaggle_benchmarks/kaggle_client/utils.py index a601206..7c45ed5 100644 --- a/src/kaggle_benchmarks/kaggle_client/utils.py +++ b/src/kaggle_benchmarks/kaggle_client/utils.py @@ -12,12 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Utility functions for the Kaggle benchmark client.""" + import json import re import warnings from pathlib import Path from typing import Any +# --------------------------------------------------------------------------- +# File format conversion +# --------------------------------------------------------------------------- + def convert_py_to_ipynb(py_path: str | Path, ipynb_path: str | Path) -> None: """Converts a Python script in percent format to a Jupyter Notebook. @@ -39,6 +45,13 @@ def convert_py_to_ipynb(py_path: str | Path, ipynb_path: str | Path) -> None: ) notebook = jupytext.reads(content, fmt="py:percent") + + # Kaggle's notebook runner (papermill) requires a kernelspec to evaluate cells. + notebook.metadata.setdefault( + "kernelspec", + {"display_name": "Python 3", "language": "python", "name": "python3"}, + ) + jupytext.write(notebook, ipynb_path) @@ -50,7 +63,30 @@ def convert_ipynb_to_py(ipynb_path: str | Path, py_path: str | Path) -> None: jupytext.write(notebook, py_path, fmt="py:percent") -def resolve_metadata( +# --------------------------------------------------------------------------- +# Metadata +# --------------------------------------------------------------------------- + + +# Maps kernel-metadata.json keys to (ApiSaveKernelRequest attribute, default value) +KAGGLE_METADATA_MAP = { + "language": ("language", "python"), + "kernel_type": ("kernel_type", "notebook"), + "is_private": ("is_private", True), + "enable_gpu": ("enable_gpu", False), + "enable_tpu": ("enable_tpu", False), + "enable_internet": ("enable_internet", True), + "dataset_sources": ("dataset_data_sources", []), + "competition_sources": ("competition_data_sources", []), + "kernel_sources": ("kernel_data_sources", []), + "model_sources": ("model_data_sources", []), + "keywords": ("category_ids", []), + "docker_image": ("docker_image", None), + "machine_shape": ("machine_shape", "None"), +} + + +def build_local_metadata( workspace_dir: Path | str, notebook_slug: str, username: str, @@ -59,32 +95,18 @@ def resolve_metadata( title: str | None = None, **kwargs, ) -> dict[str, Any]: - """Assembles the kernel-metadata.json payload for Kaggle. + """Builds the local `kernel-metadata.json` dictionary for pushing to Kaggle. - Loads existing metadata if present, applies overrides, and ensures - mandatory fields/tags are set. + Loads existing workspace metadata, merges runtime CLI overrides, and enforces + required benchmark schemas (such as the 'personal-benchmark' tag). """ meta_path = Path(workspace_dir) / "kernel-metadata.json" - if meta_path.exists(): - metadata = json.loads(meta_path.read_text(encoding="utf-8")) - else: - # Default metadata template for new notebooks. - # See: https://github.com/Kaggle/kaggle-cli/blob/main/docs/kernels_metadata.md - metadata = { - "language": "python", - "kernel_type": "notebook", - "enable_gpu": False, - "enable_tpu": False, - "enable_internet": True, # Required for LLM API calls - "dataset_sources": [], # Kaggle datasets to mount at /kaggle/input/ - "competition_sources": [], - "kernel_sources": [], - "model_sources": [], - "keywords": [], # Tags; we ensure "personal-benchmark" is added - "docker_image": None, # Custom docker image (future use) - "machine_shape": "None", # "None" = default, "gpu", "tpu" etc. - } + metadata = ( + json.loads(meta_path.read_text(encoding="utf-8")) if meta_path.exists() else {} + ) + for json_key, (_, default_val) in KAGGLE_METADATA_MAP.items(): + metadata.setdefault(json_key, default_val) # --- Mandatory overrides --- overrides = { @@ -109,8 +131,48 @@ def resolve_metadata( metadata.update(overrides) # Ensure "personal-benchmark" tag is present. - keywords = metadata.setdefault("keywords", []) - if "personal-benchmark" not in keywords: + if "personal-benchmark" not in (keywords := metadata.setdefault("keywords", [])): keywords.append("personal-benchmark") return metadata + + +def parse_remote_metadata( + meta: Any, default_id: str, default_slug: str +) -> dict[str, Any]: + """Converts a Kaggle API `Kernel` object into a local `kernel-metadata.json` dictionary. + + Translates SDK-specific attribute names (like `dataset_data_sources`) back into + standard Kaggle JSON fields to allow for local editing on disk. + """ + meta_dict = { + "id": getattr(meta, "ref", default_id), + "title": getattr(meta, "title", default_slug), + } + + for json_key, (api_key, default_val) in KAGGLE_METADATA_MAP.items(): + val = getattr(meta, api_key, None) + # Handle repeated enum/list types cleanly by defaulting to [] + if isinstance(default_val, list): + meta_dict[json_key] = list(val or []) + else: + meta_dict[json_key] = val if val is not None else default_val + + return meta_dict + + +# --------------------------------------------------------------------------- +# Status normalization +# --------------------------------------------------------------------------- + + +def normalize_status(status: Any) -> str: + """Normalize a Kaggle notebook status to a lowercase string. + + The Kaggle API may return a KernelWorkerStatus enum + (e.g. "KernelWorkerStatus.complete") or a plain string. + This method normalizes both to a simple lowercase string + like "complete". + """ + status_raw = getattr(status, "status", status) + return str(status_raw).lower().split(".")[-1] diff --git a/tests/kaggle_client/test_kaggle_client_utils.py b/tests/kaggle_client/test_kaggle_client_utils.py index 3617261..e71b6bc 100644 --- a/tests/kaggle_client/test_kaggle_client_utils.py +++ b/tests/kaggle_client/test_kaggle_client_utils.py @@ -13,14 +13,15 @@ # limitations under the License. import json +from unittest.mock import MagicMock import pytest -from kaggle_benchmarks.kaggle_client.utils import ( - convert_ipynb_to_py, - convert_py_to_ipynb, - resolve_metadata, -) +from kaggle_benchmarks.kaggle_client import utils as kaggle_utils + +# --------------------------------------------------------------------------- +# File format conversion +# --------------------------------------------------------------------------- def test_convert_py_to_ipynb(tmp_path): @@ -35,7 +36,7 @@ def test_convert_py_to_ipynb(tmp_path): """ py_file.write_text(content) - convert_py_to_ipynb(py_file, ipynb_file) + kaggle_utils.convert_py_to_ipynb(py_file, ipynb_file) assert ipynb_file.exists() with open(ipynb_file, "r") as f: @@ -45,6 +46,14 @@ def test_convert_py_to_ipynb(tmp_path): assert notebook["cells"][0]["cell_type"] == "markdown" assert notebook["cells"][1]["cell_type"] == "code" + # Verify the kernelspec is added so papermill can run the notebook + assert "kernelspec" in notebook["metadata"] + assert notebook["metadata"]["kernelspec"] == { + "display_name": "Python 3", + "language": "python", + "name": "python3", + } + def test_convert_py_to_ipynb_warning(tmp_path): py_file = tmp_path / "benchmark.py" @@ -56,7 +65,7 @@ def test_convert_py_to_ipynb_warning(tmp_path): py_file.write_text(content) with pytest.warns(UserWarning, match="has no '# %%' cell delimiters"): - convert_py_to_ipynb(py_file, ipynb_file) + kaggle_utils.convert_py_to_ipynb(py_file, ipynb_file) assert ipynb_file.exists() @@ -84,7 +93,7 @@ def test_convert_ipynb_to_py(tmp_path): with open(ipynb_file, "w") as f: json.dump(notebook, f) - convert_ipynb_to_py(ipynb_file, py_file) + kaggle_utils.convert_ipynb_to_py(ipynb_file, py_file) assert py_file.exists() content = py_file.read_text() @@ -94,12 +103,44 @@ def test_convert_ipynb_to_py(tmp_path): assert "print('Hello')" in content -def test_resolve_metadata_new(tmp_path): +def test_roundtrip_conversion(tmp_path): + """Test that py -> ipynb -> py preserves cell structure.""" + py_file = tmp_path / "benchmark.py" + ipynb_file = tmp_path / "benchmark.ipynb" + py_roundtrip = tmp_path / "benchmark_roundtrip.py" + + content = """# %% [markdown] +# # Title + +# %% +print("Hello") + +# %% +x = 42 +""" + py_file.write_text(content) + + kaggle_utils.convert_py_to_ipynb(py_file, ipynb_file) + kaggle_utils.convert_ipynb_to_py(ipynb_file, py_roundtrip) + + roundtrip_content = py_roundtrip.read_text() + assert "# %% [markdown]" in roundtrip_content + assert "# # Title" in roundtrip_content + assert 'print("Hello")' in roundtrip_content + assert "x = 42" in roundtrip_content + + +# --------------------------------------------------------------------------- +# build_local_metadata +# --------------------------------------------------------------------------- + + +def test_build_local_metadata_new(tmp_path): workspace = tmp_path slug = "my-bench" username = "alice" - metadata = resolve_metadata(workspace, slug, username) + metadata = kaggle_utils.build_local_metadata(workspace, slug, username) assert metadata["id"] == "alice/my-bench" assert metadata["title"] == "my-bench" # Defaults to slug @@ -110,19 +151,21 @@ def test_resolve_metadata_new(tmp_path): assert metadata["docker_image"] is None -def test_resolve_metadata_custom_title(tmp_path): +def test_build_local_metadata_custom_title(tmp_path): """Test that title can be customized separately from slug.""" workspace = tmp_path slug = "my-bench" username = "alice" - metadata = resolve_metadata(workspace, slug, username, title="My Awesome Benchmark") + metadata = kaggle_utils.build_local_metadata( + workspace, slug, username, title="My Awesome Benchmark" + ) assert metadata["id"] == "alice/my-bench" # slug used for id assert metadata["title"] == "My Awesome Benchmark" # custom title -def test_resolve_metadata_existing(tmp_path): +def test_build_local_metadata_existing(tmp_path): workspace = tmp_path slug = "my-bench" username = "alice" @@ -138,7 +181,7 @@ def test_resolve_metadata_existing(tmp_path): with open(workspace / "kernel-metadata.json", "w") as f: json.dump(existing, f) - metadata = resolve_metadata( + metadata = kaggle_utils.build_local_metadata( workspace, slug, username, dataset_sources=["alice/data"] ) @@ -151,12 +194,12 @@ def test_resolve_metadata_existing(tmp_path): assert metadata["dataset_sources"] == ["alice/data"] -def test_resolve_metadata_overrides(tmp_path): +def test_build_local_metadata_overrides(tmp_path): workspace = tmp_path slug = "my-bench" username = "alice" - metadata = resolve_metadata( + metadata = kaggle_utils.build_local_metadata( workspace, slug, username, @@ -174,8 +217,8 @@ def test_resolve_metadata_overrides(tmp_path): assert metadata["enable_tpu"] is False # Default -def test_resolve_metadata_idempotent_keywords(tmp_path): - """Test that calling resolve_metadata doesn't duplicate 'personal-benchmark'.""" +def test_build_local_metadata_idempotent_keywords(tmp_path): + """Test that calling build_local_metadata doesn't duplicate 'personal-benchmark'.""" workspace = tmp_path slug = "my-bench" username = "alice" @@ -186,43 +229,163 @@ def test_resolve_metadata_idempotent_keywords(tmp_path): with open(workspace / "kernel-metadata.json", "w") as f: json.dump(existing, f) - metadata = resolve_metadata(workspace, slug, username) + metadata = kaggle_utils.build_local_metadata(workspace, slug, username) assert metadata["keywords"].count("personal-benchmark") == 1 assert "foo" in metadata["keywords"] -def test_resolve_metadata_malformed_json(tmp_path): +def test_build_local_metadata_malformed_json(tmp_path): """Test behavior when kernel-metadata.json contains invalid JSON.""" workspace = tmp_path (workspace / "kernel-metadata.json").write_text("not valid json") with pytest.raises(json.JSONDecodeError): - resolve_metadata(workspace, "my-bench", "alice") + kaggle_utils.build_local_metadata(workspace, "my-bench", "alice") -def test_roundtrip_conversion(tmp_path): - """Test that py -> ipynb -> py preserves cell structure.""" - py_file = tmp_path / "benchmark.py" - ipynb_file = tmp_path / "benchmark.ipynb" - py_roundtrip = tmp_path / "benchmark_roundtrip.py" +# --------------------------------------------------------------------------- +# parse_remote_metadata +# --------------------------------------------------------------------------- - content = """# %% [markdown] -# # Title -# %% -print("Hello") +def test_parse_remote_metadata_from_api_response(): + """Extracts all fields from a fully populated API metadata object.""" + meta = MagicMock( + ref="alice/my-benchmark", + title="My Benchmark", + language="python", + kernel_type="notebook", + is_private=False, + enable_gpu=True, + enable_internet=True, + enable_tpu=False, + dataset_data_sources=["alice/dataset"], + competition_data_sources=["comp1"], + kernel_data_sources=["alice/kernel"], + model_data_sources=["alice/model"], + category_ids=["personal-benchmark", "nlp"], + docker_image="custom-image:latest", + machine_shape="T4", + ) -# %% -x = 42 -""" - py_file.write_text(content) + result = kaggle_utils.parse_remote_metadata( + meta, default_id="fallback/id", default_slug="fallback" + ) - convert_py_to_ipynb(py_file, ipynb_file) - convert_ipynb_to_py(ipynb_file, py_roundtrip) + assert result == { + "id": "alice/my-benchmark", + "title": "My Benchmark", + "language": "python", + "kernel_type": "notebook", + "is_private": False, + "enable_gpu": True, + "enable_tpu": False, + "enable_internet": True, + "dataset_sources": ["alice/dataset"], + "competition_sources": ["comp1"], + "kernel_sources": ["alice/kernel"], + "model_sources": ["alice/model"], + "keywords": ["personal-benchmark", "nlp"], + "docker_image": "custom-image:latest", + "machine_shape": "T4", + } - roundtrip_content = py_roundtrip.read_text() - assert "# %% [markdown]" in roundtrip_content - assert "# # Title" in roundtrip_content - assert 'print("Hello")' in roundtrip_content - assert "x = 42" in roundtrip_content + +def test_parse_remote_metadata_uses_defaults_for_missing_attrs(): + """Falls back to defaults when API metadata has missing attributes.""" + meta = MagicMock(spec=[]) # spec=[] means no attributes exist + + result = kaggle_utils.parse_remote_metadata( + meta, default_id="owner/slug", default_slug="my-slug" + ) + + assert result == { + "id": "owner/slug", + "title": "my-slug", + "language": "python", + "kernel_type": "notebook", + "is_private": True, + "enable_gpu": False, + "enable_tpu": False, + "enable_internet": True, + "dataset_sources": [], + "competition_sources": [], + "kernel_sources": [], + "model_sources": [], + "keywords": [], + "docker_image": None, + "machine_shape": "None", + } + + +def test_parse_remote_metadata_handles_none_lists(): + """Converts None list fields to empty lists.""" + meta = MagicMock( + ref="alice/bench", + title="Bench", + language="python", + kernel_type="notebook", + is_private=True, + enable_gpu=False, + enable_tpu=False, + enable_internet=True, + dataset_data_sources=None, + competition_data_sources=None, + kernel_data_sources=None, + model_data_sources=None, + category_ids=None, + docker_image=None, + machine_shape=None, + ) + + result = kaggle_utils.parse_remote_metadata( + meta, default_id="x/y", default_slug="y" + ) + + assert result == { + "id": "alice/bench", + "title": "Bench", + "language": "python", + "kernel_type": "notebook", + "is_private": True, + "enable_gpu": False, + "enable_tpu": False, + "enable_internet": True, + "dataset_sources": [], + "competition_sources": [], + "kernel_sources": [], + "model_sources": [], + "keywords": [], + "docker_image": None, + "machine_shape": "None", + } + + +# --------------------------------------------------------------------------- +# normalize_status +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "raw_status, expected", + [ + ("complete", "complete"), + ("Complete", "complete"), + ("KernelWorkerStatus.complete", "complete"), + ("kernelworkerstatus.running", "running"), + ("error", "error"), + ], +) +def test_normalize_status_strings(raw_status, expected): + """normalize_status should strip enum prefixes and lower-case.""" + assert kaggle_utils.normalize_status(raw_status) == expected + + +def test_normalize_status_object_with_attribute(): + """Simulates an API response object with a .status attribute.""" + + class FakeStatus: + status = "running" + + assert kaggle_utils.normalize_status(FakeStatus()) == "running" diff --git a/tests/kaggle_client/test_notebook_api.py b/tests/kaggle_client/test_notebook_api.py new file mode 100644 index 0000000..0d034e4 --- /dev/null +++ b/tests/kaggle_client/test_notebook_api.py @@ -0,0 +1,834 @@ +# Copyright 2026 Kaggle Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import json +import tarfile +import threading +import time +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from requests.exceptions import HTTPError + +from kaggle_benchmarks.kaggle_client.notebook_api import ( + BenchmarkNotebookClient, + ConcurrentRunError, + KaggleAuthError, + RunResult, + _authenticate, +) + +_API_MOD = "kaggle_benchmarks.kaggle_client.notebook_api" + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_api(): + """Create a mock KaggleClient (kagglesdk).""" + api = MagicMock() + # Set username for Basic auth scenarios + api.username = "testuser" + api.api_token = None + return api + + +@pytest.fixture +def client(mock_api, tmp_path): + """Create a BenchmarkNotebookClient with a mocked API.""" + with patch(f"{_API_MOD}._authenticate", return_value=(mock_api, "testuser")): + return BenchmarkNotebookClient(base_dir=tmp_path) + + +def _make_http_error(status_code): + """Create a mock HTTPError with the given response status code.""" + response = MagicMock() + response.status_code = status_code + return HTTPError(response=response) + + +def _make_404_error(): + """Create a mock HTTPError with a 404 response.""" + return _make_http_error(404) + + +def _make_benchmark_py(workspace: Path) -> None: + """Write a minimal benchmark.py with cell delimiters.""" + workspace.mkdir(parents=True, exist_ok=True) + (workspace / "benchmark.py").write_text("# %%\nprint('hello')\n") + + +def _make_save_kernel_response(error=""): + """Create a mock ApiSaveKernelResponse.""" + resp = MagicMock() + resp.error = error + resp.invalid_tags = [] + return resp + + +def _make_status_response(status_str): + """Create a mock ApiGetKernelSessionStatusResponse.""" + + class FakeStatus: + status = status_str + + return FakeStatus() + + +def _make_archive_response(*files): + """Create a mock streamed response with a tar archive containing the given files. + + Args: + *files: Tuples of (filename, content_string) to include in the archive. + """ + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as tf: + for name, content in files: + data = content.encode("utf-8") + info = tarfile.TarInfo(name=name) + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + buf.seek(0) + + response = MagicMock() + response.iter_content = lambda chunk_size=8192: iter([buf.read()]) + return response + + +def _make_get_kernel_response( + source_json=None, metadata_dict=None, ipynb_name="notebook.ipynb" +): + """Create a mock ApiGetKernelResponse. + + Args: + source_json: The notebook source (raw .ipynb JSON dict). If None, uses a default. + metadata_dict: The metadata dict. If None, uses a default. + ipynb_name: Used to construct the default metadata 'ref'. + """ + if source_json is None: + source_json = { + "cells": [ + { + "cell_type": "code", + "source": ["print('hello')\n"], + "metadata": {}, + "outputs": [], + "execution_count": None, + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5, + } + + if metadata_dict is None: + metadata_dict = { + "ref": f"source/{ipynb_name.replace('.ipynb', '')}", + "title": ipynb_name.replace(".ipynb", ""), + "language": "python", + "kernel_type": "notebook", + "is_private": True, + "enable_gpu": False, + "enable_internet": True, + "enable_tpu": False, + "dataset_data_sources": [], + "competition_data_sources": [], + "kernel_data_sources": [], + "model_data_sources": [], + "category_ids": [], + } + + import types + + resp = MagicMock() + resp.blob.source = json.dumps(source_json) + # Use SimpleNamespace so getattr(meta, "missing_field", None) returns None correctly + resp.metadata = types.SimpleNamespace(**metadata_dict) + return resp + + +# --------------------------------------------------------------------------- +# Internal Helpers +# --------------------------------------------------------------------------- + + +def test_notebook_id(client): + """_notebook_id should be username/slug.""" + assert client._notebook_id("my-notebook") == "testuser/my-notebook" + + +def test_tracking_url(client): + """_tracking_url should build a valid Kaggle URL.""" + assert ( + client._tracking_url("my-notebook") + == "https://www.kaggle.com/testuser/my-notebook" + ) + + +def test_workspace_path(client, tmp_path): + """_workspace should return base_dir / slug.""" + ws = client._workspace("my-notebook") + assert ws == tmp_path / "my-notebook" + + +# --------------------------------------------------------------------------- +# publish_and_run +# --------------------------------------------------------------------------- + + +def test_publish_and_run_basic(client, mock_api, tmp_path): + """Happy path: workspace exists, benchmark.py exists, no prior kernel.""" + slug = "test-bench" + _make_benchmark_py(tmp_path / slug) + + # No existing kernel (404) + mock_api.kernels.kernels_api_client.get_kernel_session_status.side_effect = ( + _make_404_error() + ) + mock_api.kernels.kernels_api_client.save_kernel.return_value = ( + _make_save_kernel_response() + ) + + url = client.publish_and_run(slug) + + assert "testuser/test-bench" in url + mock_api.kernels.kernels_api_client.save_kernel.assert_called_once() + + # Verify the save_kernel request has correct slug + call_args = mock_api.kernels.kernels_api_client.save_kernel.call_args + req = call_args[0][0] + assert req.slug == "testuser/test-bench" + + # Verify metadata was written + meta = json.loads((tmp_path / slug / "kernel-metadata.json").read_text()) + assert meta["id"] == "testuser/test-bench" + assert "personal-benchmark" in meta["keywords"] + + # Verify .ipynb was created + assert (tmp_path / slug / "benchmark.ipynb").exists() + + +def test_publish_and_run_with_source_file(client, mock_api, tmp_path): + """Source file is copied into the workspace before processing.""" + slug = "test-bench" + source = tmp_path / "my_script.py" + source.write_text("# %%\nimport kaggle_benchmarks as kbench\n") + + mock_api.kernels.kernels_api_client.get_kernel_session_status.side_effect = ( + _make_404_error() + ) + mock_api.kernels.kernels_api_client.save_kernel.return_value = ( + _make_save_kernel_response() + ) + + _ = client.publish_and_run(slug, source_file=str(source)) + + workspace = tmp_path / slug + assert (workspace / "benchmark.py").exists() + assert (workspace / "benchmark.ipynb").exists() + mock_api.kernels.kernels_api_client.save_kernel.assert_called_once() + + +def test_publish_and_run_dataset_sources(client, mock_api, tmp_path): + """Dataset sources are passed through to metadata.""" + slug = "test-bench" + _make_benchmark_py(tmp_path / slug) + mock_api.kernels.kernels_api_client.get_kernel_session_status.side_effect = ( + _make_404_error() + ) + mock_api.kernels.kernels_api_client.save_kernel.return_value = ( + _make_save_kernel_response() + ) + + client.publish_and_run(slug, dataset_sources=["alice/data", "bob/more-data"]) + + meta = json.loads((tmp_path / slug / "kernel-metadata.json").read_text()) + assert meta["dataset_sources"] == ["alice/data", "bob/more-data"] + + +def test_publish_and_run_concurrent_guard(client, mock_api, tmp_path): + """Raises ConcurrentRunError when notebook is already running.""" + slug = "test-bench" + _make_benchmark_py(tmp_path / slug) + + mock_api.kernels.kernels_api_client.get_kernel_session_status.return_value = ( + _make_status_response("running") + ) + + with pytest.raises(ConcurrentRunError, match="already running"): + client.publish_and_run(slug) + + mock_api.kernels.kernels_api_client.save_kernel.assert_not_called() + + +def test_publish_and_run_concurrent_guard_queued(client, mock_api, tmp_path): + """Raises ConcurrentRunError when notebook is queued.""" + slug = "test-bench" + _make_benchmark_py(tmp_path / slug) + + mock_api.kernels.kernels_api_client.get_kernel_session_status.return_value = ( + _make_status_response("queued") + ) + + with pytest.raises(ConcurrentRunError, match="already running"): + client.publish_and_run(slug) + + +def test_publish_and_run_concurrent_guard_non_404_error(client, mock_api, tmp_path): + """Non-404 HTTPError during status check should propagate as HTTPError.""" + slug = "test-bench" + _make_benchmark_py(tmp_path / slug) + + mock_api.kernels.kernels_api_client.get_kernel_session_status.side_effect = ( + _make_http_error(500) + ) + + with pytest.raises(HTTPError): + client.publish_and_run(slug) + + mock_api.kernels.kernels_api_client.save_kernel.assert_not_called() + + +def test_publish_and_run_force(client, mock_api, tmp_path): + """force=True bypasses the concurrent run guard entirely.""" + slug = "test-bench" + _make_benchmark_py(tmp_path / slug) + + mock_api.kernels.kernels_api_client.save_kernel.return_value = ( + _make_save_kernel_response() + ) + + _ = client.publish_and_run(slug, force=True) + + # get_kernel_session_status should NOT be called (guard skipped) + mock_api.kernels.kernels_api_client.get_kernel_session_status.assert_not_called() + mock_api.kernels.kernels_api_client.save_kernel.assert_called_once() + + +def test_publish_and_run_missing_file(client, tmp_path): + """Raises FileNotFoundError when benchmark.py doesn't exist.""" + with pytest.raises(FileNotFoundError, match="Benchmark file not found"): + client.publish_and_run("nonexistent") + + +def test_publish_and_run_missing_source_file(client, tmp_path): + """Raises FileNotFoundError when source_file doesn't exist.""" + with pytest.raises(FileNotFoundError, match="Source file not found"): + client.publish_and_run("test-bench", source_file="/nonexistent/path.py") + + +# --------------------------------------------------------------------------- +# get_results +# --------------------------------------------------------------------------- + + +def test_get_results_complete_immediately(client, mock_api, tmp_path): + """Kernel is already complete — downloads output immediately.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + + mock_api.kernels.kernels_api_client.get_kernel_session_status.return_value = ( + _make_status_response("complete") + ) + mock_api.kernels.kernels_api_client.download_kernel_output.return_value = ( + _make_archive_response(("run.json", '{"score": 0.95}')) + ) + + result = client.get_results(slug) + + assert result.status == "complete" + assert result.output_dir is not None + assert result.error is None + # Tests that it works with the single legacy `run.json` as well + runs = dict(result.iter_run_results()) + assert len(runs) == 1 + assert runs["run.json"] == {"score": 0.95} + + +def test_get_results_multiple_run_files(client, mock_api, tmp_path): + """Kernel completes and produces multiple *.run.json files.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + + mock_api.kernels.kernels_api_client.get_kernel_session_status.return_value = ( + _make_status_response("complete") + ) + mock_api.kernels.kernels_api_client.download_kernel_output.return_value = ( + _make_archive_response( + ("task_1.run.json", '{"score": 1}'), + ("task_2.run.json", '{"score": 2}'), + ) + ) + + result = client.get_results(slug) + + assert result.status == "complete" + assert result.output_dir is not None + + runs = dict(result.iter_run_results()) + assert len(runs) == 2 + assert runs["task_1.run.json"] == {"score": 1} + assert runs["task_2.run.json"] == {"score": 2} + + +def test_get_results_polls_until_complete(client, mock_api, tmp_path, monkeypatch): + """Polls through queued -> running -> complete.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + monkeypatch.setattr(time, "sleep", lambda s: None) + + statuses = iter( + [ + _make_status_response("queued"), + _make_status_response("running"), + _make_status_response("complete"), + ] + ) + mock_api.kernels.kernels_api_client.get_kernel_session_status.side_effect = ( + lambda *a, **k: next(statuses) + ) + mock_api.kernels.kernels_api_client.download_kernel_output.return_value = ( + _make_archive_response() + ) + + collected = [] + result = client.get_results(slug, poll_interval=0.01, on_status=collected.append) + + assert result.status == "complete" + # Callbacks: "queued" (in-loop), "running" (in-loop), "complete" (post-loop) + assert collected == ["queued", "running", "complete"] + + +def test_get_results_timeout(client, mock_api, tmp_path, monkeypatch): + """Returns timeout when exceeding the timeout limit.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + monkeypatch.setattr(time, "sleep", lambda s: None) + + mock_api.kernels.kernels_api_client.get_kernel_session_status.return_value = ( + _make_status_response("running") + ) + + result = client.get_results(slug, poll_interval=1, timeout=0) + + assert result.status == "timeout" + assert result.output_dir is None + + +def test_get_results_cancel(client, mock_api, tmp_path): + """Returns cancelled when cancel_event is set.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + + mock_api.kernels.kernels_api_client.get_kernel_session_status.return_value = ( + _make_status_response("running") + ) + + cancel_event = threading.Event() + cancel_event.set() # Set immediately + + result = client.get_results(slug, cancel_event=cancel_event) + + assert result.status == "cancelled" + + +def test_get_results_cancel_mid_wait(client, mock_api, tmp_path): + """Returns cancelled when cancel_event fires during poll wait.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + + mock_api.kernels.kernels_api_client.get_kernel_session_status.return_value = ( + _make_status_response("running") + ) + + cancel_event = threading.Event() + + # Set the cancel event after a short delay to trigger during wait() + def cancel_after_delay(): + time.sleep(0.05) + cancel_event.set() + + timer = threading.Thread(target=cancel_after_delay, daemon=True) + timer.start() + + result = client.get_results( + slug, + poll_interval=10, # Long poll — cancel fires during the wait + cancel_event=cancel_event, + ) + timer.join(timeout=5) + + assert result.status == "cancelled" + + +def test_get_results_error_status(client, mock_api, tmp_path): + """Returns error when Kaggle reports an error status.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + + mock_api.kernels.kernels_api_client.get_kernel_session_status.return_value = ( + _make_status_response("error") + ) + + result = client.get_results(slug) + + assert result.status == "error" + assert result.error is not None + assert "error" in result.error + + +def test_get_results_initial_404_retries(client, mock_api, tmp_path, monkeypatch): + """Handles initial 404s after push with retries.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + monkeypatch.setattr(time, "sleep", lambda s: None) + + # First 2 calls return 404, then "complete" + call_count = 0 + + def status_side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise _make_404_error() + return _make_status_response("complete") + + mock_api.kernels.kernels_api_client.get_kernel_session_status.side_effect = ( + status_side_effect + ) + mock_api.kernels.kernels_api_client.download_kernel_output.return_value = ( + _make_archive_response() + ) + + result = client.get_results(slug, poll_interval=0.01) + + assert result.status == "complete" + assert call_count == 3 # 2 retries + 1 success + + +def test_get_results_all_retries_exhausted(client, mock_api, tmp_path, monkeypatch): + """Returns error when all 404 retries are exhausted.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + monkeypatch.setattr(time, "sleep", lambda s: None) + + mock_api.kernels.kernels_api_client.get_kernel_session_status.side_effect = ( + _make_404_error() + ) + + result = client.get_results(slug) + + assert result.status == "error" + assert "retries" in result.error + + +def test_get_results_non_404_error_during_retries( + client, mock_api, tmp_path, monkeypatch +): + """Non-404 HTTPError during initial retry loop should propagate immediately.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + monkeypatch.setattr(time, "sleep", lambda s: None) + + mock_api.kernels.kernels_api_client.get_kernel_session_status.side_effect = ( + _make_http_error(500) + ) + + with pytest.raises(HTTPError): + client.get_results(slug) + + +def test_get_results_on_status_callback(client, mock_api, tmp_path, monkeypatch): + """on_status callback receives intermediate and final statuses but deduplicates repeats.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + monkeypatch.setattr(time, "sleep", lambda s: None) + + statuses = iter( + [ + _make_status_response("running"), + _make_status_response("running"), + _make_status_response("complete"), + ] + ) + mock_api.kernels.kernels_api_client.get_kernel_session_status.side_effect = ( + lambda *a, **k: next(statuses) + ) + mock_api.kernels.kernels_api_client.download_kernel_output.return_value = ( + _make_archive_response() + ) + + collected = [] + result = client.get_results(slug, poll_interval=0.01, on_status=collected.append) + + assert result.status == "complete" + assert collected == ["running", "complete"] + + +def test_get_results_no_run_json(client, mock_api, tmp_path): + """iter_run_results() yields nothing when no run.json is produced.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + + mock_api.kernels.kernels_api_client.get_kernel_session_status.return_value = ( + _make_status_response("complete") + ) + mock_api.kernels.kernels_api_client.download_kernel_output.return_value = ( + _make_archive_response() + ) + + result = client.get_results(slug) + + assert result.status == "complete" + assert not dict(result.iter_run_results()) + + +def test_get_results_custom_output_dir(client, mock_api, tmp_path): + """output_dir parameter overrides the default output path.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + + custom_output = tmp_path / "my_custom_output" + + mock_api.kernels.kernels_api_client.get_kernel_session_status.return_value = ( + _make_status_response("complete") + ) + mock_api.kernels.kernels_api_client.download_kernel_output.return_value = ( + _make_archive_response(("run.json", '{"score": 0.95}')) + ) + + result = client.get_results(slug, output_dir=str(custom_output)) + + assert result.status == "complete" + assert result.output_dir == str(custom_output) + assert custom_output.exists() + + runs = dict(result.iter_run_results()) + assert len(runs) == 1 + + +def test_get_results_clears_existing_output(client, mock_api, tmp_path): + """Existing output files are cleared before downloading by default.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + + # Create pre-existing output directory with an old file + output_dir = tmp_path / slug / "output" + output_dir.mkdir(parents=True, exist_ok=True) + old_file = output_dir / "old_result.txt" + old_file.write_text("stale data") + + mock_api.kernels.kernels_api_client.get_kernel_session_status.return_value = ( + _make_status_response("complete") + ) + mock_api.kernels.kernels_api_client.download_kernel_output.return_value = ( + _make_archive_response(("run.json", '{"score": 0.95}')) + ) + + result = client.get_results(slug) + + assert result.status == "complete" + # Old file should have been removed + assert not old_file.exists() + + +def test_get_results_no_clear_output(client, mock_api, tmp_path): + """clear_output=False preserves existing files in the output directory.""" + slug = "test-bench" + (tmp_path / slug).mkdir() + + # Create pre-existing output directory with an old file + output_dir = tmp_path / slug / "output" + output_dir.mkdir(parents=True, exist_ok=True) + old_file = output_dir / "old_result.txt" + old_file.write_text("keep me") + + mock_api.kernels.kernels_api_client.get_kernel_session_status.return_value = ( + _make_status_response("complete") + ) + mock_api.kernels.kernels_api_client.download_kernel_output.return_value = ( + _make_archive_response(("run.json", '{"score": 0.95}')) + ) + + result = client.get_results(slug, clear_output=False) + + assert result.status == "complete" + # Old file should still exist + assert old_file.exists() + assert old_file.read_text() == "keep me" + # New file should also exist + runs = dict(result.iter_run_results()) + assert len(runs) == 1 + + +# --------------------------------------------------------------------------- +# RunResult.iter_run_results edge cases +# --------------------------------------------------------------------------- + + +def test_iter_run_results_nonexistent_output_dir(tmp_path): + """iter_run_results() yields nothing when output_dir does not exist.""" + result = RunResult( + status="complete", + output_dir=str(tmp_path / "nonexistent"), + tracking_url=None, + ) + assert list(result.iter_run_results()) == [] + assert dict(result.iter_run_results()) == {} + + +def test_iter_run_results_none_output_dir(): + """iter_run_results() yields nothing when output_dir is None.""" + result = RunResult( + status="error", + output_dir=None, + tracking_url=None, + ) + assert list(result.iter_run_results()) == [] + assert dict(result.iter_run_results()) == {} + + +# --------------------------------------------------------------------------- +# fork +# --------------------------------------------------------------------------- + + +def test_fork_basic(client, mock_api, tmp_path): + """Happy path: pulls notebook and converts to .py.""" + mock_api.kernels.kernels_api_client.get_kernel.return_value = ( + _make_get_kernel_response(ipynb_name="riddle-benchmark.ipynb") + ) + + result = client.fork("alice/riddle-benchmark") + + assert result == tmp_path / "riddle-benchmark" + assert (result / "benchmark.py").exists() + assert (result / "kernel-metadata.json").exists() + assert (result / "benchmark.ipynb").exists() + + # Verify the .py file has cell delimiters + py_content = (result / "benchmark.py").read_text() + assert "# %%" in py_content + assert "print('hello')" in py_content + + +def test_fork_custom_slug(client, mock_api, tmp_path): + """Custom notebook_slug overrides the default.""" + mock_api.kernels.kernels_api_client.get_kernel.return_value = ( + _make_get_kernel_response(ipynb_name="notebook.ipynb") + ) + + result = client.fork("alice/riddle-benchmark", dest_notebook_slug="my-riddle") + + assert result == tmp_path / "my-riddle" + assert (result / "benchmark.py").exists() + + +def test_fork_exists_error(client, mock_api, tmp_path): + """Raises FileExistsError when workspace already exists.""" + (tmp_path / "riddle-benchmark").mkdir() + + with pytest.raises(FileExistsError, match="already exists"): + client.fork("alice/riddle-benchmark") + + +def test_fork_overwrite(client, mock_api, tmp_path): + """overwrite=True removes the existing workspace.""" + workspace = tmp_path / "riddle-benchmark" + workspace.mkdir() + (workspace / "old-file.txt").write_text("old content") + + mock_api.kernels.kernels_api_client.get_kernel.return_value = ( + _make_get_kernel_response(ipynb_name="riddle-benchmark.ipynb") + ) + + _ = client.fork("alice/riddle-benchmark", overwrite=True) + + assert not (workspace / "old-file.txt").exists() + assert (workspace / "benchmark.py").exists() + + +def test_fork_missing_notebook_raises_value_error(client, mock_api): + """fork() should raise a friendly ValueError for any HTTPError.""" + mock_api.kernels.kernels_api_client.get_kernel.side_effect = _make_404_error() + + with pytest.raises( + ValueError, match="Failed to pull notebook 'kaggle/does-not-exist'" + ): + client.fork("kaggle/does-not-exist") + + +def test_fork_http_error_500_raises_value_error(client, mock_api): + """fork() wraps any HTTPError (including 500) in ValueError.""" + mock_api.kernels.kernels_api_client.get_kernel.side_effect = _make_http_error(500) + + with pytest.raises(ValueError, match="Failed to pull notebook"): + client.fork("kaggle/server-error-notebook") + + +def test_fork_no_source(client, mock_api, tmp_path, caplog): + """fork() should log a warning when no source is found (empty blob).""" + resp = MagicMock() + resp.blob.source = None + resp.metadata = None # No metadata either + mock_api.kernels.kernels_api_client.get_kernel.return_value = resp + + import logging + + with caplog.at_level(logging.WARNING): + result = client.fork("alice/script-notebook") + + assert result == tmp_path / "script-notebook" + # No source → no benchmark.py conversion + assert not (result / "benchmark.py").exists() + # Warning should be logged + assert "No source found" in caplog.text + + +# --------------------------------------------------------------------------- +# _authenticate (integration of the above) +# --------------------------------------------------------------------------- + + +def test_authenticate_import_error(): + """Raises KaggleAuthError when kagglesdk package is not installed.""" + with patch.dict( + "sys.modules", + { + "kagglesdk": None, + "kagglesdk.kaggle_client": None, + "kagglesdk.kaggle_env": None, + }, + ): + # Force ImportError by removing the module from sys.modules cache + import sys + + saved = {} + for mod_name in list(sys.modules): + if mod_name.startswith("kagglesdk"): + saved[mod_name] = sys.modules.pop(mod_name) + + try: + with patch( + "builtins.__import__", + side_effect=ImportError("No module named 'kagglesdk'"), + ): + with pytest.raises(KaggleAuthError, match="kagglesdk.*required"): + _authenticate() + finally: + sys.modules.update(saved) diff --git a/uv.lock b/uv.lock index 066de63..e43d037 100644 --- a/uv.lock +++ b/uv.lock @@ -1892,6 +1892,8 @@ dependencies = [ [package.optional-dependencies] kaggle-client = [ { name = "jupytext" }, + { name = "kagglehub" }, + { name = "kagglesdk" }, ] [package.dev-dependencies] @@ -1976,6 +1978,8 @@ requires-dist = [ { name = "joblib" }, { name = "jupyter" }, { name = "jupytext", marker = "extra == 'kaggle-client'" }, + { name = "kagglehub", marker = "extra == 'kaggle-client'", specifier = ">=0.3.10" }, + { name = "kagglesdk", marker = "extra == 'kaggle-client'", specifier = ">=0.1.14,<1.0" }, { name = "nest-asyncio", specifier = ">=1.6.0" }, { name = "openai", specifier = ">=1.66" }, { name = "pandas" }, @@ -2077,6 +2081,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/8e/4077b08b95a1f8302c694a8b399bd413815fbe89045c41e6e08cd7d9439a/kagglehub-0.3.13-py3-none-any.whl", hash = "sha256:e00dec8b81396cbad9c7b5eb62a33cf8ae27da26227abd196ed8f054c845ca00", size = 68257, upload-time = "2025-08-26T16:17:32.13Z" }, ] +[[package]] +name = "kagglesdk" +version = "0.1.16" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/23/0e/51bf72a462e1e72fe3427b7c52b11c9c52cbcc63d7ce90f81a8f56d5a71b/kagglesdk-0.1.16.tar.gz", hash = "sha256:4a20da4ac6f4085e64b976a313ee136d4698737dc5be7c0f13009fadd41d5540", size = 121064, upload-time = "2026-02-27T19:32:34.019Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/6b/db30f17ad132391ac37a751fa45b32fd954a7ffa484fa3550eee9678334d/kagglesdk-0.1.16-py3-none-any.whl", hash = "sha256:a26ba7a754866f8eef1e327e78101f2960b6fe9b1b323183f2f61170abdb11ff", size = 160520, upload-time = "2026-02-27T19:32:32.721Z" }, +] + [[package]] name = "lark" version = "1.3.1"