Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions test/modules/op/to.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,14 @@ def forward(self, x):

def get_example_inputs(self):
return (torch.randn(1, 3),), {}


class SimpleToForCast(TestModuleBase):
def __init__(self):
super().__init__()

def forward(self, x):
return x.to(torch.int32)

def get_example_inputs(self):
return (torch.randn(1, 3),), {}
34 changes: 34 additions & 0 deletions test/modules/op/top_k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from test.modules.base import TestModuleBase
from test.utils.tag import use_onert

# luci-interpreter doesn't support TopK operator yet
@use_onert
class SimpleTopK(TestModuleBase):
def __init__(self):
super().__init__()

def forward(self, x):
values, indices = torch.topk(x, 2)
return values, indices

def get_example_inputs(self):
batch_size = 1
seq_len = 63
num_experts = 8
return (torch.randn(batch_size * seq_len, num_experts),), {}
49 changes: 48 additions & 1 deletion tico/passes/legalize_predefined_layout_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

if TYPE_CHECKING:
import torch.fx
from operator import getitem

import torch
from torch.export import ExportedProgram

Expand All @@ -26,7 +28,7 @@
from tico.utils.graph import create_node
from tico.utils.passes import PassBase, PassResult
from tico.utils.trace_decorators import trace_graph_diff_on_pass
from tico.utils.utils import is_target_node
from tico.utils.utils import is_target_node, set_new_meta_val
from tico.utils.validate_args_kwargs import (
AvgPool2dArgs,
Conv2DArgs,
Expand All @@ -35,6 +37,7 @@
DequantizePerTensorArgs,
InstanceNormArgs,
MaxPool2dWithIndicesArgs,
TopKArgs,
)


Expand Down Expand Up @@ -434,6 +437,49 @@ def legalize_avg_pool2d(self, exported_program, node) -> bool:
modified = True
return modified

def legalize_top_k(self, exported_program, node) -> bool:
logger = logging.getLogger(__name__)
modified = False

graph_module = exported_program.graph_module
graph = graph_module.graph

args = TopKArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
input_ = args.input
k = args.k
dim = args.dim

if not (dim == -1 or dim == len(extract_shape(input_)) - 1):
raise NotYetSupportedError("Only support dim = -1 (last dimension)")

with graph.inserting_after(input_):
circle_topk = create_node(
graph,
torch.ops.circle_custom.top_k,
args=(input_, k),
origin=input_,
)

with graph.inserting_after(circle_topk):
topk_values = create_node(graph, getitem, args=(circle_topk, 0))
topk_indices = create_node(graph, getitem, args=(circle_topk, 1))
with graph.inserting_after(topk_indices):
topk_indices_int64 = create_node(
graph,
torch.ops.aten._to_copy.default,
args=(topk_indices,),
kwargs={"dtype": torch.int64},
)

get_item, get_item_1 = node.users.keys()
get_item.replace_all_uses_with(topk_values, propagate_meta=True)
get_item_1.replace_all_uses_with(topk_indices_int64, propagate_meta=True)

logger.debug(f"{node.name} is replaced with {circle_topk.name}")
modified = True

return modified

def call(self, exported_program: ExportedProgram) -> PassResult:
target_to_legalize_func = {
torch.ops.aten.conv2d.default: self.legalize_conv2d,
Expand All @@ -442,6 +488,7 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
torch.ops.aten.max_pool2d_with_indices.default: self.legalize_max_pool2d_with_indices,
torch.ops.aten.avg_pool2d.default: self.legalize_avg_pool2d,
torch.ops.aten.instance_norm.default: self.legalize_instance_norm,
torch.ops.aten.topk.default: self.legalize_top_k,
}

graph_module = exported_program.graph_module
Expand Down
2 changes: 2 additions & 0 deletions tico/serialize/circle_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
multiple_output_ops = [
torch.ops.aten.split_with_sizes.default,
torch.ops.aten.max.dim,
torch.ops.aten.topk.default,
torch.ops.circle_custom.top_k,
]


Expand Down
59 changes: 37 additions & 22 deletions tico/serialize/operators/op_to_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, TYPE_CHECKING
from typing import Dict, List, TYPE_CHECKING, Union

if TYPE_CHECKING:
import torch._ops
import torch.fx
import torch
from circle_schema import circle

from tico.passes import ops

