Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions testing/test_perf_regression_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import importlib.util
import runpy
import sys
import types
from pathlib import Path


def _load_perf_module(monkeypatch):
"""Load perf_regression directly while stubbing the tilelang package to avoid heavy test dependencies."""
module_path = Path(__file__).resolve().parents[1] / "tilelang/testing/perf_regression.py"
spec = importlib.util.spec_from_file_location("tilelang.testing.perf_regression", module_path)
assert spec is not None and spec.loader is not None
module = importlib.util.module_from_spec(spec)
sys.modules["tilelang.testing.perf_regression"] = module
spec.loader.exec_module(module)

tilelang_pkg = types.ModuleType("tilelang")
testing_pkg = types.ModuleType("tilelang.testing")
testing_pkg.process_func = module.process_func # type: ignore[attr-defined]
testing_pkg.regression = module.regression # type: ignore[attr-defined]
testing_pkg.perf_regression = module # type: ignore[attr-defined]
tilelang_pkg.testing = testing_pkg # type: ignore[attr-defined]

monkeypatch.setitem(sys.modules, "tilelang", tilelang_pkg)
monkeypatch.setitem(sys.modules, "tilelang.testing", testing_pkg)
monkeypatch.setitem(sys.modules, "tilelang.testing.perf_regression", module)

return module


def test_run_bench_file_executes_regressions(monkeypatch, tmp_path):
perf = _load_perf_module(monkeypatch)
bench_file = tmp_path / "regression_sample.py"
bench_file.write_text(
"import tilelang.testing\n"
"\n"
"def regression_sample():\n"
" tilelang.testing.process_func(lambda: 1.0, 'sample')\n",
encoding="utf-8",
)

perf._reset_results()
perf._run_bench_file(bench_file)

assert perf._results_to_jsonable() == [{"name": "sample", "latency": 1.0}]


def test_regression_all_uses_pytest_wrapper(monkeypatch, tmp_path):
perf = _load_perf_module(monkeypatch)
bench_file = tmp_path / "regression_sample.py"
bench_file.write_text(
"import tilelang.testing\n"
"\n"
"def regression_sample():\n"
" tilelang.testing.process_func(lambda: 2.5, 'sample')\n",
encoding="utf-8",
)

calls: dict[str, list[str]] = {}

def fake_pytest_main(args, _plugins=None):
# _plugins unused in mock; kept for signature compatibility with pytest.main
calls["args"] = args
module_vars = runpy.run_path(args[0])
for name, fn in module_vars.items():
if name.startswith("test_perf_regression_") and callable(fn):
fn()
return 0

monkeypatch.setitem(sys.modules, "pytest", types.SimpleNamespace(main=fake_pytest_main))

perf._reset_results()
perf.regression_all(examples_root=tmp_path)

assert Path(calls["args"][0]).name.startswith("test_perf_regression_wrapper")
assert perf._results_to_jsonable() == [{"name": "sample", "latency": 2.5}]
151 changes: 111 additions & 40 deletions tilelang/testing/perf_regression.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import contextlib
import importlib.util
import hashlib
import inspect
import json
import os
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
import tempfile
from typing import Any, Callable, Iterable, Sequence

try:
Expand Down Expand Up @@ -50,6 +53,83 @@ def _reset_results() -> None:
_RESULTS.clear()


@contextlib.contextmanager
def _pushd(path: Path) -> Iterable[None]:
"""Temporarily change working directory (process-wide; avoid in concurrent contexts)."""
cwd = Path.cwd()
os.chdir(path)
try:
yield
finally:
os.chdir(cwd)


@contextlib.contextmanager
def _prepend_sys_path(path: Path) -> Iterable[None]:
orig = list(sys.path)
sys.path.insert(0, str(path))
try:
yield
finally:
sys.path[:] = orig


def _iter_regression_functions(namespace: dict[str, Any], prefixes: Sequence[str]) -> Iterable[tuple[str, Callable[..., Any]]]:
for k, v in namespace.items():
if not callable(v):
continue
if any(k.startswith(p) for p in prefixes):
yield k, v


def _run_bench_file(bench_file: Path, *, prefixes: Sequence[str] = ("regression_",)) -> None:
bench_file = bench_file.resolve()
if not bench_file.is_file():
raise FileNotFoundError(f"Benchmark driver not found: {bench_file}")

