Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ jax = [
"jax>=0.8,<1.0",
"flax>=0.12.2",
"optax>=0.2.5",
"ml_dtypes>=0.5.0",
"jax-mps>=0.0.6; sys_platform == 'darwin' and platform_machine == 'arm64'",
]

skyrl-train = [
Expand Down Expand Up @@ -208,7 +210,7 @@ override-dependencies = [
"causal-conv1d; sys_platform == 'never'",
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'",
"megatron-core==0.16.0; sys_platform == 'linux'",
"ml_dtypes>=0.5.0; sys_platform == 'linux'",
"ml_dtypes>=0.5.0",
]

[tool.uv.extra-build-dependencies]
Expand Down
11 changes: 6 additions & 5 deletions skyrl/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,12 +787,13 @@ def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types
logger.warning(f"No accumulated gradients for model {model_id}; applying step with zero gradients")

# Update hyperparameters from the request
# Use direct .value assignment instead of [...] indexing to avoid MPS zero-sized tensor issues
hp = optimizer.opt_state.hyperparams
hp["learning_rate"][...] = learning_rate
hp["b1"][...] = request_data.adam_params.beta1
hp["b2"][...] = request_data.adam_params.beta2
hp["eps"][...] = request_data.adam_params.eps
hp["weight_decay"][...] = request_data.adam_params.weight_decay
hp["learning_rate"].value = jnp.asarray(learning_rate, dtype=hp["learning_rate"].value.dtype)
hp["b1"].value = jnp.asarray(request_data.adam_params.beta1, dtype=hp["b1"].value.dtype)
hp["b2"].value = jnp.asarray(request_data.adam_params.beta2, dtype=hp["b2"].value.dtype)
hp["eps"].value = jnp.asarray(request_data.adam_params.eps, dtype=hp["eps"].value.dtype)
hp["weight_decay"].value = jnp.asarray(request_data.adam_params.weight_decay, dtype=hp["weight_decay"].value.dtype)
Comment on lines +790 to +796
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code for updating hyperparameters is quite repetitive. To improve readability and maintainability, you could refactor this block to use a loop over a dictionary of parameters.

        # Use direct .value assignment instead of [...] indexing to avoid MPS zero-sized tensor issues
        hp = optimizer.opt_state.hyperparams
        params_to_update = {
            "learning_rate": learning_rate,
            "b1": request_data.adam_params.beta1,
            "b2": request_data.adam_params.beta2,
            "eps": request_data.adam_params.eps,
            "weight_decay": request_data.adam_params.weight_decay,
        }
        for name, value in params_to_update.items():
            hp[name].value = jnp.asarray(value, dtype=hp[name].value.dtype)


# JIT-compiled: compute full gradients, apply optimizer update, and reset accumulated grads
with jax.set_mesh(self.mesh):
Expand Down
5 changes: 4 additions & 1 deletion skyrl/tx/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,10 @@ def init_adapter(path, value):
if key_name == "lora_A":
# Reinitialize with he_uniform, then zero columns beyond rank
new_A = nnx.initializers.he_uniform()(rngs.params(), value[idx].shape, value.dtype)
new_A = new_A.at[..., effective_rank:].set(0.0)
# Zero columns beyond rank using multiplication (avoids MPS zero-sized tensor issues)
if new_A.shape[-1] > 0:
mask = jnp.arange(new_A.shape[-1]) < effective_rank
new_A = new_A * mask.astype(new_A.dtype)
return value.at[idx].set(new_A)
if key_name == "lora_B":
# Explicitly zero lora_B
Expand Down
3 changes: 2 additions & 1 deletion skyrl/tx/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ def load_safetensors(
tensor = tensor.reshape(param.shape)
assert param.shape == tensor.shape, f"shape mismatch for {key}"
# ArrayRef.set_raw_value writes through to the stacked parent variable
param.set_raw_value(jax.device_put(tensor.astype(param.dtype), param.sharding))
# Use np.ascontiguousarray to ensure data is contiguous (required by jax-mps backend)
param.set_raw_value(jax.device_put(np.ascontiguousarray(tensor.astype(param.dtype)), param.sharding))


def save_safetensors(
Expand Down
13 changes: 7 additions & 6 deletions tests/tx/models/test_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def test_qwen3(tp: int):

def load_moe_base_weights(jax_moe_layer: Qwen3MoeSparseMoeBlock, hf_moe_layer: HFQwen3MoeSparseMoeBlock) -> None:
"""Load base weights from HF MoE layer to JAX MoE layer."""
jax_moe_layer.gate.kernel[:] = hf_moe_layer.gate.weight.detach().numpy().T
jax_moe_layer.gate.kernel[:] = np.ascontiguousarray(hf_moe_layer.gate.weight.detach().numpy().T)
for i, expert in enumerate(hf_moe_layer.experts):
jax_moe_layer.experts.gate_proj.weight[i, :, :] = expert.gate_proj.weight.detach().numpy().T
jax_moe_layer.experts.up_proj.weight[i, :, :] = expert.up_proj.weight.detach().numpy().T
jax_moe_layer.experts.down_proj.weight[i, :, :] = expert.down_proj.weight.detach().numpy().T
jax_moe_layer.experts.gate_proj.weight[i, :, :] = np.ascontiguousarray(expert.gate_proj.weight.detach().numpy().T)
jax_moe_layer.experts.up_proj.weight[i, :, :] = np.ascontiguousarray(expert.up_proj.weight.detach().numpy().T)
jax_moe_layer.experts.down_proj.weight[i, :, :] = np.ascontiguousarray(expert.down_proj.weight.detach().numpy().T)


@pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)])
Expand Down Expand Up @@ -100,8 +100,9 @@ def load_lora_weights(
and jax_module.lora_scaling is not None
and jax_module.lora_ranks is not None
)
jax_module.lora_A[...] = jax_module.lora_A[...].at[adapter_idx].set(jnp.array(lora_A_weights))
jax_module.lora_B[...] = jax_module.lora_B[...].at[adapter_idx].set(jnp.array(lora_B_weights))
# Use np.ascontiguousarray to ensure data is contiguous (required by jax-mps backend)
jax_module.lora_A[...] = jax_module.lora_A[...].at[adapter_idx].set(jnp.array(np.ascontiguousarray(lora_A_weights)))
jax_module.lora_B[...] = jax_module.lora_B[...].at[adapter_idx].set(jnp.array(np.ascontiguousarray(lora_B_weights)))
jax_module.lora_scaling[...] = jax_module.lora_scaling[...].at[adapter_idx].set(scaling)
jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[adapter_idx].set(rank)

Expand Down
Loading