Skip to content

Commit 6fffb13

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Move aten input loader to the new loader style (#507)
Summary: Pull Request resolved: #507 We are refactoring the input loader component. Now, each operator must implement its input loader script in the `input-loader.py` under its directory. This file must provide a function, `get_input_loader(tritonbench_op: Any, op: str, input: str)`, which will be plugged in as the new input generator. The same interface holds for operator loaders like aten. Reviewed By: FindHao Differential Revision: D83754693
1 parent 00d8040 commit 6fffb13

File tree

5 files changed

+18
-16
lines changed

5 files changed

+18
-16
lines changed

tritonbench/data/loader.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
SUPPORTED_INPUT_OPS = ["highway_self_gating"]
88

9-
INPUT_CONFIG_DIR = Path(__file__).parent.parent.joinpath("input_configs")
9+
INPUT_CONFIG_DIR = Path(__file__).parent.joinpath("input_configs")
1010
INTERNAL_INPUT_CONFIG_DIR = (
1111
importlib.resources.files("tritonbench.data.input_configs.fb")
1212
if is_fbcode()
@@ -15,15 +15,11 @@
1515

1616

1717
def get_input_loader(tritonbench_op: Any, op: str, input: str):
18-
if hasattr(tritonbench_op, "aten_op_name"):
19-
generator_module = importlib.import_module(
20-
".input_loaders.aten", package=__package__
21-
)
22-
elif op in SUPPORTED_INPUT_OPS:
23-
op_module = ".".join(tritonbench_op.__module__.split(".")[:-1])
24-
generator_module = importlib.import_module(f"{op_module}.input_loader")
25-
else:
26-
raise RuntimeError(f"Unsupported op: {op}")
18+
assert (
19+
hasattr(tritonbench_op, "aten_op_name") or op in SUPPORTED_INPUT_OPS
20+
), f"Unsupported op: {op}. "
21+
op_module = ".".join(tritonbench_op.__module__.split(".")[:-1])
22+
generator_module = importlib.import_module(f"{op_module}.input_loader")
2723
input_iter_getter = generator_module.get_input_iter
2824
input_iter = input_iter_getter(tritonbench_op, op, input)
2925
return input_iter
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .loader import get_aten_loader_cls_by_name, list_aten_ops
1+
from .op_loader import get_aten_loader_cls_by_name, list_aten_ops

tritonbench/data/input_loaders/aten.py renamed to tritonbench/operator_loader/aten/input_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()}
4141

42-
INPUT_CONFIG_DIR = Path(__file__).parent.parent.joinpath("input_configs")
42+
from tritonbench.data import INPUT_CONFIG_DIR
4343

4444

4545
def truncate_inp(arg):

tritonbench/operator_loader/aten/loader.py renamed to tritonbench/operator_loader/aten/op_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_aten_loader_cls_by_name(aten_op_name: str, aten_op_input: Optional[str]
5151
If input is not provided, use the default input from the config file.
5252
"""
5353
op_cls_name = aten_op_name.replace(".", "_")
54-
module_name = f"tritonbench.operator_loader.loaders.{op_cls_name}"
54+
module_name = f"tritonbench.operator_loader.aten.{op_cls_name}"
5555
op_name_module = types.ModuleType(module_name)
5656
op_class = AtenOperator
5757
op_class.__module__ = module_name

tritonbench/utils/triton_op.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,13 @@ def _split_params_by_comma(params: Optional[str]) -> List[str]:
198198
def _find_op_name_from_module_path(module_path: str) -> str:
199199
PATH_PREFIX = "tritonbench.operators."
200200
# We have a separate operator loader for aten operator benchmark.
201-
PATH_PREFIX_LOADER = "tritonbench.operator_loader.loaders."
201+
PATH_PREFIX_LOADER = "tritonbench.operator_loader."
202202
assert (
203203
PATH_PREFIX in module_path or PATH_PREFIX_LOADER in module_path
204204
), f"We rely on module path prefix to identify operator name. Expected {PATH_PREFIX}<operator_name>, get {module_path}."
205205
if PATH_PREFIX_LOADER in module_path:
206206
suffix = module_path.partition(PATH_PREFIX_LOADER)[2]
207+
suffix = suffix.partition(".")[2]
207208
else:
208209
suffix = module_path.partition(PATH_PREFIX)[2]
209210
if suffix.startswith("fb."):
@@ -947,19 +948,24 @@ def input_callable(self):
947948
self._num_inputs = self._available_num_inputs - len(self._input_ids)
948949
self._input_ids = [i for i in range(0, self._num_inputs)]
949950

950-
def add_benchmark(self, bm_func_name: str, bm_callable: Callable):
951+
def add_benchmark(
952+
self, bm_func_name: str, bm_callable: Callable, baseline: bool = False
953+
):
951954
def _inner(self, *args, **kwargs):
952955
return bm_callable(*args, **kwargs)
953956

954957
decorator_kwargs = {
955958
"operator_name": self.name,
956959
"func_name": bm_func_name,
957960
"enabled": True,
961+
"baseline": baseline,
958962
}
959963
decorated_func = register_benchmark(**decorator_kwargs)(_inner)
960964
bound_method = types.MethodType(decorated_func, self)
961965
setattr(self, bm_func_name or bm_callable.__name__, bound_method)
962-
REGISTERED_BENCHMARKS[bm_func_name] = _inner
966+
if self.name not in REGISTERED_BENCHMARKS:
967+
REGISTERED_BENCHMARKS[self.name] = {}
968+
REGISTERED_BENCHMARKS[self.name][bm_func_name] = _inner
963969

964970
def run(
965971
self,

0 commit comments

Comments
 (0)