-
Notifications
You must be signed in to change notification settings - Fork 529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use op.dtype to create aten.empty.memory_format during decomposition. #3941
base: main
Are you sure you want to change the base?
Conversation
Hi @ramiro050, I see that I am reverting your change here. It's not clear to me why the |
Hello @ramiro050 and @vivekkhandelwal1 can you please review this PR when you have a chance? Thanks! |
func.func @torch.aten.empty.memory_format$noneDtype() -> !torch.vtensor<[200,200,26],f64> attributes {torch.assume_strict_symbolic_shapes} { | ||
%int200 = torch.constant.int 200 | ||
%int26 = torch.constant.int 26 | ||
%false = torch.constant.bool false | ||
%none = torch.constant.none | ||
%0 = torch.prim.ListConstruct %int200, %int200, %int26 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | ||
%1 = torch.aten.empty.memory_format %0, %none, %none, %none, %false, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64> | ||
return %1 : !torch.vtensor<[200,200,26],f64> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @sahas3, the above test case is an invalid test since the tensor element type is f32 when none is specifed for the aten.empty.memory_format
op.
I just tried it. See:
>>> a = torch.ops.aten.empty.memory_format([2, 3])
>>> a
tensor([[-3.4054e+29, 4.2543e-41, -3.4085e+29],
[ 4.2543e-41, -3.4086e+29, 4.2543e-41]])
>>> a.dtype
torch.float32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, the changes that you've done are not correct and hence this patch should not be merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great flag, @vivekkhandelwal1. I missed that torch.ops.aten.empty.memory_format
produces f32
for dtype=none
in PyTorch as I assumed the aten.empty.memory_format
op being produced during conversion of ExportedProgram
to torch
dialect was correct.
Upon further investigation, I think the bug is instead in the DecomposeComplexOps
pass where we decompose different torch ops to aten.empty.memory_format
. For empty_like
op if dtype
is not specified then it defaults to input dtype as per https://pytorch.org/docs/stable/generated/torch.empty_like.html. This was not being captured when decomposing to empty.memory_format
.
I have pushed new changes, reverting the original change that addresses this issue. Can you please take another look at these new changes? Thanks!
Prior to the change in this PR
torch-mlir-opt --convert-torch-to-linalg
was running into the following error:This is because when
dtype
of theaten.empty.memory_format
isnone
, by defaultf32
was being selected as the element type of the resulting tensor which doesn't match with the actual element type of the result.