Skip to content

Commit 49b7a2e

Browse files
committed
feat: Add cli for benchmark tasks
1 parent c687fde commit 49b7a2e

File tree

6 files changed

+1123
-9
lines changed

6 files changed

+1123
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ dependencies = [
3232
"urllib3 >= 1.15.1",
3333
"packaging",
3434
"protobuf",
35+
"jupytext >= 1.15.0",
3536
]
3637

3738
[project.scripts]

src/kaggle/api/kaggle_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def kernel_push(self, kernel_push_request): # noqa: E501
338338
"""
339339
with tempfile.TemporaryDirectory() as tmpdir:
340340
meta_file = os.path.join(tmpdir, "kernel-metadata.json")
341-
(fd, code_file) = tempfile.mkstemp("code", "py", tmpdir, text=True)
341+
fd, code_file = tempfile.mkstemp("code", "py", tmpdir, text=True)
342342
fd.write(json.dumps(kernel_push_request.code))
343343
os.close(fd)
344344
with open(meta_file, "w") as f:

src/kaggle/api/kaggle_api_extended.py

Lines changed: 216 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ def _authenticate_with_legacy_apikey(self) -> bool:
721721
return True
722722

723723
def _authenticate_with_access_token(self):
724-
(access_token, source) = get_access_token_from_env()
724+
access_token, source = get_access_token_from_env()
725725
if not access_token:
726726
return False
727727

@@ -1901,7 +1901,7 @@ def dataset_metadata_update(self, dataset, path):
19011901
dataset: The dataset to update.
19021902
path: The path to the metadata file.
19031903
"""
1904-
(owner_slug, dataset_slug, effective_path) = self.dataset_metadata_prep(dataset, path)
1904+
owner_slug, dataset_slug, effective_path = self.dataset_metadata_prep(dataset, path)
19051905
meta_file = self.get_dataset_metadata_file(effective_path)
19061906
with open(meta_file, "r") as f:
19071907
metadata = json.load(f)
@@ -1954,7 +1954,7 @@ def dataset_metadata(self, dataset, path):
19541954
Returns:
19551955
The path to the downloaded metadata file.
19561956
"""
1957-
(owner_slug, dataset_slug, effective_path) = self.dataset_metadata_prep(dataset, path)
1957+
owner_slug, dataset_slug, effective_path = self.dataset_metadata_prep(dataset, path)
19581958

19591959
if not os.path.exists(effective_path):
19601960
os.makedirs(effective_path)
@@ -3462,7 +3462,7 @@ def kernels_output_cli(self, kernel, kernel_opt=None, path=None, force=False, qu
34623462
file_pattern: Regex pattern to match against filenames. Only files matching the pattern will be downloaded.
34633463
"""
34643464
kernel = kernel or kernel_opt
3465-
(_, token) = self.kernels_output(kernel, path, file_pattern, force, quiet)
3465+
_, token = self.kernels_output(kernel, path, file_pattern, force, quiet)
34663466
if token:
34673467
print(f"Next page token: {token}")
34683468

@@ -3508,6 +3508,217 @@ def kernels_status_cli(self, kernel, kernel_opt=None):
35083508
else:
35093509
print('%s has status "%s"' % (kernel, status))
35103510

