Skip to content

Commit 65b7c1a

Browse files
authored
[BE] Add testing for output saving logic (#153)
1 parent f9e7911 commit 65b7c1a

File tree

3 files changed

+429
-90
lines changed

3 files changed

+429
-90
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ generated_kernels/
1515
*.csv
1616
backendbench_output*
1717
.DS_Store
18+
*.bak

BackendBench/output.py

Lines changed: 151 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import csv
88
import json
99
import logging
10+
import os
1011
from collections import defaultdict
1112
from dataclasses import asdict
1213
from pathlib import Path
@@ -19,38 +20,15 @@
1920
logger = logging.getLogger(__name__)
2021

2122

22-
def save_results(
23+
def _prepare_results_data(
2324
correctness_results: List[CorrectnessTestResult],
2425
performance_results: List[PerformanceTestResult],
25-
output_path: Union[str, Path] = "backendbench_output",
26-
command: str = None,
27-
mean_correctness: float = None,
28-
geomean_perf: float = None,
29-
perf_at_p_score: float = None,
30-
p: float = 1.0,
31-
):
32-
"""Save results without creating per-operator directories.
33-
34-
Args:
35-
correctness_results: List of correctness test results
36-
performance_results: List of performance test results
37-
output_path: Base directory for saving results
38-
command: Command used to run the benchmark
39-
mean_correctness: Mean correctness score
40-
geomean_perf: Geometric mean of performance scores
41-
perf_at_p_score: Performance at threshold p score
42-
p: The threshold value used for perf@p calculation
26+
) -> Tuple[List[dict], List[dict], dict]:
27+
"""Prepare and process results data without file I/O.
4328
44-
Structure created:
45-
output_path/
46-
├── OVERALL_SUMMARY.md # Top level summary of results
47-
├── full_results.json # Complete results log
48-
├── operator_summary.csv # Operator-level summary
49-
└── failed_ops.json # Log of failed operations
29+
Returns:
30+
Tuple of (all_results, failed_tests, op_summaries)
5031
"""
51-
base_dir = Path(output_path)
52-
base_dir.mkdir(parents=True, exist_ok=True)
53-
5432
# Prep work: save all results as a list of dicts
5533
all_results = [asdict(result) for result in correctness_results] + [
5634
asdict(result) for result in performance_results
@@ -59,16 +37,7 @@ def save_results(
5937
# sort by op_name, then args
6038
all_results.sort(key=lambda x: (x["op_name"], x["args"]))
6139

62-
# 1. Save the full log in the base directory
63-
full_log_path = base_dir / "full_results.json"
64-
failed_ops_path = base_dir / "failed_ops.json"
65-
summary_csv_path = base_dir / "operator_summary.csv"
66-
67-
with open(full_log_path, "w") as f:
68-
json.dump(all_results, f, indent=2)
69-
logger.info(f"Full results saved to {full_log_path}")
70-
71-
# 2. Organize results by operator for csv
40+
# Organize results by operator for csv
7241
op_all_results = defaultdict(list)
7342
op_summaries = {}
7443

@@ -87,8 +56,10 @@ def save_results(
8756
]
8857

8958
# Calculate operator-level summary
90-
total_tests = len(op_tests)
91-
correct_tests = sum(1 for result in op_correctness_results if result.is_correct)
59+
correct_correctness_tests = sum(1 for result in op_correctness_results if result.is_correct)
60+
passed_performance_tests = sum(
61+
1 for result in op_performance_results if result.successfully_ran
62+
)
9263
# Collect performance metrics
9364
speedups = []
9465
abs_errors = []
@@ -105,25 +76,90 @@ def save_results(
10576
speedups.append(float(test.speedup))
10677

10778
# Calculate summary statistics
108-
correctness_rate = correct_tests / total_tests if total_tests > 0 else 0.0
79+
correctness_rate = (
80+
correct_correctness_tests / len(op_correctness_results)
81+
if len(op_correctness_results) > 0
82+
else 0.0
83+
)
10984
avg_speedup = sum(speedups) / len(speedups) if speedups else 0.0
11085
geomean_speedup = torch.tensor(speedups).log().mean().exp().item() if speedups else 0.0
11186
max_abs_error = max(abs_errors) if abs_errors else 0.0
11287
max_rel_error = max(rel_errors) if rel_errors else 0.0
11388

11489
op_summaries[op_name] = {
11590
"operator": op_name,
116-
"total_tests": total_tests,
117-
"passed_tests": correct_tests,
118-
"failed_tests": total_tests - correct_tests,
91+
"total_tests": len(op_all_results),
92+
"correctness_tests": len(op_correctness_results),
93+
"performance_tests": len(op_performance_results),
94+
"passed_correctness_tests": correct_correctness_tests,
95+
"passed_performance_tests": passed_performance_tests,
96+
"failed_correctness_tests": len(op_correctness_results) - correct_correctness_tests,
97+
"failed_performance_tests": len(op_performance_results) - passed_performance_tests,
11998
"correctness_rate": correctness_rate,
12099
"avg_speedup": avg_speedup,
121100
"geomean_speedup": geomean_speedup,
122101
"max_absolute_error": max_abs_error,
123102
"max_relative_error": max_rel_error,
124103
}
125104

126-
# 3. Create operator-level summary CSV
105+
# Prepare failed operations log
106+
failed_tests = [asdict(result) for result in correctness_results if not result.is_correct] + [
107+
asdict(result) for result in performance_results if not result.successfully_ran
108+
]
109+
110+
# sort failed_tests
111+
failed_tests.sort(key=lambda x: (x["op_name"], x["args"]))
112+
113+
return all_results, failed_tests, op_summaries
114+
115+
116+
def save_results(
117+
correctness_results: List[CorrectnessTestResult],
118+
performance_results: List[PerformanceTestResult],
119+
output_path: str,
120+
command: str,
121+
mean_correctness: float,
122+
geomean_perf: float,
123+
perf_at_p_score: float,
124+
p: float = 1.0,
125+
) -> Tuple[List[dict], List[dict], dict]:
126+
"""Prepare and process results data without file I/O.
127+
128+
Args:
129+
correctness_results: List of correctness test results
130+
performance_results: List of performance test results
131+
output_path: Base directory for saving results
132+
command: Command used to run the benchmark
133+
mean_correctness: Mean correctness score
134+
geomean_perf: Geometric mean of performance scores
135+
perf_at_p_score: Performance at threshold p score
136+
p: The threshold value used for perf@p calculation
137+
138+
Structure created:
139+
output_path/
140+
├── OVERALL_SUMMARY.md # Top level summary of results
141+
├── full_results.json # Complete results log
142+
├── operator_summary.csv # Operator-level summary
143+
└── failed_tests.json # Log of failed operations
144+
"""
145+
base_dir = Path(output_path)
146+
base_dir.mkdir(parents=True, exist_ok=True)
147+
148+
# Process data using the extracted function
149+
all_results, failed_tests, op_summaries = _prepare_results_data(
150+
correctness_results, performance_results
151+
)
152+
153+
# 1. Save the full log in the base directory
154+
full_log_path = os.path.join(base_dir, "full_results.json")
155+
failed_tests_path = os.path.join(base_dir, "failed_tests.json")
156+
summary_csv_path = os.path.join(base_dir, "operator_summary.csv")
157+
158+
with open(full_log_path, "w") as f:
159+
json.dump(all_results, f, indent=2)
160+
logger.info(f"Full results saved to {full_log_path}")
161+
162+
# 2. Create operator-level summary CSV
127163
if len(op_summaries) > 0:
128164
op_summary_list = list(op_summaries.values())
129165
fieldnames = list(op_summary_list[0].keys())
@@ -136,16 +172,10 @@ def save_results(
136172

137173
logger.info(f"Operator summary CSV saved to {summary_csv_path}")
138174

139-
# 4. Save failed operations log
140-
failed_tests = [asdict(result) for result in correctness_results if not result.is_correct] + [
141-
asdict(result) for result in performance_results if not result.successfully_ran
142-
]
143-
# sort failed_tests
144-
failed_tests.sort(key=lambda x: (x["op_name"], x["args"]))
145-
146-
with open(failed_ops_path, "w") as f:
175+
# 3. Save failed operations log
176+
with open(failed_tests_path, "w") as f:
147177
json.dump(failed_tests, f, indent=2)
148-
logger.info(f"Failed operations log saved to {failed_ops_path}")
178+
logger.info(f"Failed operations log saved to {failed_tests_path}")
149179

150180
# Save overall_summary if metrics are provided
151181
if all(x is not None for x in [command, mean_correctness, geomean_perf, perf_at_p_score]):
@@ -203,6 +233,61 @@ def _get_summary_op_results(
203233
return op_results
204234

205235

236+
def _generate_overall_summary_content(
237+
command: str,
238+
mean_correctness: float,
239+
geomean_perf: float,
240+
perf_at_p_score: float,
241+
p: float = 1.0,
242+
performance_results: List[PerformanceTestResult] = None,
243+
correctness_results: List[CorrectnessTestResult] = None,
244+
) -> str:
245+
"""Generate the content for the overall summary markdown file.
246+
247+
Returns:
248+
The markdown content as a string.
249+
"""
250+
op_results = _get_summary_op_results(performance_results, correctness_results)
251+
252+
content = []
253+
content.append("# BackendBench Run Summary\n")
254+
255+
content.append("## Command")
256+
content.append("```bash")
257+
content.append(f"{command}")
258+
content.append("```\n")
259+
260+
content.append("## Results\n")
261+
content.append("| Metric | Value |")
262+
content.append("|--------|-------|")
263+
content.append(f"| Correctness Score | {mean_correctness:.2f} |")
264+
content.append(f"| Performance Score (geomean speedup) | {geomean_perf:.2f} |")
265+
content.append(f"| Perf@{p} Score | {perf_at_p_score:.2f} |")
266+
content.append("")
267+
268+
content.append("### Metric Descriptions\n")
269+
content.append("- **Correctness Score**: Mean pass rate over all operators")
270+
content.append("- **Performance Score**: Geometric mean speedup over all operators")
271+
content.append(f"- **Perf@{p} Score**: Rate of correct samples with a speedup greater than {p}")
272+
content.append("")
273+
274+
content.append("## Output Files\n")
275+
content.append("The following files are saved in this directory:\n")
276+
content.append("- `full_results.json`: Complete test results for all operators")
277+
content.append("- `operator_summary.csv`: Operator-level summary statistics")
278+
content.append("- `failed_tests.json`: Log of failed tests (if any)")
279+
content.append("- `OVERALL_SUMMARY.md`: This file")
280+
281+
content.append("### Operator Speedups vs Eager in Descending Order\n")
282+
content.append("| Operator | Correctness Ratio | Speedup vs Eager |")
283+
content.append("|----------|-----------|----------------|")
284+
for op, correctness, speedup in op_results:
285+
content.append(f"| {op} | {correctness} | {speedup}|")
286+
content.append("")
287+
288+
return "\n".join(content)
289+
290+
206291
def save_overall_summary(
207292
output_path: Union[str, Path],
208293
command: str,
@@ -226,43 +311,19 @@ def save_overall_summary(
226311
base_dir = Path(output_path)
227312
base_dir.mkdir(parents=True, exist_ok=True)
228313

229-
overall_summary_path = base_dir / "OVERALL_SUMMARY.md"
230-
op_results = _get_summary_op_results(performance_results, correctness_results)
314+
overall_summary_path = os.path.join(base_dir, "OVERALL_SUMMARY.md")
315+
316+
content = _generate_overall_summary_content(
317+
command,
318+
mean_correctness,
319+
geomean_perf,
320+
perf_at_p_score,
321+
p,
322+
performance_results,
323+
correctness_results,
324+
)
231325

232326
with open(overall_summary_path, "w") as f:
233-
f.write("# BackendBench Run Summary\n\n")
234-
235-
f.write("## Command\n")
236-
f.write("```bash\n")
237-
f.write(f"{command}\n")
238-
f.write("```\n\n")
239-
240-
f.write("## Results\n\n")
241-
f.write("| Metric | Value |\n")
242-
f.write("|--------|-------|\n")
243-
f.write(f"| Correctness Score | {mean_correctness:.2f} |\n")
244-
f.write(f"| Performance Score (geomean speedup) | {geomean_perf:.2f} |\n")
245-
f.write(f"| Perf@{p} Score | {perf_at_p_score:.2f} |\n")
246-
f.write("\n")
247-
248-
f.write("### Metric Descriptions\n\n")
249-
f.write("- **Correctness Score**: Mean pass rate over all operators\n")
250-
f.write("- **Performance Score**: Geometric mean speedup over all operators\n")
251-
f.write(f"- **Perf@{p} Score**: Rate of correct samples with a speedup greater than {p}\n")
252-
f.write("\n")
253-
254-
f.write("## Output Files\n\n")
255-
f.write("The following files are saved in this directory:\n\n")
256-
f.write("- `full_results.json`: Complete test results for all operators\n")
257-
f.write("- `operator_summary.csv`: Operator-level summary statistics\n")
258-
f.write("- `failed_ops.json`: Log of failed operations (if any)\n")
259-
f.write("- `OVERALL_SUMMARY.md`: This file\n")
260-
261-
f.write("### Operator Speedups vs Eager in Descending Order\n\n")
262-
f.write("| Operator | Correctness Ratio | Speedup vs Eager |\n")
263-
f.write("|----------|-----------|----------------|\n")
264-
for op, correctness, speedup in op_results:
265-
f.write(f"| {op} | {correctness} | {speedup}|\n")
266-
f.write("\n")
327+
f.write(content)
267328

268329
logger.info(f"Overall summary saved to {overall_summary_path}")

0 commit comments

Comments
 (0)