-
Notifications
You must be signed in to change notification settings - Fork 11
Don't crash if tensor_meta is not available for output spec #268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Don't crash if tensor_meta is not available for output spec #268
Conversation
PyTorch DTensor strategy can legitimately populate DTensorSpec without tensor_meta. In such cases, we attempt to do fake tensor propagation to populate tensor_meta, but for some ops, one or more outputs can legtimiately be None depending on inputs (e.g., convolution.backward with certain output_mask). Switch validation to throw a warning in such case. If tensor_meta is legtimately not known, and the output of an op is subsequently an input to a downstream op, we will fail during the input_spec validation. Testing: Adding convolution test that revealed this issue. <!-- ps-id: 9c2841a8-6f6b-44c4-a07a-16fb9b32ac70 -->
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the diff @aditvenk !
So, one thing I'd like to clarify with DTensor folks (@wconstab @zpcore ) is around the expected outputs of DTensor specs.
My understanding is that there is some logic in DTensor that converts None placements as Replicate() somewhere down the road during DTensor sharding propagation (which is a shortcut which IMO should just be cleaned up, as we are querying the placements prior to this logic).
One case was in SDPA, which I worked around in
autoparallel/autoparallel/utils.py
Lines 84 to 90 in e794cc2
| # This is overcoming some limitations of the lack of | |
| # tensor_meta for sdpa which returns None | |
| # we should just fix this all across the board | |
| if ospec.tensor_meta is None: | |
| ospec.tensor_meta = tm | |
| else: | |
| assert tm is None |
tensor_meta is always present.
IMO we should always have the invariant that tensor_meta / redistribute_cost / etc are always present and populated in DTensorSpec, so that we can consistently rely on them.
Thoughts?
The specific place where I saw the tensor meta as not populated came from here ( there is a TODO here too) |
|
@wconstab wdty? Should we just enforce that we always return tensor_meta ? |
I like the idea of enforcing this. Annoyingly, I realized that in my single-dim rules currently, I am relying on not populating the output tensormeta because its weird to have to do shape inference inside the rule expansion infra, but i'm forced to create a new output spec. I still think we should figure out how to make this change, even if it requires a bit more of a refactor. |
|
@aditvenk to unblock getting conv_backward to work, what do you think about doing a similar workaround as what I did for SDPA in autoparallel/autoparallel/utils.py Lines 84 to 90 in e794cc2
Let me know if you would like me to tackle this part. Thanks! |
@fmassa --thanks, I will be happy to patch this PR as per your suggestion by EOW :) |
@fmassa -- actually we cannot do the exact same thing as SDPA, because for conv_backward, some outputs are returned as None (depending on the value of output_mask). In that case, e.g., ospec returned from PyTorch is We don't have a way to compute a tensor_meta for this in AP level. We can only relax the constraint to validate that tensor_meta must always be present for outputs of op. I also wonder if it is correct for PyTorch to output Spec for a None output. |
|
Oh, I just realized that @wconstab @zpcore given that |
|
I looked a bit more into it, I think the issue is that given that |
The op_schema for |
PyTorch DTensor strategy can legitimately populate DTensorSpec without tensor_meta. In such cases, we attempt to do fake tensor propagation to populate tensor_meta, but for some ops, one or more outputs can legitimately be None depending on inputs (e.g., convolution.backward with certain output_mask).
In such cases, fake tensor prop cannot resolve the output tensor_meta, and we currently throw an error in validation.
Switch validation to instead emit a warning in such case. If tensor_meta is unknown, and that tensor is subsequently an input to a downstream op, we will fail during the input_spec validation.
Testing: Adding convolution test that revealed this issue.