Skip to content

[tx] Enable MPS backend for Apple Silicon on Jax backend#1332

Open
pcmoritz wants to merge 1 commit intoNovaSky-AI:mainfrom
pcmoritz:tx-mps
Open

[tx] Enable MPS backend for Apple Silicon on Jax backend#1332
pcmoritz wants to merge 1 commit intoNovaSky-AI:mainfrom
pcmoritz:tx-mps

Conversation

@pcmoritz
Copy link
Copy Markdown
Collaborator

@pcmoritz pcmoritz commented Mar 17, 2026

This uses the OSS MPS pjrt / StableHLO backend https://github.com/tillahoffmann/jax-mps

This will only be really interesting once the M5 Ultra or similar hardware gets released (to have the larger prefill performance of the M5), but it is very nice to see it works, on my puny mac book pro, it already gets a huge speedup over CPU and is useful for local development. It also proves the point that the Jax backend is very portable.

It can be run with

JAX_PLATFORMS=mps uv run --extra gpu --extra tinker -m skyrl.tinker.api --base-model Qwen/Qwen3-0.6B --backend-config '{"max_lora_adapters": 2, "max_lora_rank": 1, "train_micro_batch_size": 1}'

The timings for

export TINKER_API_KEY="tml-dummy"
uv run --with wandb --with tinker sl_loop.py \
    base_url=http://localhost:8000 \
    model_name=Qwen/Qwen3-0.6B lora_rank=1 train_on_what=LAST_ASSISTANT_MESSAGE

are:

Step 0                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000100   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 37834      │
│ progress                       │ 0.000000   │
│ skyrl.ai/grad_norm             │ 3.765625   │
│ skyrl.ai/learning_rate         │ 0.000100   │
│ time_total                     │ 331.295395 │
│ train_mean_nll                 │ 2.795922   │
└────────────────────────────────┴────────────┘
tinker_cookbook.utils.ml_log:147 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:199 [INFO] 
                    Step 1                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000099   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 35654      │
│ progress                       │ 0.013514   │
│ skyrl.ai/grad_norm             │ 4.781250   │
│ skyrl.ai/learning_rate         │ 0.000099   │
│ time_total                     │ 331.346420 │
│ train_mean_nll                 │ 2.779400   │
└────────────────────────────────┴────────────┘
tinker_cookbook.utils.ml_log:147 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:199 [INFO] 
                     Step 2                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ Metric                         ┃ Value       ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000097    │
│ num_sequences                  │ 128         │
│ num_tokens                     │ 38474       │
│ progress                       │ 0.027027    │
│ skyrl.ai/grad_norm             │ 4.125000    │
│ skyrl.ai/learning_rate         │ 0.000097    │
│ time_total                     │ 1306.594704 │
│ train_mean_nll                 │ 2.617733    │
└────────────────────────────────┴─────────────┘
tinker_cookbook.utils.ml_log:147 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:199 [INFO] 
                    Step 3                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000096   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 39781      │
│ progress                       │ 0.040541   │
│ skyrl.ai/grad_norm             │ 4.281250   │
│ skyrl.ai/learning_rate         │ 0.000096   │
│ time_total                     │ 285.588389 │
│ train_mean_nll                 │ 2.693189   │
└────────────────────────────────┴────────────┘
tinker_cookbook.utils.ml_log:147 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:199 [INFO] 
                    Step 4                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Metric                         ┃ Value      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000095   │
│ num_sequences                  │ 128        │
│ num_tokens                     │ 37346      │
│ progress                       │ 0.054054   │
│ skyrl.ai/grad_norm             │ 3.828125   │
│ skyrl.ai/learning_rate         │ 0.000094   │
│ time_total                     │ 151.503146 │
│ train_mean_nll                 │ 2.667086   │
└────────────────────────────────┴────────────┘

Compared on the CPU which is

Step 0                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ Metric                         ┃ Value       ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000100    │
│ num_sequences                  │ 128         │
│ num_tokens                     │ 37834       │
│ progress                       │ 0.000000    │
│ skyrl.ai/grad_norm             │ 3.859375    │
│ skyrl.ai/learning_rate         │ 0.000100    │
│ time_total                     │ 2165.875092 │
│ train_mean_nll                 │ 2.794540    │
└────────────────────────────────┴─────────────┘
tinker_cookbook.utils.ml_log:147 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:199 [INFO] 
                     Step 1                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ Metric                         ┃ Value       ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000099    │
│ num_sequences                  │ 128         │
│ num_tokens                     │ 35654       │
│ progress                       │ 0.013514    │
│ skyrl.ai/grad_norm             │ 5.000000    │
│ skyrl.ai/learning_rate         │ 0.000099    │
│ time_total                     │ 2148.938093 │
│ train_mean_nll                 │ 2.775640    │
└────────────────────────────────┴─────────────┘
tinker_cookbook.utils.ml_log:147 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:199 [INFO] 
                     Step 2                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ Metric                         ┃ Value       ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000097    │
│ num_sequences                  │ 128         │
│ num_tokens                     │ 38474       │
│ progress                       │ 0.027027    │
│ skyrl.ai/grad_norm             │ 4.343750    │
│ skyrl.ai/learning_rate         │ 0.000097    │
│ time_total                     │ 4486.724472 │
│ train_mean_nll                 │ 2.611415    │
└────────────────────────────────┴─────────────┘
tinker_cookbook.utils.ml_log:147 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:199 [INFO] 
                     Step 3                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ Metric                         ┃ Value       ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000096    │
│ num_sequences                  │ 128         │
│ num_tokens                     │ 39781       │
│ progress                       │ 0.040541    │
│ skyrl.ai/grad_norm             │ 4.187500    │
│ skyrl.ai/learning_rate         │ 0.000096    │
│ time_total                     │ 2290.050567 │
│ train_mean_nll                 │ 2.682926    │
└────────────────────────────────┴─────────────┘
tinker_cookbook.utils.ml_log:147 [INFO] Wrote metrics to /tmp/tinker-examples/sl-loop/metrics.jsonl
tinker_cookbook.utils.ml_log:199 [INFO] 
                     Step 4                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ Metric                         ┃ Value       ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ learning_rate                  │ 0.000095    │
│ num_sequences                  │ 128         │
│ num_tokens                     │ 37346       │
│ progress                       │ 0.054054    │
│ skyrl.ai/grad_norm             │ 3.796875    │
│ skyrl.ai/learning_rate         │ 0.000094    │
│ time_total                     │ 1484.751887 │
│ train_mean_nll                 │ 2.659317    │
└────────────────────────────────┴─────────────┘

It runs with the fully OSS backend https://github.com/tillahoffmann/jax-mps

A couple of limitations:

  • Currently sometimes crashes with
libc++abi: terminating due to uncaught exception of type std::runtime_error: [METAL] Command buffer execution failed: Internal Error (0000000e:Internal Error)
/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/resource_tracker.py:276: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown: {'/mp-dam64i5u'}

This might be a bug in metal, maybe it is already fixed in a later version, mine is a little outdated currently.

  • Needs tensors to be contiguous
  • Doesn't support zero dimensions

Maybe we can fix them by fixing the jax-mps backend, so it can run without modifications.


Open with Devin

@pcmoritz pcmoritz added the tx label Mar 17, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables the JAX MPS backend for Apple Silicon, which is a valuable addition for local development. The changes introduce necessary dependencies and implement workarounds for current jax-mps limitations, such as requiring contiguous tensors and handling of zero-sized tensor operations. The modifications across the JAX backend, LoRA layers, and utility functions are logical and well-justified. Overall, this is a solid contribution. I have one minor suggestion to improve code maintainability.

Comment on lines +790 to +796
# Use direct .value assignment instead of [...] indexing to avoid MPS zero-sized tensor issues
hp = optimizer.opt_state.hyperparams
hp["learning_rate"][...] = learning_rate
hp["b1"][...] = request_data.adam_params.beta1
hp["b2"][...] = request_data.adam_params.beta2
hp["eps"][...] = request_data.adam_params.eps
hp["weight_decay"][...] = request_data.adam_params.weight_decay
hp["learning_rate"].value = jnp.asarray(learning_rate, dtype=hp["learning_rate"].value.dtype)
hp["b1"].value = jnp.asarray(request_data.adam_params.beta1, dtype=hp["b1"].value.dtype)
hp["b2"].value = jnp.asarray(request_data.adam_params.beta2, dtype=hp["b2"].value.dtype)
hp["eps"].value = jnp.asarray(request_data.adam_params.eps, dtype=hp["eps"].value.dtype)
hp["weight_decay"].value = jnp.asarray(request_data.adam_params.weight_decay, dtype=hp["weight_decay"].value.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code for updating hyperparameters is quite repetitive. To improve readability and maintainability, you could refactor this block to use a loop over a dictionary of parameters.

        # Use direct .value assignment instead of [...] indexing to avoid MPS zero-sized tensor issues
        hp = optimizer.opt_state.hyperparams
        params_to_update = {
            "learning_rate": learning_rate,
            "b1": request_data.adam_params.beta1,
            "b2": request_data.adam_params.beta2,
            "eps": request_data.adam_params.eps,
            "weight_decay": request_data.adam_params.weight_decay,
        }
        for name, value in params_to_update.items():
            hp[name].value = jnp.asarray(value, dtype=hp[name].value.dtype)

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 3 additional findings.

Open in Devin Review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant