-
Notifications
You must be signed in to change notification settings - Fork 362
slight code reorg and bug correction for cross_compile #3472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
# insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows | ||
trt_node = gm.graph.call_function( | ||
torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default, | ||
(trt_module_node.args, *engine_info), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need to unpack this list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would still need to unpack the list. Else while loading in windows it shows
File "C:\Users\abose\Documents\work\TensorRT\torchTRT\Lib\site-packages\torch\_export\serde\serialize.py", line 2258, in deserialize_inputs
args.append(actual_args[schema_arg.name])
~~~~~~~~~~~^^^^^^^^^^^^^^^^^
KeyError: 'name'
py/torch_tensorrt/runtime/_utils.py
Outdated
@@ -144,6 +144,7 @@ def no_op_placeholder_for_execute_engine( | |||
serialized_hardware_compatible: str, | |||
serialized_metadata: str, | |||
serialized_target_platform: str, | |||
serialized_require_output_allocator: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this placeholder op to runtime/meta_ops
getitem_nodes = trt_node.users | ||
for idx, getitem_node in enumerate(getitem_nodes): | ||
getitem_node.meta["val"] = trt_node.meta["val"][idx] | ||
no_op_placeholder_node.replace_all_uses_with(trt_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a multi output testcase to the cross compile tests?
no_op_placeholder_node.replace_all_uses_with(trt_node) | ||
getitem_nodes = trt_node.users | ||
for idx, getitem_node in enumerate(getitem_nodes): | ||
getitem_node.meta["val"] = trt_node.meta["val"][idx] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@narendasan this is the part which should address the bug
@@ -22,6 +22,8 @@ class Add(torch.nn.Module): | |||
def forward(self, a, b): | |||
return torch.add(a, b) | |||
|
|||
print("here") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this
3f8ab4c
to
2934660
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@bowang007 on linux converter tests I see-
Would you know what is going wrong? |
Hi @apbose , When you do the cross-compile, what is the sm version that you are compiling into? |
Hmm @bowang007 are you suggesting the above wrt to the linux tests or the windows test? The error seems to be coming specifically in pytorch/TensorRT/tree/main/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py tests in linux |
Hi @apbose , |
@apbose can we do something like turning off these tests for other platforms for now? |
2934660
to
f8f0f55
Compare
Addresses the following for the cross_compile_for_windows feature-
cross_compile_flag
withcross_compile_module