You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm encountering an issue where jax.random.beta and jax.random.gamma fail when called inside shard_map. The error indicates a problem with while_loop and varying manual axes. This is problematic for distributed training scenarios where we need to sample from these distributions within sharded computations.
JAX version: 0.7.0
Devices: [CudaDevice(id=0)]
alpha shape: (64, 1), sharding: NamedSharding(mesh=Mesh('dp': 1, axis_types=(Auto,)), spec=PartitionSpec('dp', None), memory_kind=device)
beta shape: (64, 1), sharding: NamedSharding(mesh=Mesh('dp': 1, axis_types=(Auto,)), spec=PartitionSpec('dp', None), memory_kind=device)
rng sharding: NamedSharding(mesh=Mesh('dp': 1, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=device)
Success! actions shape: (64, 1)
Trying with jit:
Success! actions shape: (64, 1)
Trying with shard_map:
Inside shard_map - alpha shape: (64, 1)
Inside shard_map - beta shape: (64, 1)
Inside shard_map - rng type: <class 'jax._src.shard_map.ShardMapTracer'>
Error with shard_map: TypeError: while_loop body function carry input and carry output must have equal types, but they differ:
The input carry component kxv[2] has type float32[] but the corresponding output carry component has type float32[]{dp}, so the varying manual axes do not match.
This might be fixed by applying `jax.lax.pvary(..., ('dp',))` to the initial carry value corresponding to the input carry component kxv[2].
See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma for more information.
Revise the function so that all output types match the corresponding input types.
Trying with jit + shard_map:
Error with jit + shard_map: TypeError: while_loop body function carry input and carry output must have equal types, but they differ:
The input carry component kxv[2] has type float32[] but the corresponding output carry component has type float32[]{dp}, so the varying manual axes do not match.
This might be fixed by applying `jax.lax.pvary(..., ('dp',))` to the initial carry value corresponding to the input carry component kxv[2].
See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma for more information.
Revise the function so that all output types match the corresponding input types.
Traceback (most recent call last):
File "/mnt/c/Users/admin/GitHub/jax_bug_001/01_shard_random.py", line 101, in <module>
main()
File "/mnt/c/Users/admin/GitHub/jax_bug_001/01_shard_random.py", line 69, in main
actions = shard_beta_fn(rng_sharded, alpha_sharded, beta_sharded)
File "/mnt/c/Users/admin/GitHub/jax_bug_001/01_shard_random.py", line 59, in shard_beta_fn_impl
return random.beta(rng, alpha, beta)
File "/home/admin/.local/lib/python3.12/site-packages/jax/_src/random.py", line 1036, in beta
return _beta(key, a, b, shape, dtype)
File "/home/admin/.local/lib/python3.12/site-packages/jax/_src/random.py", line 1050, in _beta
log_gamma_a = loggamma(key_a, a, shape, dtype)
File "/home/admin/.local/lib/python3.12/site-packages/jax/_src/random.py", line 1416, in loggamma
return _gamma(key, a, shape=shape, dtype=dtype, log_space=True)
File "/home/admin/.local/lib/python3.12/site-packages/jax/_src/random.py", line 1429, in _gamma
return random_gamma_p.bind(key, a, log_space=log_space)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: while_loop body function carry input and carry output must have equal types, but they differ:
The input carry component kxv[2] has type float32[] but the corresponding output carry component has type float32[]{dp}, so the varying manual axes do not match.
This might be fixed by applying `jax.lax.pvary(..., ('dp',))` to the initial carry value corresponding to the input carry component kxv[2].
See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma for more information.
Revise the function so that all output types match the corresponding input types.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/mnt/c/Users/admin/GitHub/jax_bug_001/01_shard_random.py", line 89, in main
actions = jit_shard_beta_fn(rng_sharded, alpha_sharded, beta_sharded)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/jax/_src/random.py", line 1300, in _gamma_impl
samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/jax/_src/random.py", line 1259, in _gamma_one
_, _, V, _ = lax_control_flow.while_loop(_cond_fn, _body_fn, (key, zero, one, lax._const(alpha, 2)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/jax/_src/random.py", line 1251, in _body_fn
_, x, v = lax_control_flow.while_loop(lambda kxv: lax.le(kxv[2], zero), _next_kxv, (x_key, zero, minus_one))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: while_loop body function carry input and carry output must have equal types, but they differ:
The input carry component kxv[2] has type float32[] but the corresponding output carry component has type float32[]{dp}, so the varying manual axes do not match.
This might be fixed by applying `jax.lax.pvary(..., ('dp',))` to the initial carry value corresponding to the input carry component kxv[2].
See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma for more information.
Revise the function so that all output types match the corresponding input types.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Process finished with exit code 0
Is there a recommended workaround for sampling from Beta/Gamma distributions inside shard_map? Is there a way to use lax.pvary to fix this that I'm missing?
More system info:
jax: 0.7.0
jaxlib: 0.7.0
numpy: 1.26.4
python: 3.12.11 (main, Jun 4 2025, 08:56:18) [GCC 11.4.0]
device info: NVIDIA GeForce RTX 4080 SUPER-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='admin-pc', release='6.6.87.2-microsoft-standard-WSL2', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 5 18:30:46 UTC 2025', machine='x86_64')
JAX_COMPILATION_CACHE_DIR=/home/admin/jax_compilation_cache
$ nvidia-smi
Fri Aug 1 12:18:36 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.124.06 Driver Version: 572.70 CUDA Version: 12.8 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4080 ... On | 00000000:01:00.0 On | N/A |
| 0% 40C P2 13W / 320W | 2715MiB / 16376MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 114455 C /python3.12 N/A |
| 0 N/A N/A 386366 C /python3.12 N/A |
+-----------------------------------------------------------------------------------------+
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm encountering an issue where
jax.random.beta
andjax.random.gamma
fail when called insideshard_map
. The error indicates a problem with while_loop and varying manual axes. This is problematic for distributed training scenarios where we need to sample from these distributions within sharded computations.Minimal Reproducible Example:
Output:
Is there a recommended workaround for sampling from Beta/Gamma distributions inside
shard_map
? Is there a way to uselax.pvary
to fix this that I'm missing?More system info:
Beta Was this translation helpful? Give feedback.
All reactions