-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathrun.py
126 lines (106 loc) · 3.91 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Tritonbench benchmark runner.
Note: make sure to `python install.py` first or otherwise make sure the benchmark you are going to run
has been installed. This script intentionally does not automate or enforce setup steps.
"""
import argparse
import os
import sys
from typing import List
from tritonbench.operator_loader import load_opbench_by_name_from_loader
from tritonbench.operators import load_opbench_by_name
from tritonbench.operators_collection import list_operators_by_collection
from tritonbench.utils.env_utils import is_fbcode
from tritonbench.utils.gpu_utils import gpu_lockdown
from tritonbench.utils.parser import get_parser
from tritonbench.utils.run_utils import run_in_task
from tritonbench.utils.triton_op import BenchmarkOperatorResult
try:
if is_fbcode():
from .fb.utils import usage_report_logger # @manual
else:
usage_report_logger = lambda *args, **kwargs: None
except ImportError:
usage_report_logger = lambda *args, **kwargs: None
def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorResult:
if args.operator_loader:
Opbench = load_opbench_by_name_from_loader(args)
else:
Opbench = load_opbench_by_name(args.op)
opbench = Opbench(
tb_args=args,
extra_args=extra_args,
)
try:
opbench.run(args.warmup, args.iter)
finally:
metrics = opbench.output
if not args.skip_print:
if args.csv:
metrics.write_csv_to_file(sys.stdout)
else:
print(metrics)
if is_fbcode() and args.log_scuba:
from .fb.utils import log_benchmark # @manual
kwargs = {
"metrics": metrics,
"benchmark_name": args.op,
"device": args.device,
"logging_group": args.logging_group or args.op,
"precision": args.precision,
}
if args.production_shapes:
from tritonbench.utils.fb.durin_data import productionDataLoader
kwargs["weights_loader"] = productionDataLoader
if "hardware" in args:
kwargs["hardware"] = args.hardware
log_benchmark(**kwargs)
if args.plot:
try:
opbench.plot()
except NotImplementedError:
print(f"Plotting is not implemented for {args.op}")
if args.output:
with open(args.output, "w") as f:
metrics.write_csv_to_file(f)
print(f"[tritonbench] Output result csv to {args.output}")
if args.output_json:
with open(args.output_json, "w") as f:
metrics.write_json_to_file(f)
if args.output_dir:
if args.csv:
output_file = os.path.join(args.output_dir, f"{args.op}.csv")
with open(output_file, "w") as f:
metrics.write_json_to_file(f)
else:
output_file = os.path.join(args.output_dir, f"{args.op}.json")
with open(output_file, "w") as f:
metrics.write_json_to_file(f)
return metrics
def run(args: List[str] = []):
if args == []:
args = sys.argv[1:]
# Log the tool usage
usage_report_logger(benchmark_name="tritonbench")
parser = get_parser()
args, extra_args = parser.parse_known_args(args)
if args.ci:
from .ci import run_ci # @manual
run_ci()
return
if args.op:
ops = args.op.split(",")
else:
ops = list_operators_by_collection(args.op_collection)
# Force isolation in subprocess if testing more than one op.
if len(ops) >= 2:
args.isolate = True
with gpu_lockdown(args.gpu_lockdown):
for op in ops:
args.op = op
if args.isolate:
run_in_task(op)
else:
_run(args, extra_args)
if __name__ == "__main__":
run()