Skip to content

Commit 00d8040

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Load input shapes from highway_self_gating
Summary: Load input shapes from highway_self_gating traces. Differential Revision: D83660588
1 parent 6896537 commit 00d8040

File tree

4 files changed

+32
-10
lines changed

4 files changed

+32
-10
lines changed

tritonbench/data/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
1-
from .loader import get_input_loader
1+
from .loader import (
2+
get_input_loader,
3+
INPUT_CONFIG_DIR,
4+
INTERNAL_INPUT_CONFIG_DIR,
5+
SUPPORTED_INPUT_OPS,
6+
)

tritonbench/data/loader.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,29 @@
11
import importlib
2+
from pathlib import Path
23
from typing import Any
34

5+
from tritonbench.utils.env_utils import is_fbcode
6+
7+
SUPPORTED_INPUT_OPS = ["highway_self_gating"]
8+
9+
INPUT_CONFIG_DIR = Path(__file__).parent.parent.joinpath("input_configs")
10+
INTERNAL_INPUT_CONFIG_DIR = (
11+
importlib.resources.files("tritonbench.data.input_configs.fb")
12+
if is_fbcode()
13+
else None
14+
)
15+
416

517
def get_input_loader(tritonbench_op: Any, op: str, input: str):
618
if hasattr(tritonbench_op, "aten_op_name"):
7-
loader_type = "aten"
19+
generator_module = importlib.import_module(
20+
".input_loaders.aten", package=__package__
21+
)
22+
elif op in SUPPORTED_INPUT_OPS:
23+
op_module = ".".join(tritonbench_op.__module__.split(".")[:-1])
24+
generator_module = importlib.import_module(f"{op_module}.input_loader")
825
else:
926
raise RuntimeError(f"Unsupported op: {op}")
10-
11-
generator_module = importlib.import_module(
12-
f".input_loaders.{loader_type}", package=__package__
13-
)
1427
input_iter_getter = generator_module.get_input_iter
1528
input_iter = input_iter_getter(tritonbench_op, op, input)
1629
return input_iter

tritonbench/operator_loader/aten/loader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ def __init__(
2222
self.tb_args.input_loader = self.aten_op_input
2323

2424
def get_input_iter(self) -> Generator:
25-
for inp in self._get_input_iter():
26-
yield inp
25+
raise NotImplementedError("get_input_iter is not implemented for AtenOperator.")
2726

2827
def eager(self, *input):
2928
args, kwargs = input

tritonbench/utils/triton_op.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tritonbench.components.export import export_data
3434

3535
from tritonbench.components.power.chart import power_chart_begin, power_chart_end
36+
from tritonbench.data import SUPPORTED_INPUT_OPS
3637
from tritonbench.utils.constants import (
3738
DEFAULT_QUANTILES,
3839
DEFAULT_REP,
@@ -777,7 +778,11 @@ def __init__(
777778
# Run the post initialization
778779
def __post__init__(self):
779780
if self.tb_args.input_loader:
780-
if is_fbcode() and not hasattr(self, "aten_op_name"):
781+
if (
782+
is_fbcode()
783+
and not hasattr(self, "aten_op_name")
784+
and self.name not in SUPPORTED_INPUT_OPS
785+
):
781786
from tritonbench.data.fb.input_loader import get_input_loader
782787

783788
self.get_input_iter = get_input_loader(
@@ -786,7 +791,7 @@ def __post__init__(self):
786791
else:
787792
from tritonbench.data import get_input_loader
788793

789-
self._get_input_iter = get_input_loader(
794+
self.get_input_iter = get_input_loader(
790795
self, self.name, self.tb_args.input_loader
791796
)
792797
# Count total available inputs directly

0 commit comments

Comments
 (0)