Skip to content

Conversation

copybara-service[bot]
Copy link

[JAX SC] Implement pipelined SparseCore embedding layer in Flax

This change introduces PipelinedSparseCoreEmbed to enable pipelining of SparseCore operations with TensorCore. The layer stores inputs and activations/gradients across steps to allow for overlapping computation. The Shakespeare example is updated to use this new pipelined layer, requiring changes to how variables are handled in the training loop, including using mutable=True in model.apply. A test target for the Shakespeare example is also added.

With pipelining enabled, the model converges as shown below:

I 2025-09-28T13:41:31.529508-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 9: Loss = 7.518053
I 2025-09-28T13:41:31.599092-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 19: Loss = 7.317211
I 2025-09-28T13:41:31.668971-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 29: Loss = 7.0158334
I 2025-09-28T13:41:31.738482-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 39: Loss = 6.609592
I 2025-09-28T13:41:31.807324-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 49: Loss = 6.140646
I 2025-09-28T13:41:31.876329-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 59: Loss = 5.7092185
I 2025-09-28T13:41:31.949826-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 69: Loss = 5.3921995
I 2025-09-28T13:41:32.018904-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 79: Loss = 5.183298
I 2025-09-28T13:41:32.088231-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 89: Loss = 5.0565457
I 2025-09-28T13:41:32.156878-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 99: Loss = 4.9826536
...
I 2025-09-28T13:41:38.309954-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 999: Loss = 0.31050137

This change introduces `PipelinedSparseCoreEmbed` to enable pipelining of SparseCore operations with TensorCore. The layer stores inputs and activations/gradients across steps to allow for overlapping computation. The Shakespeare example is updated to use this new pipelined layer, requiring changes to how variables are handled in the training loop, including using `mutable=True` in `model.apply`. A test target for the Shakespeare example is also added.

With pipelining enabled, the model converges as shown below:
```
I 2025-09-28T13:41:31.529508-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 9: Loss = 7.518053
I 2025-09-28T13:41:31.599092-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 19: Loss = 7.317211
I 2025-09-28T13:41:31.668971-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 29: Loss = 7.0158334
I 2025-09-28T13:41:31.738482-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 39: Loss = 6.609592
I 2025-09-28T13:41:31.807324-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 49: Loss = 6.140646
I 2025-09-28T13:41:31.876329-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 59: Loss = 5.7092185
I 2025-09-28T13:41:31.949826-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 69: Loss = 5.3921995
I 2025-09-28T13:41:32.018904-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 79: Loss = 5.183298
I 2025-09-28T13:41:32.088231-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 89: Loss = 5.0565457
I 2025-09-28T13:41:32.156878-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 99: Loss = 4.9826536
...
I 2025-09-28T13:41:38.309954-07:00 4428  jax_sc_shakespeare_jit_flax.py:415] Step 999: Loss = 0.31050137
```

PiperOrigin-RevId: 812524634
@copybara-service copybara-service bot closed this Oct 7, 2025
@copybara-service copybara-service bot deleted the test_812524634 branch October 7, 2025 19:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant