Why does check_rep
fail?
#30567
Unanswered
eitanporat
asked this question in
Q&A
Replies: 1 comment
-
Had the same issue, tldr use all_gather_invariant (recently added and undocumented). all_gather is varying for efficiency as the transpose of psum_scatter explained here https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html I needed an all_gather instead of just using out_specs because the jit decided to spam collective permutes for whatever reason. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
from the JAX Scaling Book "Programming TPUs in JAX"

Problem 1
This is my solution
It works with
check_rep=False
but without it it fails raising the errorWhat am I doing wrong? Doesn't AllGather replicate the results along the specified axis and thus the output have sharding (None, Y); as in the X axis is replicated and not sharded?
Running a small test with
check_rep=False
and looking at the shards they appear to be replicatedBeta Was this translation helpful? Give feedback.
All reactions