Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion skyrl/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ def __init__(self, base_model: str, config: BaseModel):
pass

@abstractmethod
def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None:
"""Create a new model in the backend.

Creates optimizer and configures LoRA adapter.

Args:
model_id: The model identifier
lora_config: LoRA configuration with rank and alpha
model_role: Logical role for the model (e.g. policy or critic)
"""
pass

Expand Down
10 changes: 7 additions & 3 deletions skyrl/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,11 +546,13 @@ def has_model(self, model_id: str) -> bool:
"""Check if a model is registered with the backend."""
return model_id in self.models

def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None:
"""Create a new model in the backend.

Creates optimizer and configures LoRA adapter. Allocates adapter_index internally.
"""
if model_role != "policy":
raise ValueError(f"JaxBackend only supports model_role='policy', got {model_role!r}")
# Allocate adapter index for this model_id (find first available slot)
# Index 0 is reserved for base model, so user models use indices 1 to max_lora_adapters-1
used_indices = {m.adapter_index for m in self.models.values()}
Expand Down Expand Up @@ -615,6 +617,8 @@ def _model_pass(
"""
if not prepared_batch.all_model_inputs:
return {}
if "ppo_critic" in prepared_batch.all_loss_fns:
raise ValueError("ppo_critic is only supported by the SkyRL-Train backend")

results = {}

Expand Down Expand Up @@ -1105,8 +1109,8 @@ def serialize(k, v):
)
return getattr(super(), method)(**kwargs)

def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
self._broadcast_and_call("create_model", model_id=model_id, lora_config=lora_config)
def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None:
self._broadcast_and_call("create_model", model_id=model_id, lora_config=lora_config, model_role=model_role)

def forward_backward(self, prepared_batch: types.PreparedModelPassBatch):
return self._broadcast_and_call("forward_backward", prepared_batch=prepared_batch)
Expand Down
4 changes: 4 additions & 0 deletions skyrl/backends/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ def empty_cache(self) -> None:
"""Empty GPU memory cache on Worker's CUDA device"""
torch.cuda.empty_cache()

def set_algorithm_config(self, **kwargs) -> None:
for key, value in kwargs.items():
setattr(self.cfg.algorithm, key, value)

def offload_to_cpu(self, pin_memory=True, non_blocking=True):
"""Offload all worker state to CPU.

Expand Down
9 changes: 9 additions & 0 deletions skyrl/backends/skyrl_train/workers/worker_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def __init__(
# GPU state tracking (only matters when colocated)
self._gpu_state: Dict[str, GPUState] = {name: GPUState() for name in self._actor_groups.keys()}

def register_actor_group(self, model: str, actor_group: PPORayActorGroup) -> None:
self._actor_groups[model] = actor_group
self._gpu_state[model] = GPUState()

def get_lcm_dp_size(self) -> int:
"""Get LCM of all models' dp_size."""
import math
Expand Down Expand Up @@ -288,6 +292,11 @@ def set_lr(self, model: str, learning_rate: float) -> None:
self._ensure_on_gpu(model, need_optimizer=True, need_model=False)
ray.get(self._actor_groups[model].async_run_ray_method("pass_through", "set_lr", learning_rate=learning_rate))

def set_algorithm_config(self, model: str, **kwargs) -> None:
"""Update algorithm config fields on all workers for a model."""
self._ensure_on_gpu(model, need_optimizer=False, need_model=False)
ray.get(self._actor_groups[model].async_run_ray_method("pass_through", "set_algorithm_config", **kwargs))

def _save_memory_snapshot(self, model: str, tag: str) -> None:
"""Save memory snapshot on workers."""
ray.get(
Expand Down
Loading
Loading