with _pushd(bench_file.parent), _prepend_sys_path(bench_file.parent):
module_tag = hashlib.sha256(str(bench_file).encode("utf-8")).hexdigest()[:12]
parent_stem = bench_file.parent.name.replace("-", "_") or "root"
stem = bench_file.stem.replace("-", "_")
module_name = f"tilelang.testing.perf_regression.bench_{parent_stem}_{stem}_{module_tag}"
spec = importlib.util.spec_from_file_location(module_name, bench_file)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot import benchmark driver: {bench_file}")
module = importlib.util.module_from_spec(spec)
prev = sys.modules.get(module_name)
sys.modules[module_name] = module
try:
spec.loader.exec_module(module)

for _, fn in sorted(_iter_regression_functions(module.__dict__, prefixes), key=lambda kv: kv[0]):
fn()
finally:
if prev is None:
sys.modules.pop(module_name, None)
else:
sys.modules[module_name] = prev


def _build_pytest_wrapper(bench_files: Sequence[Path]) -> str:
lines = [
"from pathlib import Path",
"import tilelang.testing.perf_regression as _pr",
"",
"def _make_test(path_str):",
" path = Path(path_str)",
" def _inner():",
" _pr._run_bench_file(path)",
" return _inner",
"",
]

for idx, bench in enumerate(bench_files):
lines.append(f"test_perf_regression_{idx} = _make_test({str(bench)!r})")

lines.append("")
return "\n".join(lines)


def process_func(func: Callable[..., float], name: str | None = None, /, **kwargs: Any) -> float:
"""Execute a single perf function and record its latency.

Expand Down Expand Up @@ -133,60 +213,51 @@ def _discover_bench_files(examples_root: Path) -> list[Path]:
return sorted({p for p in files if p.is_file() and p.name != "__init__.py"})


def regression_all(examples_root: str | os.PathLike[str] | None = None) -> None:
def regression_all(examples_root: str | os.PathLike[str] | None = None, *, pytest_args: Sequence[str] | None = None) -> None:
"""Run all example benchmark drivers and print a consolidated table.

Intended usage (CI): `python -c "import tilelang.testing.perf_regression as pr; pr.regression_all()"`
Additional pytest arguments can be passed via `pytest_args`.
"""

root = Path(examples_root) if examples_root is not None else _examples_root()
if not root.exists():
raise FileNotFoundError(f"Examples root not found: {root}")

bench_files = _discover_bench_files(root)
bench_files = [p.resolve() for p in _discover_bench_files(root)]
if not bench_files:
raise RuntimeError(f"No benchmark drivers found under: {root}")

_reset_results()
wrapper_source = _build_pytest_wrapper(bench_files)
merged: dict[str, float] = {}
failures: list[str] = []

for bench_file in bench_files:
proc = subprocess.run(
[sys.executable, str(bench_file)],
cwd=str(bench_file.parent),
capture_output=True,
text=True,
env={
**os.environ,
# Keep child processes from picking up user-site or random paths.
"PYTHONNOUSERSITE": "1",
# Ask child to emit a single JSON marker line for robust parsing.
"TL_PERF_REGRESSION_FORMAT": "json",
},
)
if proc.returncode != 0:
failures.append(
f"{bench_file.relative_to(root)}\nSTDOUT:\n{proc.stdout}\nSTDERR:\n{proc.stderr}"
)
continue
with tempfile.TemporaryDirectory() as td:
wrapper = Path(td) / "test_perf_regression_wrapper.py"
wrapper.write_text(wrapper_source, encoding="utf-8")

parsed = _parse_table(proc.stdout)
for k, v in parsed.items():
# First writer wins to keep stable behavior if duplicates happen.
if k not in merged:
merged[k] = v
_RESULTS.append(PerfResult(name=k, latency=v))

if failures and not merged:
raise RuntimeError("All benchmark drivers failed:\n\n" + "\n\n".join(failures))
if failures:
# Don't hard-fail if we have some results; surface the errors for debugging.
print("# Some benchmark drivers failed (partial results)")
for msg in failures:
print("# ---")
for line in msg.splitlines():
print(f"# {line}")
try:
import pytest # type: ignore
except ImportError as exc: # pragma: no cover - tested via stubbed import
raise RuntimeError("pytest is required to run perf regression suite. Install with: pip install pytest") from exc

# Disable output capturing so benchmark progress remains visible.
args = [str(wrapper), "-s"]
if pytest_args:
args.extend(pytest_args)

exit_code = pytest.main(args)

for res in _RESULTS:
if res.name not in merged:
merged[res.name] = res.latency

if not merged:
if exit_code != 0:
raise RuntimeError("All benchmark drivers failed")
raise RuntimeError("No benchmark results collected")
if exit_code != 0:
# Don't hard-fail if we have some results; pytest already reported details.
print("# Some benchmark drivers failed (partial results)")

rows = [[k, merged[k]] for k in sorted(merged.keys())]
headers = ["File", "Latency"]
Expand Down