diff --git a/tritonbench/data/loader.py b/tritonbench/data/loader.py index 8e9f371e1..bc5da16bf 100644 --- a/tritonbench/data/loader.py +++ b/tritonbench/data/loader.py @@ -6,7 +6,7 @@ SUPPORTED_INPUT_OPS = ["highway_self_gating"] -INPUT_CONFIG_DIR = Path(__file__).parent.parent.joinpath("input_configs") +INPUT_CONFIG_DIR = Path(__file__).parent.joinpath("input_configs") INTERNAL_INPUT_CONFIG_DIR = ( importlib.resources.files("tritonbench.data.input_configs.fb") if is_fbcode() @@ -15,15 +15,11 @@ def get_input_loader(tritonbench_op: Any, op: str, input: str): - if hasattr(tritonbench_op, "aten_op_name"): - generator_module = importlib.import_module( - ".input_loaders.aten", package=__package__ - ) - elif op in SUPPORTED_INPUT_OPS: - op_module = ".".join(tritonbench_op.__module__.split(".")[:-1]) - generator_module = importlib.import_module(f"{op_module}.input_loader") - else: - raise RuntimeError(f"Unsupported op: {op}") + assert ( + hasattr(tritonbench_op, "aten_op_name") or op in SUPPORTED_INPUT_OPS + ), f"Unsupported op: {op}. " + op_module = ".".join(tritonbench_op.__module__.split(".")[:-1]) + generator_module = importlib.import_module(f"{op_module}.input_loader") input_iter_getter = generator_module.get_input_iter input_iter = input_iter_getter(tritonbench_op, op, input) return input_iter diff --git a/tritonbench/operator_loader/aten/__init__.py b/tritonbench/operator_loader/aten/__init__.py index 5f2943029..a6585d163 100644 --- a/tritonbench/operator_loader/aten/__init__.py +++ b/tritonbench/operator_loader/aten/__init__.py @@ -1 +1 @@ -from .loader import get_aten_loader_cls_by_name, list_aten_ops +from .op_loader import get_aten_loader_cls_by_name, list_aten_ops diff --git a/tritonbench/data/input_loaders/aten.py b/tritonbench/operator_loader/aten/input_loader.py similarity index 98% rename from tritonbench/data/input_loaders/aten.py rename to tritonbench/operator_loader/aten/input_loader.py index fa8d1a6c7..2ace894fb 100644 --- a/tritonbench/data/input_loaders/aten.py +++ b/tritonbench/operator_loader/aten/input_loader.py @@ -39,7 +39,7 @@ dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()} -INPUT_CONFIG_DIR = Path(__file__).parent.parent.joinpath("input_configs") +from tritonbench.data import INPUT_CONFIG_DIR def truncate_inp(arg): diff --git a/tritonbench/operator_loader/aten/loader.py b/tritonbench/operator_loader/aten/op_loader.py similarity index 97% rename from tritonbench/operator_loader/aten/loader.py rename to tritonbench/operator_loader/aten/op_loader.py index c5cc9bc6c..c6910c6f5 100644 --- a/tritonbench/operator_loader/aten/loader.py +++ b/tritonbench/operator_loader/aten/op_loader.py @@ -51,7 +51,7 @@ def get_aten_loader_cls_by_name(aten_op_name: str, aten_op_input: Optional[str] If input is not provided, use the default input from the config file. """ op_cls_name = aten_op_name.replace(".", "_") - module_name = f"tritonbench.operator_loader.loaders.{op_cls_name}" + module_name = f"tritonbench.operator_loader.aten.{op_cls_name}" op_name_module = types.ModuleType(module_name) op_class = AtenOperator op_class.__module__ = module_name diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 328f85aa1..3153be26d 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -198,12 +198,13 @@ def _split_params_by_comma(params: Optional[str]) -> List[str]: def _find_op_name_from_module_path(module_path: str) -> str: PATH_PREFIX = "tritonbench.operators." # We have a separate operator loader for aten operator benchmark. - PATH_PREFIX_LOADER = "tritonbench.operator_loader.loaders." + PATH_PREFIX_LOADER = "tritonbench.operator_loader." assert ( PATH_PREFIX in module_path or PATH_PREFIX_LOADER in module_path ), f"We rely on module path prefix to identify operator name. Expected {PATH_PREFIX}, get {module_path}." if PATH_PREFIX_LOADER in module_path: suffix = module_path.partition(PATH_PREFIX_LOADER)[2] + suffix = suffix.partition(".")[2] else: suffix = module_path.partition(PATH_PREFIX)[2] if suffix.startswith("fb."): @@ -948,7 +949,9 @@ def input_callable(self): self._num_inputs = self._available_num_inputs - len(self._input_ids) self._input_ids = [i for i in range(0, self._num_inputs)] - def add_benchmark(self, bm_func_name: str, bm_callable: Callable): + def add_benchmark( + self, bm_func_name: str, bm_callable: Callable, baseline: bool = False + ): def _inner(self, *args, **kwargs): return bm_callable(*args, **kwargs) @@ -956,11 +959,14 @@ def _inner(self, *args, **kwargs): "operator_name": self.name, "func_name": bm_func_name, "enabled": True, + "baseline": baseline, } decorated_func = register_benchmark(**decorator_kwargs)(_inner) bound_method = types.MethodType(decorated_func, self) setattr(self, bm_func_name or bm_callable.__name__, bound_method) - REGISTERED_BENCHMARKS[bm_func_name] = _inner + if self.name not in REGISTERED_BENCHMARKS: + REGISTERED_BENCHMARKS[self.name] = {} + REGISTERED_BENCHMARKS[self.name][bm_func_name] = _inner def run( self,