3511+
def benchmarks_pull(self, kernel: str, path: str = None, quiet: bool = False):
3512+
"""Pulls a benchmark notebook and converts it to a .py script."""
3513+
import os
3514+
3515+
try:
3516+
import jupytext
3517+
except ImportError:
3518+
raise ImportError("jupytext is required for benchmarks functionality. Please install it.")
3519+
3520+
effective_path = self.kernels_pull(kernel, path=path, metadata=True, quiet=quiet)
3521+
3522+
# After pulling, find the .ipynb file in effective_path
3523+
ipynb_files = [f for f in os.listdir(effective_path) if f.endswith(".ipynb")]
3524+
if not ipynb_files:
3525+
raise ValueError("Could not find a .ipynb file in the pulled kernel.")
3526+
3527+
# Rename the first .ipynb file to benchmark.ipynb (if it isn't already)
3528+
ipynb_name = ipynb_files[0]
3529+
ipynb_path = os.path.join(effective_path, ipynb_name)
3530+
benchmark_ipynb_path = os.path.join(effective_path, "benchmark.ipynb")
3531+
if ipynb_name != "benchmark.ipynb":
3532+
os.rename(ipynb_path, benchmark_ipynb_path)
3533+
if not quiet:
3534+
print(f"Renamed {ipynb_name} to benchmark.ipynb")
3535+
3536+
# Convert to benchmark.py
3537+
benchmark_py_path = os.path.join(effective_path, "benchmark.py")
3538+
notebook = jupytext.read(benchmark_ipynb_path)
3539+
3540+
# Strip confusing metadata formatting from the resulting Python document
3541+
if "jupytext" not in notebook.metadata:
3542+
notebook.metadata["jupytext"] = {}
3543+
notebook.metadata["jupytext"]["notebook_metadata_filter"] = "-all"
3544+
notebook.metadata["jupytext"]["cell_metadata_filter"] = "-all"
3545+
3546+
jupytext.write(notebook, benchmark_py_path, fmt="py:percent")
3547+
if not quiet:
3548+
print(f"Converted benchmark.ipynb to {benchmark_py_path}")
3549+
return effective_path
3550+
3551+
def benchmarks_pull_cli(self, kernel, kernel_opt=None, path=None):
3552+
kernel = kernel or kernel_opt
3553+
effective_path = self.benchmarks_pull(kernel, path=path, quiet=False)
3554+
print(f"Benchmark pulled and converted successfully to {effective_path}")
3555+
3556+
def benchmarks_publish_and_run(
3557+
self, kernel: str = None, path: str = None, file_name: str = None, quiet: bool = False
3558+
):
3559+
"""Converts a local .py benchmark to .ipynb and pushes it to Kaggle."""
3560+
import os
3561+
import json
3562+
3563+
try:
3564+
import jupytext
3565+
except ImportError:
3566+
raise ImportError("jupytext is required for benchmarks functionality. Please install it.")
3567+
3568+
path = path or os.getcwd()
3569+
file_name = file_name or "benchmark.py"
3570+
py_path = os.path.join(path, file_name)
3571+
3572+
if not os.path.exists(py_path):
3573+
raise FileNotFoundError(f"Source file not found: {py_path}")
3574+
3575+
ipynb_path = os.path.join(path, "benchmark.ipynb")
3576+
notebook = jupytext.read(py_path, fmt="py:percent")
3577+
3578+
# Inject default Kaggle kernelspec so Papermill executes the notebook correctly
3579+
if "kernelspec" not in notebook.metadata:
3580+
notebook.metadata["kernelspec"] = {
3581+
"display_name": "Python 3",
3582+
"language": "python",
3583+
"name": "python3"
3584+
}
3585+
3586+
jupytext.write(notebook, ipynb_path)
3587+
if not quiet:
3588+
print(f"Converted {py_path} to {ipynb_path}")
3589+
3590+
# Ensure kernel-metadata.json exists and has "personal-benchmark"
3591+
metadata_path = os.path.join(path, self.KERNEL_METADATA_FILE)
3592+
if not os.path.exists(metadata_path):
3593+
# Create a default metadata file
3594+
if not kernel:
3595+
raise ValueError("A kernel slug must be specified to create a new metadata file.")
3596+
3597+
if "/" in kernel:
3598+
self.validate_kernel_string(kernel)
3599+
owner_slug, kernel_slug = kernel.split("/")
3600+
else:
3601+
owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
3602+
kernel_slug = kernel
3603+
3604+
title = kernel_slug.replace("-", " ").title()
3605+
metadata = {
3606+
"id": f"{owner_slug}/{kernel_slug}",
3607+
"title": title,
3608+
"code_file": "benchmark.ipynb",
3609+
"language": "python",
3610+
"kernel_type": "notebook",
3611+
"is_private": "true",
3612+
"enable_gpu": "false",
3613+
"enable_internet": "true",
3614+
"dataset_sources": [],
3615+
"competition_sources": [],
3616+
"kernel_sources": [],
3617+
"model_sources": [],
3618+
"keywords": ["personal-benchmark"],
3619+
}
3620+
with open(metadata_path, "w") as f:
3621+
json.dump(metadata, f, indent=2)
3622+
if not quiet:
3623+
print(f"Created kernel metadata at {metadata_path}")
3624+
else:
3625+
# Read existing and inject keyword if missing
3626+
with open(metadata_path, "r") as f:
3627+
metadata = json.load(f)
3628+
3629+
if kernel:
3630+
if "/" in kernel:
3631+
self.validate_kernel_string(kernel)
3632+
owner_slug, kernel_slug = kernel.split("/")
3633+
else:
3634+
owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
3635+
kernel_slug = kernel
3636+
3637+
new_id = f"{owner_slug}/{kernel_slug}"
3638+
if metadata.get("id") != new_id:
3639+
metadata["id"] = new_id
3640+
metadata["title"] = kernel_slug.replace("-", " ").title()
3641+
if "id_no" in metadata:
3642+
del metadata["id_no"]
3643+
3644+
if "keywords" not in metadata:
3645+
metadata["keywords"] = []
3646+
if "personal-benchmark" not in metadata["keywords"]:
3647+
metadata["keywords"].append("personal-benchmark")
3648+
if "code_file" not in metadata or metadata["code_file"] != "benchmark.ipynb":
3649+
metadata["code_file"] = "benchmark.ipynb"
3650+
3651+
with open(metadata_path, "w") as f:
3652+
json.dump(metadata, f, indent=2)
3653+
3654+
# Now push using kernels_push
3655+
return self.kernels_push(path)
3656+
3657+
def benchmarks_publish_and_run_cli(self, kernel=None, kernel_opt=None, path=None, file_name=None):
3658+
kernel = kernel or kernel_opt
3659+
result = self.benchmarks_publish_and_run(kernel, path=path, file_name=file_name, quiet=False)
3660+
3661+
url_text = ""
3662+
if result and getattr(result, "url", None):
3663+
url_text = f"\nTracking URL: {result.url}"
3664+
3665+
print(f"Benchmark pushed and started successfully.{url_text}\nRun kaggle benchmarks tasks results to stream output.")
3666+
3667+
def benchmarks_get_results(self, kernel: str = None, path: str = None, poll_interval: int = 60, timeout: int = None):
3668+
"""Polls the status of a benchmark until complete, then downloads the output."""
3669+
import os
3670+
import time
3671+
import json
3672+
3673+
if kernel is None:
3674+
check_path = path or os.getcwd()
3675+
meta_file = os.path.join(check_path, self.KERNEL_METADATA_FILE)
3676+
if os.path.exists(meta_file):
3677+
with open(meta_file, "r") as f:
3678+
try:
3679+
metadata = json.load(f)
3680+
kernel = metadata.get("id")
3681+
except Exception:
3682+
pass
3683+
if kernel is None:
3684+
raise ValueError("A kernel must be specified")
3685+
3686+
start_time = time.time()
3687+
print(f"Waiting for benchmark {kernel} to complete...")
3688+
3689+
while True:
3690+
response = self.kernels_status(kernel)
3691+
status = response.status
3692+
status_str = str(status).upper()
3693+
if "COMPLETE" in status_str:
3694+
print(f"Benchmark {kernel} completed.")
3695+
break
3696+
elif "ERROR" in status_str:
3697+
message = getattr(response, "failure_message", "")
3698+
error_txt = f" Message: {message}" if message else ""
3699+
print(f"Benchmark {kernel} failed!{error_txt}")
3700+
print(f"Attempting to download partial logs for debugging...")
3701+
try:
3702+
self.kernels_output(kernel, path=check_path, force=True, quiet=False)
3703+
except Exception as log_err:
3704+
print(f"Could not retrieve backend logs: {log_err}")
3705+
raise ValueError(f"Benchmark execution terminated with an error state.")
3706+
else:
3707+
print(f"Status: {status}. Waiting {poll_interval}s...")
3708+
time.sleep(poll_interval)
3709+
3710+
if timeout is not None and (time.time() - start_time) > timeout:
3711+
raise TimeoutError(f"Timed out waiting for benchmark after {timeout} seconds.")
3712+
3713+
# Now download output
3714+
print(f"Downloading results for {kernel}...")
3715+
return self.kernels_output(kernel=kernel, path=path, force=True, quiet=False)
3716+
3717+
def benchmarks_get_results_cli(self, kernel, kernel_opt=None, path=None, poll_interval=60, timeout=None):
3718+
kernel = kernel or kernel_opt
3719+
self.benchmarks_get_results(kernel, path=path, poll_interval=poll_interval, timeout=timeout)
3720+
print("Output downloaded successfully.")
3721+
35113722
def model_get(self, model: str) -> ApiModel:
35123723
"""Gets a model.
35133724
@@ -4587,7 +4798,7 @@ def files_upload_cli(self, local_paths, inbox_path, no_resume, no_compress):
45874798
files_to_create = []
45884799
with ResumableUploadContext(no_resume) as upload_context:
45894800
for local_path in local_paths:
4590-
(upload_file, file_name) = self.file_upload_cli(local_path, inbox_path, no_compress, upload_context)
4801+
upload_file, file_name = self.file_upload_cli(local_path, inbox_path, no_compress, upload_context)
45914802
if upload_file is None:
45924803
continue
45934804

0 commit comments

Comments
 (0)