Skip to content

Commit 858c3da

Browse files
committed
nit: copy_ops -> get_copy_ops
Signed-off-by: gcunhase <[email protected]>
1 parent f4547d1 commit 858c3da

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

modelopt/onnx/op_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def is_fusible_scaling_op(op_type: str):
9696
]
9797

9898

99-
def copy_ops():
99+
def get_copy_ops():
100100
"""Returns list of copy operators."""
101101
return [
102102
"Flatten",
@@ -120,7 +120,7 @@ def copy_ops():
120120

121121
def is_copy_op(op_type: str):
122122
"""Returns whether the given op is a copy operator or not."""
123-
return op_type in copy_ops()
123+
return op_type in get_copy_ops()
124124

125125

126126
def is_linear_op(op_type: str):

modelopt/onnx/quantization/graph_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from onnxruntime.quantization.calibrate import CalibrationDataReader
3030

3131
from modelopt.onnx.logging_config import logger
32-
from modelopt.onnx.op_types import copy_ops, is_copy_op, is_linear_op
32+
from modelopt.onnx.op_types import get_copy_ops, is_copy_op, is_linear_op
3333
from modelopt.onnx.quantization.ort_utils import create_inference_session
3434
from modelopt.onnx.utils import (
3535
find_lowest_common_ancestor,
@@ -203,7 +203,7 @@ def _get_backbone(root: Node):
203203
["MaxPool", "Relu", "BatchNormalization", "BiasAdd", conv_type],
204204
]
205205
for idx, path_type in enumerate(fusible_linear_path_types):
206-
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=copy_ops()):
206+
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=get_copy_ops()):
207207
return _get_backbone(node)
208208

209209
return None

0 commit comments

Comments
 (0)