Skip to content

Commit d433384

Browse files
authored
Cherrypick #3440 for release/2.7 (#3485)
1 parent db7bd90 commit d433384

File tree

3 files changed

+262
-7
lines changed

3 files changed

+262
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
from typing import Callable, Optional, Sequence, Union
2+
3+
import flashinfer
4+
import torch
5+
import torch_tensorrt
6+
from torch.fx.passes.shape_prop import TensorMetadata
7+
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
8+
_aten_lowering_pass,
9+
)
10+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
11+
clean_up_graph_after_modifications,
12+
)
13+
from transformers import LlamaConfig, LlamaForCausalLM
14+
15+
16+
@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
17+
def flashinfer_rmsnorm(
18+
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
19+
) -> torch.Tensor:
20+
return flashinfer.norm.rmsnorm(input, weight)
21+
22+
23+
@torch.library.register_fake("flashinfer::rmsnorm")
24+
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
25+
return input
26+
27+
28+
torch_tensorrt.dynamo.conversion.plugins.custom_op(
29+
"flashinfer::rmsnorm", supports_dynamic_shapes=True
30+
)
31+
32+
33+
@_aten_lowering_pass
34+
def replace_rmsnorm(
35+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
36+
) -> torch.fx.GraphModule:
37+
for node in gm.graph.nodes:
38+
if (
39+
node.target == torch.ops.aten._to_copy.default
40+
and node.kwargs.get("dtype") is torch.float32
41+
and len(node.users) == 2
42+
):
43+
if (
44+
list(node.users)[0].target == torch.ops.aten.pow.Tensor_Scalar
45+
and list(node.users)[1].target == torch.ops.aten.mul.Tensor
46+
):
47+
pow_node = list(node.users)[0]
48+
if (
49+
len(pow_node.users) == 1
50+
and list(pow_node.users)[0].target == torch.ops.aten.mean.dim
51+
):
52+
mean_node = list(pow_node.users)[0]
53+
if (
54+
len(mean_node.users) == 1
55+
and list(mean_node.users)[0].target == torch.ops.aten.add.Tensor
56+
):
57+
add_node = list(mean_node.users)[0]
58+
if (
59+
len(add_node.users) == 1
60+
and list(add_node.users)[0].target
61+
== torch.ops.aten.sqrt.default
62+
):
63+
sqrt_node = list(add_node.users)[0]
64+
if (
65+
len(sqrt_node.users) == 1
66+
and list(sqrt_node.users)[0].target
67+
== torch.ops.aten.div.Tensor
68+
):
69+
div_node = list(sqrt_node.users)[0]
70+
if list(div_node.users)[0] == list(node.users)[1]:
71+
mul_node = list(div_node.users)[0]
72+
copy_node = list(mul_node.users)[0]
73+
weight_mul_node = list(copy_node.users)[0]
74+
75+
weight = weight_mul_node.args[0]
76+
77+
original_meta = weight_mul_node.meta.get(
78+
"tensor_meta", {}
79+
)
80+
memory_format = original_meta.memory_format
81+
82+
with gm.graph.inserting_after(weight_mul_node):
83+
b = gm.graph.create_node(
84+
op="call_function",
85+
target=torch.ops.aten.sym_size.int,
86+
args=(node.args[0], 0),
87+
)
88+
b.meta["tensor_meta"] = TensorMetadata(
89+
shape=torch.Size([]),
90+
dtype=torch.int64,
91+
requires_grad=False,
92+
stride=None,
93+
memory_format=memory_format,
94+
is_quantized=False,
95+
qparams={},
96+
)
97+
s = gm.graph.create_node(
98+
op="call_function",
99+
target=torch.ops.aten.sym_size.int,
100+
args=(node.args[0], 1),
101+
)
102+
s.meta.update(b.meta)
103+
104+
d = gm.graph.create_node(
105+
op="call_function",
106+
target=torch.ops.aten.sym_size.int,
107+
args=(node.args[0], 2),
108+
)
109+
d.meta.update(b.meta)
110+
111+
with gm.graph.inserting_after(b):
112+
new_first_dim = gm.graph.create_node(
113+
op="call_function",
114+
target=torch.ops.aten.mul.Scalar,
115+
args=(b, s),
116+
)
117+
new_first_dim.meta.update(b.meta)
118+
119+
with gm.graph.inserting_after(new_first_dim):
120+
# with gm.graph.inserting_after(weight_mul_node):
121+
reshape_node = gm.graph.create_node(
122+
op="call_function",
123+
target=torch.ops.aten.reshape.default,
124+
args=(node.args[0], [new_first_dim, d]),
125+
)
126+
b_val = original_meta.shape[0]
127+
s_val = original_meta.shape[1]
128+
d_val = original_meta.shape[2]
129+
130+
reshape_node.meta["tensor_meta"] = (
131+
TensorMetadata(
132+
shape=torch.Size(
133+
[b_val * s_val, d_val]
134+
),
135+
dtype=original_meta.dtype,
136+
requires_grad=True,
137+
stride=None,
138+
memory_format=memory_format,
139+
is_quantized=False,
140+
qparams={},
141+
)
142+
)
143+
144+
with gm.graph.inserting_after(reshape_node):
145+
flashinfer_rmsnorm_node = gm.graph.create_node(
146+
op="call_function",
147+
target=torch.ops.flashinfer.rmsnorm.default,
148+
args=(
149+
reshape_node,
150+
weight,
151+
add_node.args[1],
152+
),
153+
)
154+
flashinfer_rmsnorm_node.meta.update(
155+
reshape_node.meta
156+
)
157+
158+
with gm.graph.inserting_after(
159+
flashinfer_rmsnorm_node
160+
):
161+
reshapback_node = gm.graph.create_node(
162+
op="call_function",
163+
target=torch.ops.aten.reshape.default,
164+
args=(
165+
flashinfer_rmsnorm_node,
166+
[b, s, d],
167+
),
168+
)
169+
170+
weight_mul_node.replace_all_uses_with(
171+
reshapback_node
172+
)
173+
reshapback_node.meta.update(weight_mul_node.meta)
174+
175+
modified_graph = True
176+
177+
gm.graph.erase_node(weight_mul_node)
178+
gm.graph.erase_node(copy_node)
179+
gm.graph.erase_node(mul_node)
180+
gm.graph.erase_node(div_node)
181+
gm.graph.erase_node(sqrt_node)
182+
gm.graph.erase_node(add_node)
183+
gm.graph.erase_node(mean_node)
184+
gm.graph.erase_node(pow_node)
185+
gm.graph.erase_node(node)
186+
187+
if modified_graph:
188+
gm = clean_up_graph_after_modifications(gm)
189+
190+
return gm
191+
192+
193+
# 1. Create a custom config with 1 layer
194+
config = LlamaConfig(
195+
vocab_size=32000,
196+
hidden_size=4096, # LLaMA2-7B dimensions
197+
intermediate_size=11008, # FFN hidden_dim = 4 * 4096 * 0.7 (SwiGLU scaling)
198+
num_hidden_layers=1, # Only 1 decoder layer
199+
num_attention_heads=32,
200+
max_position_embeddings=4096,
201+
use_cache=False, # Disable KV caching for export
202+
)
203+
204+
# 2. Initialize model (random weights)
205+
with torch.no_grad():
206+
model = LlamaForCausalLM(config).eval().half()
207+
208+
# 3. Export with static shapes
209+
input_ids = torch.randint(0, 32000, (1, 64)) # Static [batch=1, seq=64]
210+
exported = torch.export.export(
211+
model,
212+
(input_ids,),
213+
dynamic_shapes=None, # Fully static
214+
)
215+
216+
# Test forward pass
217+
input_ids = torch.randint(0, 32000, (1, 64))
218+
output = model(input_ids)
219+
print(output)
220+
221+
# Export validation
222+
223+
DEVICE = torch.device("cuda:0")
224+
225+
with torch_tensorrt.logging.errors():
226+
trt_model = torch_tensorrt.dynamo.compile(
227+
exported,
228+
inputs=[input_ids],
229+
enabled_precisions={torch.float32, torch.float16},
230+
truncate_double=True,
231+
device=DEVICE,
232+
disable_tf32=True,
233+
use_explicit_typing=False,
234+
use_fp32_acc=True,
235+
# debug=True,
236+
)
237+
238+
input_ids = input_ids.to(DEVICE)
239+
240+
res = trt_model.forward(input_ids)
241+
print(res)

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import logging
23
from types import FunctionType
34
from typing import Any, Callable, Tuple
@@ -108,7 +109,6 @@ def generate_signature(
108109

109110
def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
110111
shape_env = ShapeEnv()
111-
fake_mode = FakeTensorMode(shape_env=shape_env)
112112
syms_args = []
113113
tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)]
114114

@@ -121,7 +121,7 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
121121
]
122122
syms_args.append(syms_arg)
123123

124-
with FakeTensorMode() as fake_mode:
124+
with FakeTensorMode(shape_env=shape_env) as fake_mode:
125125
fake_args = []
126126
for syms_arg in syms_args:
127127
fake_arg = torch.randn(syms_arg)
@@ -130,16 +130,25 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
130130
output = torch_op(*fake_args, **kwargs)
131131

132132
# We assume that number of dimensions are the same in torch op
133-
shape_calc_fns = [None] * args[0].ndim
134-
for i in range(args[0].ndim):
135-
input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args]
133+
shape_calc_fns = [None] * output.ndim
134+
135+
for i in range(output.ndim):
136+
input_node_expr = list(
137+
itertools.chain.from_iterable(
138+
[sym.node.expr for sym in syms_arg] for syms_arg in syms_args
139+
)
140+
)
141+
136142
shape_calc_fns[i] = lambdify(
137143
tuple(input_node_expr), output.shape[i].node.expr, "math"
138144
)
139145

140146
out_desc = tensor_args[0].like()
141147
for i in range(out_desc.ndim):
142-
input_shape_expr = [tensor_arg.shape_expr[i] for tensor_arg in tensor_args]
148+
input_shape_expr = list(
149+
itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args)
150+
)
151+
143152
if output.shape[i].node.expr is None:
144153
raise ValueError(f"output.shape[{i}].node.expr cannot be None")
145154
out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc]

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import uuid
23
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
34

45
import numpy as np
@@ -47,11 +48,15 @@ def custom_kernel_converter(
4748
kwargs: Dict[str, Argument],
4849
name: str,
4950
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
51+
5052
plugin = getattr(getattr(trtp.op, namespace), op_name)
53+
5154
tensor_inputs = plugin.input_tensor_names
5255
tensor_args = args[0 : len(tensor_inputs)]
56+
57+
unique_id = uuid.uuid4()
5358
itensor_args = [
54-
get_trt_tensor(ctx, t, f"{t_name}")
59+
get_trt_tensor(ctx, t, f"{t_name}_{unique_id}")
5560
for (t, t_name) in zip(tensor_args, tensor_inputs)
5661
]
5762

0 commit comments

Comments
 (0)