diff --git a/pyproject.toml b/pyproject.toml index 4012edbc74..292dc685a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,8 @@ jax = [ "jax>=0.8,<1.0", "flax>=0.12.2", "optax>=0.2.5", + "ml_dtypes>=0.5.0", + "jax-mps>=0.0.6; sys_platform == 'darwin' and platform_machine == 'arm64'", ] skyrl-train = [ @@ -208,7 +210,7 @@ override-dependencies = [ "causal-conv1d; sys_platform == 'never'", "transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'", "megatron-core==0.16.0; sys_platform == 'linux'", - "ml_dtypes>=0.5.0; sys_platform == 'linux'", + "ml_dtypes>=0.5.0", ] [tool.uv.extra-build-dependencies] diff --git a/skyrl/backends/jax.py b/skyrl/backends/jax.py index cc17943058..0adcd287b6 100644 --- a/skyrl/backends/jax.py +++ b/skyrl/backends/jax.py @@ -787,12 +787,13 @@ def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types logger.warning(f"No accumulated gradients for model {model_id}; applying step with zero gradients") # Update hyperparameters from the request + # Use direct .value assignment instead of [...] indexing to avoid MPS zero-sized tensor issues hp = optimizer.opt_state.hyperparams - hp["learning_rate"][...] = learning_rate - hp["b1"][...] = request_data.adam_params.beta1 - hp["b2"][...] = request_data.adam_params.beta2 - hp["eps"][...] = request_data.adam_params.eps - hp["weight_decay"][...] = request_data.adam_params.weight_decay + hp["learning_rate"].value = jnp.asarray(learning_rate, dtype=hp["learning_rate"].value.dtype) + hp["b1"].value = jnp.asarray(request_data.adam_params.beta1, dtype=hp["b1"].value.dtype) + hp["b2"].value = jnp.asarray(request_data.adam_params.beta2, dtype=hp["b2"].value.dtype) + hp["eps"].value = jnp.asarray(request_data.adam_params.eps, dtype=hp["eps"].value.dtype) + hp["weight_decay"].value = jnp.asarray(request_data.adam_params.weight_decay, dtype=hp["weight_decay"].value.dtype) # JIT-compiled: compute full gradients, apply optimizer update, and reset accumulated grads with jax.set_mesh(self.mesh): diff --git a/skyrl/tx/layers/lora.py b/skyrl/tx/layers/lora.py index 0aece0aa71..9cfbacf6e7 100644 --- a/skyrl/tx/layers/lora.py +++ b/skyrl/tx/layers/lora.py @@ -361,7 +361,10 @@ def init_adapter(path, value): if key_name == "lora_A": # Reinitialize with he_uniform, then zero columns beyond rank new_A = nnx.initializers.he_uniform()(rngs.params(), value[idx].shape, value.dtype) - new_A = new_A.at[..., effective_rank:].set(0.0) + # Zero columns beyond rank using multiplication (avoids MPS zero-sized tensor issues) + if new_A.shape[-1] > 0: + mask = jnp.arange(new_A.shape[-1]) < effective_rank + new_A = new_A * mask.astype(new_A.dtype) return value.at[idx].set(new_A) if key_name == "lora_B": # Explicitly zero lora_B diff --git a/skyrl/tx/utils/models.py b/skyrl/tx/utils/models.py index 9a2bf9bbee..e9832f8b83 100644 --- a/skyrl/tx/utils/models.py +++ b/skyrl/tx/utils/models.py @@ -194,7 +194,8 @@ def load_safetensors( tensor = tensor.reshape(param.shape) assert param.shape == tensor.shape, f"shape mismatch for {key}" # ArrayRef.set_raw_value writes through to the stacked parent variable - param.set_raw_value(jax.device_put(tensor.astype(param.dtype), param.sharding)) + # Use np.ascontiguousarray to ensure data is contiguous (required by jax-mps backend) + param.set_raw_value(jax.device_put(np.ascontiguousarray(tensor.astype(param.dtype)), param.sharding)) def save_safetensors( diff --git a/tests/tx/models/test_qwen3.py b/tests/tx/models/test_qwen3.py index 0391ecf70d..b1e265ac72 100644 --- a/tests/tx/models/test_qwen3.py +++ b/tests/tx/models/test_qwen3.py @@ -55,11 +55,11 @@ def test_qwen3(tp: int): def load_moe_base_weights(jax_moe_layer: Qwen3MoeSparseMoeBlock, hf_moe_layer: HFQwen3MoeSparseMoeBlock) -> None: """Load base weights from HF MoE layer to JAX MoE layer.""" - jax_moe_layer.gate.kernel[:] = hf_moe_layer.gate.weight.detach().numpy().T + jax_moe_layer.gate.kernel[:] = np.ascontiguousarray(hf_moe_layer.gate.weight.detach().numpy().T) for i, expert in enumerate(hf_moe_layer.experts): - jax_moe_layer.experts.gate_proj.weight[i, :, :] = expert.gate_proj.weight.detach().numpy().T - jax_moe_layer.experts.up_proj.weight[i, :, :] = expert.up_proj.weight.detach().numpy().T - jax_moe_layer.experts.down_proj.weight[i, :, :] = expert.down_proj.weight.detach().numpy().T + jax_moe_layer.experts.gate_proj.weight[i, :, :] = np.ascontiguousarray(expert.gate_proj.weight.detach().numpy().T) + jax_moe_layer.experts.up_proj.weight[i, :, :] = np.ascontiguousarray(expert.up_proj.weight.detach().numpy().T) + jax_moe_layer.experts.down_proj.weight[i, :, :] = np.ascontiguousarray(expert.down_proj.weight.detach().numpy().T) @pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)]) @@ -100,8 +100,9 @@ def load_lora_weights( and jax_module.lora_scaling is not None and jax_module.lora_ranks is not None ) - jax_module.lora_A[...] = jax_module.lora_A[...].at[adapter_idx].set(jnp.array(lora_A_weights)) - jax_module.lora_B[...] = jax_module.lora_B[...].at[adapter_idx].set(jnp.array(lora_B_weights)) + # Use np.ascontiguousarray to ensure data is contiguous (required by jax-mps backend) + jax_module.lora_A[...] = jax_module.lora_A[...].at[adapter_idx].set(jnp.array(np.ascontiguousarray(lora_A_weights))) + jax_module.lora_B[...] = jax_module.lora_B[...].at[adapter_idx].set(jnp.array(np.ascontiguousarray(lora_B_weights))) jax_module.lora_scaling[...] = jax_module.lora_scaling[...].at[adapter_idx].set(scaling) jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[adapter_idx].set(rank)