from tico.serialize.circle_mapping import (
extract_circle_dtype,
extract_torch_dtype,
Expand All @@ -29,12 +31,12 @@
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
from tico.utils.errors import NotYetSupportedError
from tico.utils.validate_args_kwargs import ToCopyArgs
from tico.utils.validate_args_kwargs import ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs


@register_node_visitor
class ToCopyVisitor(NodeVisitor):
target: List[torch._ops.OpOverload] = [torch.ops.aten._to_copy.default]
target: List[torch._ops.OpOverload] = ops.aten.to_copy

def __init__(self, op_codes: Dict[OpCode, int], graph):
super().__init__(op_codes, graph)
Expand All @@ -60,42 +62,55 @@ def define_cast_node(

return operator

def parse_args(self, op: torch._ops.OpOverload, args, kwargs):
ret: Union[ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs]
if op is torch.ops.aten._to_copy.default:
ret = ToCopyArgs(*args, **kwargs)
elif op is torch.ops.aten.to.dtype:
ret = ToDtypeArgs(*args, **kwargs)
elif op is torch.ops.aten.to.dtype_layout:
ret = ToDtypeLayoutArgs(*args, **kwargs)
else:
raise NotImplementedError(f"Unsupported to_copy/to operator: {op}")

return ret

def define_node(
self,
node: torch.fx.Node,
) -> circle.Operator.OperatorT:
supported_kwargs = ["dtype", "device", "layout"]
if not all(k in supported_kwargs for k in node.kwargs):
unsupported_node_kargs = list(node.kwargs.keys())
for supported_key in supported_kwargs:
if supported_key in node.kwargs:
unsupported_node_kargs.remove(supported_key)
raise NotYetSupportedError(
f"Support only {supported_kwargs} kwargs now. Do not support {unsupported_node_kargs}"
)

args = ToCopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type, call-arg]
args = ToCopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
input = args.input
dtype = args.dtype
layout = args.layout
# device is meaningless in circle

pin_memory = args.pin_memory
non_blocking = args.non_blocking
memory_format = args.memory_format

if pin_memory is not None:
raise NotYetSupportedError("Do not support pin_memory yet")
if non_blocking is True:
raise NotYetSupportedError("Do not support non_blocking yet")
if memory_format is not None:
raise NotYetSupportedError("Do not support memory_format yet")

input_meta = input.meta["val"]
# https://pytorch.org/docs/stable/tensor_attributes.html#torch-layout
# layout is two types: torch.strided(dense Tensors), torch.sparse_coo(sparse COO Tensors)
if "layout" in input.kwargs and input.kwargs["layout"] != input_meta:
raise NotYetSupportedError(
f"Only support when node and its input have same layout: (input layout: {input_meta}), (node layout: {node.kwargs['layout']})."
f"Only support when node and its input have same layout: (input layout: {input_meta}), (node layout: {layout})."
)

if dtype is not None:
target_type = node.kwargs["dtype"]
else:
# device and layout are meaningless
target_type = extract_torch_dtype(node)
assert isinstance(target_type, torch.dtype), type(target_type)
if dtype is None:
dtype = extract_torch_dtype(node)
assert isinstance(dtype, torch.dtype), type(dtype)

# define cast node
in_type: int = extract_circle_dtype(input)
out_type: int = to_circle_dtype(target_type)
out_type: int = to_circle_dtype(dtype)
inputs = [input]
outputs = [node]
operator = self.define_cast_node(inputs, outputs, in_type, out_type)
Expand Down
79 changes: 79 additions & 0 deletions tico/serialize/operators/op_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, TYPE_CHECKING

if TYPE_CHECKING:
import torch.fx
import torch
from circle_schema import circle

from tico.serialize.circle_graph import CircleSubgraph
from tico.serialize.circle_mapping import (
circle_legalize_dtype_to,
extract_circle_shape,
extract_shape,
extract_torch_dtype,
)
from tico.serialize.operators.hashable_opcode import OpCode
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
from tico.utils.validate_args_kwargs import TopKArgs


@register_node_visitor
class TopkVisitor(NodeVisitor):
""" """

target: List[torch._ops.OpOverload] = [
torch.ops.circle_custom.top_k,
]

def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
super().__init__(op_codes, graph)

def define_topk_node(
self, inputs: List, outputs: List
) -> circle.Operator.OperatorT:
op_index = get_op_index(
circle.BuiltinOperator.BuiltinOperator.TOPK_V2, self._op_codes
)

operator = create_builtin_operator(self.graph, op_index, inputs, outputs)

operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.TopKV2Options
option = circle.TopKV2Options.TopKV2OptionsT()
operator.builtinOptions = option

return operator

def define_node(
self,
node: torch.fx.Node,
) -> circle.Operator.OperatorT:
args = TopKArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
input = args.input
k = args.k

input_shape = extract_circle_shape(input)
k_i32 = circle_legalize_dtype_to(k, dtype=torch.int32)
assert args.dim == -1 or args.dim == len(input_shape) - 1

inputs = [input, k_i32]

outputs = [i for i in node.users.keys()]

topk_node: circle.Operator.OperatorT = self.define_topk_node(inputs, outputs)

return topk_node
8 changes: 5 additions & 3 deletions tico/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING

if TYPE_CHECKING:
import torch.fx
from operator import getitem

import torch
from torch.export import ExportedProgram
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
Expand Down Expand Up @@ -238,7 +240,7 @@ def get_module_name_chain(node: Optional[torch.fx.Node]) -> str:

def create_node(
graph: torch.fx.Graph,
target: torch._ops.OpOverload,
target: Callable,
args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
*,
Expand All @@ -252,7 +254,7 @@ def create_node(
graph : torch.fx.Graph
The graph that will own the newly-created node.

target : torch._ops.OpOverload
target : Callable
The op to call (e.g. `torch.add` or "call_function" target).

args : Tuple[Any, ...], optional
Expand Down
Loading
Loading