-
Notifications
You must be signed in to change notification settings - Fork 519
Refine the gradient accumulation API #9078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
9bfb1f7
to
930b208
Compare
cf46ce1
to
e0df762
Compare
Hmm, only PJRT_DEVICE=CUDA is having issues with the existing MLP A/B test: |
I don't reproduce the same SIG11 observed on https://github.com/pytorch/xla/actions/runs/14850406478/job/41698575596?pr=9078 with CUDA, as it succeeds with
with the same 2.7 container: The CI run doesn't even hit a single print line in the root test file (be it 1 or 2):
@tengyifei - I assume these were running earlier, as we have brought in the CI, anything I am missing? It wouldn't expect it to be specific to T4s (G4dn). |
@rpsilva-aws the GPU CI is not very stable. I would worry only about TPU CI and CPU CI for now, and make sure your tests are registered in those two environments! |
You're welcome to file a GPU-specific issue. |
@tengyifei Thanks, perfect - will do! TPU, CPU (and TRN) are all covered :) Do we have the means to disable a test for GPU? |
@rpsilva-aws you can disable a GPU test by marking it as "skipped" using the unittest API. When you disable a test you should attach the bug reference URL in the message. Use your judgement as to whether the test is truly due to some other broken thing in GPU vs caused by a bug in your PR. |
498d10c
to
0c84b82
Compare
Absolutely, had I not tried to reproduce with A100, it'd be harder to judge - but given it succeeded with it (all other devices aside), I don't think it is a bug in the PR. In any case, I will take the slow approach and flush a few logs, and partially skip the test and make a better decision after a couple CI runs. |
128fb42
to
c6acb9c
Compare
c6acb9c
to
82c5864
Compare
In this PR, we refine the gradient accumulation API to include:
Testing:
cc: @mcuiaws