[tx] Enable MPS backend for Apple Silicon on Jax backend#1332
[tx] Enable MPS backend for Apple Silicon on Jax backend#1332pcmoritz wants to merge 1 commit intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request enables the JAX MPS backend for Apple Silicon, which is a valuable addition for local development. The changes introduce necessary dependencies and implement workarounds for current jax-mps limitations, such as requiring contiguous tensors and handling of zero-sized tensor operations. The modifications across the JAX backend, LoRA layers, and utility functions are logical and well-justified. Overall, this is a solid contribution. I have one minor suggestion to improve code maintainability.
| # Use direct .value assignment instead of [...] indexing to avoid MPS zero-sized tensor issues | ||
| hp = optimizer.opt_state.hyperparams | ||
| hp["learning_rate"][...] = learning_rate | ||
| hp["b1"][...] = request_data.adam_params.beta1 | ||
| hp["b2"][...] = request_data.adam_params.beta2 | ||
| hp["eps"][...] = request_data.adam_params.eps | ||
| hp["weight_decay"][...] = request_data.adam_params.weight_decay | ||
| hp["learning_rate"].value = jnp.asarray(learning_rate, dtype=hp["learning_rate"].value.dtype) | ||
| hp["b1"].value = jnp.asarray(request_data.adam_params.beta1, dtype=hp["b1"].value.dtype) | ||
| hp["b2"].value = jnp.asarray(request_data.adam_params.beta2, dtype=hp["b2"].value.dtype) | ||
| hp["eps"].value = jnp.asarray(request_data.adam_params.eps, dtype=hp["eps"].value.dtype) | ||
| hp["weight_decay"].value = jnp.asarray(request_data.adam_params.weight_decay, dtype=hp["weight_decay"].value.dtype) |
There was a problem hiding this comment.
The code for updating hyperparameters is quite repetitive. To improve readability and maintainability, you could refactor this block to use a loop over a dictionary of parameters.
# Use direct .value assignment instead of [...] indexing to avoid MPS zero-sized tensor issues
hp = optimizer.opt_state.hyperparams
params_to_update = {
"learning_rate": learning_rate,
"b1": request_data.adam_params.beta1,
"b2": request_data.adam_params.beta2,
"eps": request_data.adam_params.eps,
"weight_decay": request_data.adam_params.weight_decay,
}
for name, value in params_to_update.items():
hp[name].value = jnp.asarray(value, dtype=hp[name].value.dtype)
This uses the OSS MPS pjrt / StableHLO backend https://github.com/tillahoffmann/jax-mps
This will only be really interesting once the M5 Ultra or similar hardware gets released (to have the larger prefill performance of the M5), but it is very nice to see it works, on my puny mac book pro, it already gets a huge speedup over CPU and is useful for local development. It also proves the point that the Jax backend is very portable.
It can be run with
The timings for
are:
Compared on the CPU which is
It runs with the fully OSS backend https://github.com/tillahoffmann/jax-mps
A couple of limitations:
This might be a bug in metal, maybe it is already fixed in a later version, mine is a little outdated currently.
Maybe we can fix them by fixing the jax-mps backend, so it can run without modifications.