Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx]add all_reduce test. #2784

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

linuxlonelyeagle
Copy link
Member

No description provided.

@qingyunqu
Copy link
Collaborator

@stellaraccident @ramiro050 Hi, happy to introduce the first PR to support communication operator.
The CI failed because that nightly-build torch has difference signature from stable-build torch. The stable-build's signature is c10d_functional::all_reduce : (Tensor, str, str, int[], int) -> (Tensor), however the nightly-build's signature is _c10d_functional::all_reduce : (Tensor, str, str) -> (Tensor).
In our practice, the stable-build's signature is correct.

Do you have any idea to fix this?

Copy link
Collaborator

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's unfortunate but I expect happens from time to time, especially on these bleeding edge things that aren't covered by the pytorch team's defacto compatibility requirements.

I haven't seen one of these in a long time. Do you remember what we usually do? Can probably exclude one of them based on version?

@ramiro050
Copy link
Collaborator

I don't think I've ever seen this particular issue. We do have a place where we check the PyTorch version because of differences in ops supported:

if torch_version_for_comparison() >= version.parse("2.1.0.dev"):

but the issue here is that the ODS for the ops is hard-coded. One simple workaround would be to modify the ODS generator to output two versions for that op and add a Stable or Nightly to the op names. Once things converge upstream, we can get rid of the workaround.

@qingyunqu
Copy link
Collaborator

qingyunqu commented Jan 24, 2024

I don't think I've ever seen this particular issue. We do have a place where we check the PyTorch version because of differences in ops supported:

if torch_version_for_comparison() >= version.parse("2.1.0.dev"):

but the issue here is that the ODS for the ops is hard-coded. One simple workaround would be to modify the ODS generator to output two versions for that op and add a Stable or Nightly to the op names. Once things converge upstream, we can get rid of the workaround.

In my local test, when I use torch==2.3.0.dev20240109+cpu and without modified GeneratedTorchOps.td, it will generate:

test_import_frozen_exported_program
-----------------------------------
module {
  func.func @main(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> {
    %str = torch.constant.str "sum"
    %str_0 = torch.constant.str ""
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %0 = torch.prim.ListConstruct %int0, %int1, %int2, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %int4 = torch.constant.int 4
    %1 = torch.operator "torch.c10d_functional.all_reduce"(%arg0, %str, %str_0, %0, %int4) : (!torch.vtensor<[4],f32>, !torch.str, !torch.str, !torch.list<int>, !torch.int) -> !torch.vtensor<[4],f32>
    %2 = torch.operator "torch.c10d_functional.wait_tensor"(%1) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32>
    return %2 : !torch.vtensor<[4],f32>
  }
}

It seems that torch.export always generate op with signature c10d_functional::all_reduce : (Tensor, str, str, int[], int) -> (Tensor) even on nightly-build torch.

So how about adding another td file CommunicationOps.td manually (don't use the torch_ods_gen.py) as workaround? And merge this td file into GeneratedTorchOps.td once things converge upstream.

@ramiro050
Copy link
Collaborator

It seems that torch.export always generate op with signature c10d_functional::all_reduce : (Tensor, str, str, int[], int) -> (Tensor) even on nightly-build torch.

Interesting. Yeah, your proposed solution seems fine to me. No need to add an extra td file. I would just place it here next to this op:

// The corresponding without underscore variant for `torch.aten.bernoulli_.float`
// doesn't exist in the pytorch ops registry. Add it here.
def Torch_ValsemVariantAtenBernoulliFloatOp: Torch_Op<"valsem.aten.bernoulli.float", [

@qingyunqu
Copy link
Collaborator

It seems that torch.export always generate op with signature c10d_functional::all_reduce : (Tensor, str, str, int[], int) -> (Tensor) even on nightly-build torch.

Interesting. Yeah, your proposed solution seems fine to me. No need to add an extra td file. I would just place it here next to this op:

// The corresponding without underscore variant for `torch.aten.bernoulli_.float`
// doesn't exist in the pytorch ops registry. Add it here.
def Torch_ValsemVariantAtenBernoulliFloatOp: Torch_Op<"valsem.aten.bernoulli.float", [

Ok, I place it at the end of TorchOps.td.


if __name__ == "__main__":
world_size = 4
mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for doing the multi-process? This PR is not really changing anything in the FX importer, so I don't think we need these tests. If all that is needed is ODS, then it is fine to just have the ODS changes.

// Torch c10d Functional Communication Ops
//===----------------------------------------------------------------------===//
// These ops is manully added because nightly-build torch's signature is
// not convergent. Generate them if torch has stable op signature.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can you also mention something like "Autogenerated by ./build_tools/update_torch_ods.sh with torch==(some version)`", so that people don't manually modify these

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants