Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions tritonbench/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

SUPPORTED_INPUT_OPS = ["highway_self_gating"]

INPUT_CONFIG_DIR = Path(__file__).parent.parent.joinpath("input_configs")
INPUT_CONFIG_DIR = Path(__file__).parent.joinpath("input_configs")
INTERNAL_INPUT_CONFIG_DIR = (
importlib.resources.files("tritonbench.data.input_configs.fb")
if is_fbcode()
Expand All @@ -15,15 +15,11 @@


def get_input_loader(tritonbench_op: Any, op: str, input: str):
if hasattr(tritonbench_op, "aten_op_name"):
generator_module = importlib.import_module(
".input_loaders.aten", package=__package__
)
elif op in SUPPORTED_INPUT_OPS:
op_module = ".".join(tritonbench_op.__module__.split(".")[:-1])
generator_module = importlib.import_module(f"{op_module}.input_loader")
else:
raise RuntimeError(f"Unsupported op: {op}")
assert (
hasattr(tritonbench_op, "aten_op_name") or op in SUPPORTED_INPUT_OPS
), f"Unsupported op: {op}. "
op_module = ".".join(tritonbench_op.__module__.split(".")[:-1])
generator_module = importlib.import_module(f"{op_module}.input_loader")
input_iter_getter = generator_module.get_input_iter
input_iter = input_iter_getter(tritonbench_op, op, input)
return input_iter
2 changes: 1 addition & 1 deletion tritonbench/operator_loader/aten/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .loader import get_aten_loader_cls_by_name, list_aten_ops
from .op_loader import get_aten_loader_cls_by_name, list_aten_ops
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()}

INPUT_CONFIG_DIR = Path(__file__).parent.parent.joinpath("input_configs")
from tritonbench.data import INPUT_CONFIG_DIR


def truncate_inp(arg):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_aten_loader_cls_by_name(aten_op_name: str, aten_op_input: Optional[str]
If input is not provided, use the default input from the config file.
"""
op_cls_name = aten_op_name.replace(".", "_")
module_name = f"tritonbench.operator_loader.loaders.{op_cls_name}"
module_name = f"tritonbench.operator_loader.aten.{op_cls_name}"
op_name_module = types.ModuleType(module_name)
op_class = AtenOperator
op_class.__module__ = module_name
Expand Down
12 changes: 9 additions & 3 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,13 @@ def _split_params_by_comma(params: Optional[str]) -> List[str]:
def _find_op_name_from_module_path(module_path: str) -> str:
PATH_PREFIX = "tritonbench.operators."
# We have a separate operator loader for aten operator benchmark.
PATH_PREFIX_LOADER = "tritonbench.operator_loader.loaders."
PATH_PREFIX_LOADER = "tritonbench.operator_loader."
assert (
PATH_PREFIX in module_path or PATH_PREFIX_LOADER in module_path
), f"We rely on module path prefix to identify operator name. Expected {PATH_PREFIX}<operator_name>, get {module_path}."
if PATH_PREFIX_LOADER in module_path:
suffix = module_path.partition(PATH_PREFIX_LOADER)[2]
suffix = suffix.partition(".")[2]
else:
suffix = module_path.partition(PATH_PREFIX)[2]
if suffix.startswith("fb."):
Expand Down Expand Up @@ -948,19 +949,24 @@ def input_callable(self):
self._num_inputs = self._available_num_inputs - len(self._input_ids)
self._input_ids = [i for i in range(0, self._num_inputs)]

def add_benchmark(self, bm_func_name: str, bm_callable: Callable):
def add_benchmark(
self, bm_func_name: str, bm_callable: Callable, baseline: bool = False
):
def _inner(self, *args, **kwargs):
return bm_callable(*args, **kwargs)

decorator_kwargs = {
"operator_name": self.name,
"func_name": bm_func_name,
"enabled": True,
"baseline": baseline,
}
decorated_func = register_benchmark(**decorator_kwargs)(_inner)
bound_method = types.MethodType(decorated_func, self)
setattr(self, bm_func_name or bm_callable.__name__, bound_method)
REGISTERED_BENCHMARKS[bm_func_name] = _inner
if self.name not in REGISTERED_BENCHMARKS:
REGISTERED_BENCHMARKS[self.name] = {}
REGISTERED_BENCHMARKS[self.name][bm_func_name] = _inner

def run(
self,
Expand Down