Skip to content

Conversation

@aditvenk
Copy link
Contributor

PyTorch DTensor strategy can legitimately populate DTensorSpec without tensor_meta. In such cases, we attempt to do fake tensor propagation to populate tensor_meta, but for some ops, one or more outputs can legitimately be None depending on inputs (e.g., convolution.backward with certain output_mask).

In such cases, fake tensor prop cannot resolve the output tensor_meta, and we currently throw an error in validation.
Switch validation to instead emit a warning in such case. If tensor_meta is unknown, and that tensor is subsequently an input to a downstream op, we will fail during the input_spec validation.

Testing: Adding convolution test that revealed this issue.

PyTorch DTensor strategy can legitimately populate DTensorSpec without tensor_meta. In such cases, we attempt to do fake tensor propagation to populate tensor_meta, but for some ops, one or more outputs can legtimiately be None depending on inputs (e.g., convolution.backward with certain output_mask).

Switch validation to throw a warning in such case. If tensor_meta is legtimately not known, and the output of an op is subsequently an input to a downstream op, we will fail during the input_spec validation.

Testing: Adding convolution test that revealed this issue.

<!-- ps-id: 9c2841a8-6f6b-44c4-a07a-16fb9b32ac70 -->
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 26, 2025
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the diff @aditvenk !

So, one thing I'd like to clarify with DTensor folks (@wconstab @zpcore ) is around the expected outputs of DTensor specs.

My understanding is that there is some logic in DTensor that converts None placements as Replicate() somewhere down the road during DTensor sharding propagation (which is a shortcut which IMO should just be cleaned up, as we are querying the placements prior to this logic).
One case was in SDPA, which I worked around in

# This is overcoming some limitations of the lack of
# tensor_meta for sdpa which returns None
# we should just fix this all across the board
if ospec.tensor_meta is None:
ospec.tensor_meta = tm
else:
assert tm is None
so that we can assume tensor_meta is always present.

IMO we should always have the invariant that tensor_meta / redistribute_cost / etc are always present and populated in DTensorSpec, so that we can consistently rely on them.

Thoughts?

@aditvenk
Copy link
Contributor Author

Thanks for the diff @aditvenk !

So, one thing I'd like to clarify with DTensor folks (@wconstab @zpcore ) is around the expected outputs of DTensor specs.

My understanding is that there is some logic in DTensor that converts None placements as Replicate() somewhere down the road during DTensor sharding propagation (which is a shortcut which IMO should just be cleaned up, as we are querying the placements prior to this logic). One case was in SDPA, which I worked around in

# This is overcoming some limitations of the lack of
# tensor_meta for sdpa which returns None
# we should just fix this all across the board
if ospec.tensor_meta is None:
ospec.tensor_meta = tm
else:
assert tm is None

so that we can assume tensor_meta is always present.
IMO we should always have the invariant that tensor_meta / redistribute_cost / etc are always present and populated in DTensorSpec, so that we can consistently rely on them.

Thoughts?

The specific place where I saw the tensor meta as not populated came from here ( there is a TODO here too)
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/utils.py#L275

@fmassa
Copy link
Contributor

fmassa commented Dec 1, 2025

@wconstab wdty? Should we just enforce that we always return tensor_meta ?

@wconstab
Copy link
Contributor

wconstab commented Dec 1, 2025

Should we just enforce that we always return tensor_meta ?

I like the idea of enforcing this.

Annoyingly, I realized that in my single-dim rules currently, I am relying on not populating the output tensormeta because its weird to have to do shape inference inside the rule expansion infra, but i'm forced to create a new output spec. I still think we should figure out how to make this change, even if it requires a bit more of a refactor.

@fmassa
Copy link
Contributor

fmassa commented Dec 3, 2025

@aditvenk to unblock getting conv_backward to work, what do you think about doing a similar workaround as what I did for SDPA in

# This is overcoming some limitations of the lack of
# tensor_meta for sdpa which returns None
# we should just fix this all across the board
if ospec.tensor_meta is None:
ospec.tensor_meta = tm
else:
assert tm is None
, while we wait for PyTorch DTensor to fix its behavior?

Let me know if you would like me to tackle this part.

Thanks!

@aditvenk
Copy link
Contributor Author

aditvenk commented Dec 3, 2025

@aditvenk to unblock getting conv_backward to work, what do you think about doing a similar workaround as what I did for SDPA in

# This is overcoming some limitations of the lack of
# tensor_meta for sdpa which returns None
# we should just fix this all across the board
if ospec.tensor_meta is None:
ospec.tensor_meta = tm
else:
assert tm is None

, while we wait for PyTorch DTensor to fix its behavior?
Let me know if you would like me to tackle this part.

Thanks!

@fmassa --thanks, I will be happy to patch this PR as per your suggestion by EOW :)

@aditvenk
Copy link
Contributor Author

aditvenk commented Dec 5, 2025

@aditvenk to unblock getting conv_backward to work, what do you think about doing a similar workaround as what I did for SDPA in

# This is overcoming some limitations of the lack of
# tensor_meta for sdpa which returns None
# we should just fix this all across the board
if ospec.tensor_meta is None:
ospec.tensor_meta = tm
else:
assert tm is None

, while we wait for PyTorch DTensor to fix its behavior?
Let me know if you would like me to tackle this part.

Thanks!

@fmassa -- actually we cannot do the exact same thing as SDPA, because for conv_backward, some outputs are returned as None (depending on the value of output_mask). In that case, _get_meta_tensors_for_op() will return None as well for the tensor_meta, so we cannot fix it after the fact,.

e.g., ospec returned from PyTorch is Spec(unknown dtypeunknown shape(R)) and tm is None.

We don't have a way to compute a tensor_meta for this in AP level. We can only relax the constraint to validate that tensor_meta must always be present for outputs of op. I also wonder if it is correct for PyTorch to output Spec for a None output.

@fmassa
Copy link
Contributor

fmassa commented Dec 10, 2025

Oh, I just realized that conv and conv_backward in PyTorch are implemented using prop_rules, which basically mean that only a single output sharding is provided. This could maybe explain the issues you were facing.

@wconstab @zpcore given that conv / conv_backward are fairly common operations, shouldn't we just implement the sharding rules in function of op_strategy?

@fmassa
Copy link
Contributor

fmassa commented Dec 10, 2025

I looked a bit more into it, I think the issue is that given that conv and conv_transpose doesn't have a op_strategy implemented, they fall-back to the replicate_strategy from @zpcore , but there seems to be some limitations with its implementation that don't handle the cases from conv and conv_transpose properly

@zpcore
Copy link
Contributor

zpcore commented Dec 12, 2025

I looked a bit more into it, I think the issue is that given that conv and conv_transpose doesn't have a op_strategy implemented, they fall-back to the replicate_strategy from @zpcore , but there seems to be some limitations with its implementation that don't handle the cases from conv and conv_transpose properly

The op_schema for conv and conv_transpose may not fit the fallback replicate_strategy well, let me try a fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants