File tree Expand file tree Collapse file tree 4 files changed +32
-10
lines changed Expand file tree Collapse file tree 4 files changed +32
-10
lines changed Original file line number Diff line number Diff line change 1
- from .loader import get_input_loader
1
+ from .loader import (
2
+ get_input_loader ,
3
+ INPUT_CONFIG_DIR ,
4
+ INTERNAL_INPUT_CONFIG_DIR ,
5
+ SUPPORTED_INPUT_OPS ,
6
+ )
Original file line number Diff line number Diff line change 1
1
import importlib
2
+ from pathlib import Path
2
3
from typing import Any
3
4
5
+ from tritonbench .utils .env_utils import is_fbcode
6
+
7
+ SUPPORTED_INPUT_OPS = ["highway_self_gating" ]
8
+
9
+ INPUT_CONFIG_DIR = Path (__file__ ).parent .parent .joinpath ("input_configs" )
10
+ INTERNAL_INPUT_CONFIG_DIR = (
11
+ importlib .resources .files ("tritonbench.data.input_configs.fb" )
12
+ if is_fbcode ()
13
+ else None
14
+ )
15
+
4
16
5
17
def get_input_loader (tritonbench_op : Any , op : str , input : str ):
6
18
if hasattr (tritonbench_op , "aten_op_name" ):
7
- loader_type = "aten"
19
+ generator_module = importlib .import_module (
20
+ ".input_loaders.aten" , package = __package__
21
+ )
22
+ elif op in SUPPORTED_INPUT_OPS :
23
+ op_module = "." .join (tritonbench_op .__module__ .split ("." )[:- 1 ])
24
+ generator_module = importlib .import_module (f"{ op_module } .input_loader" )
8
25
else :
9
26
raise RuntimeError (f"Unsupported op: { op } " )
10
-
11
- generator_module = importlib .import_module (
12
- f".input_loaders.{ loader_type } " , package = __package__
13
- )
14
27
input_iter_getter = generator_module .get_input_iter
15
28
input_iter = input_iter_getter (tritonbench_op , op , input )
16
29
return input_iter
Original file line number Diff line number Diff line change @@ -22,8 +22,7 @@ def __init__(
22
22
self .tb_args .input_loader = self .aten_op_input
23
23
24
24
def get_input_iter (self ) -> Generator :
25
- for inp in self ._get_input_iter ():
26
- yield inp
25
+ raise NotImplementedError ("get_input_iter is not implemented for AtenOperator." )
27
26
28
27
def eager (self , * input ):
29
28
args , kwargs = input
Original file line number Diff line number Diff line change 33
33
from tritonbench .components .export import export_data
34
34
35
35
from tritonbench .components .power .chart import power_chart_begin , power_chart_end
36
+ from tritonbench .data import SUPPORTED_INPUT_OPS
36
37
from tritonbench .utils .constants import (
37
38
DEFAULT_QUANTILES ,
38
39
DEFAULT_REP ,
@@ -777,7 +778,11 @@ def __init__(
777
778
# Run the post initialization
778
779
def __post__init__ (self ):
779
780
if self .tb_args .input_loader :
780
- if is_fbcode () and not hasattr (self , "aten_op_name" ):
781
+ if (
782
+ is_fbcode ()
783
+ and not hasattr (self , "aten_op_name" )
784
+ and self .name not in SUPPORTED_INPUT_OPS
785
+ ):
781
786
from tritonbench .data .fb .input_loader import get_input_loader
782
787
783
788
self .get_input_iter = get_input_loader (
@@ -786,7 +791,7 @@ def __post__init__(self):
786
791
else :
787
792
from tritonbench .data import get_input_loader
788
793
789
- self ._get_input_iter = get_input_loader (
794
+ self .get_input_iter = get_input_loader (
790
795
self , self .name , self .tb_args .input_loader
791
796
)
792
797
# Count total available inputs directly
You can’t perform that action at this time.
0 commit comments