Skip to content

Refine the gradient accumulation API #9078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

rpsilva-aws
Copy link
Collaborator

@rpsilva-aws rpsilva-aws commented May 1, 2025

In this PR, we refine the gradient accumulation API to include:

  • Making the body function wrapper pure without side effects
  • Enforcing train step pure requirements
  • Simplifying the mapping logic
  • Moving all train step non specific loop logic to the body wrapper
  • Initializing local accumulated gradients on the device (prevent requiring a data transfer if not present)
  • Change the API to return a tuple of carried tensors, instead of unpacking
  • Remove the explicit buffer donation, given the function is pure and Extend device data node binding API to not clone specified input tensors #9054
  • RNG fix for all iterations

Testing:

  • Validated the existing A/B testing for MLP with and without grad checkpointing
  • Added a few basic sanity tests
  • Validated the API with Llama 3.1 8B

cc: @mcuiaws

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_v2 branch 7 times, most recently from 9bfb1f7 to 930b208 Compare May 5, 2025 20:24
@rpsilva-aws rpsilva-aws assigned tengyifei and unassigned tengyifei May 5, 2025
@rpsilva-aws rpsilva-aws requested a review from tengyifei May 5, 2025 20:33
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_v2 branch 4 times, most recently from cf46ce1 to e0df762 Compare May 6, 2025 02:31
@rpsilva-aws rpsilva-aws marked this pull request as ready for review May 6, 2025 03:04
@rpsilva-aws
Copy link
Collaborator Author

Hmm, only PJRT_DEVICE=CUDA is having issues with the existing MLP A/B test: SIG11 on torch_xla::runtime::PjRtComputationClient::PjRtShardedData::GetHandle(). I'll look into it.

@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented May 8, 2025

I don't reproduce the same SIG11 observed on https://github.com/pytorch/xla/actions/runs/14850406478/job/41698575596?pr=9078 with CUDA, as it succeeds with NVIDIA A100-SXM4-40GB:

| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |

with the same 2.7 container: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.7.0_3.10_cuda_12.6

The CI run doesn't even hit a single print line in the root test file (be it 1 or 2):

+ python3 /__w/xla/xla/pytorch/xla/test/spmd/test_train_spmd_linear_model.py --skip-gradient-checkpointing
/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:236: UserWarning: XLA_USE_SPMD is being deprecated. Use torch_xla.runtime.use_spmd() without setting XLA_USE_SPMD env-var.
  warnings.warn("XLA_USE_SPMD is being deprecated. "
./usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:242: UserWarning: Replicating tensors already initialized on non-virtual XLA device for SPMD to force SPMD mode. This is one-time overhead to setup, and to minimize such, please set SPMD mode before initializting tensors (i.e., call use_spmd() in the beginning of the program).
  warnings.warn(
*** Received signal 11 ***
...

@tengyifei - I assume these were running earlier, as we have brought in the CI, anything I am missing? It wouldn't expect it to be specific to T4s (G4dn).

@tengyifei
Copy link
Collaborator

@rpsilva-aws the GPU CI is not very stable. I would worry only about TPU CI and CPU CI for now, and make sure your tests are registered in those two environments!

@tengyifei
Copy link
Collaborator

You're welcome to file a GPU-specific issue.

@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented May 9, 2025

@tengyifei Thanks, perfect - will do! TPU, CPU (and TRN) are all covered :) Do we have the means to disable a test for GPU?

@rpsilva-aws
Copy link
Collaborator Author

#9128

@tengyifei
Copy link
Collaborator

@rpsilva-aws you can disable a GPU test by marking it as "skipped" using the unittest API. When you disable a test you should attach the bug reference URL in the message. Use your judgement as to whether the test is truly due to some other broken thing in GPU vs caused by a bug in your PR.

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_v2 branch 3 times, most recently from 498d10c to 0c84b82 Compare May 9, 2025 03:17
@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented May 9, 2025

Absolutely, had I not tried to reproduce with A100, it'd be harder to judge - but given it succeeded with it (all other devices aside), I don't think it is a bug in the PR.

In any case, I will take the slow approach and flush a few logs, and partially skip the test and make a better decision after a couple CI runs.

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_v2 branch 2 times, most recently from 128fb42 to c6acb9c Compare May 9, 2025 05:57
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_v2 branch from c6acb9c to 82c5864 Compare May 9, 2025 18:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants