diff --git a/skyrl/backends/backend.py b/skyrl/backends/backend.py index 15c169144b..d4cbd11a1a 100644 --- a/skyrl/backends/backend.py +++ b/skyrl/backends/backend.py @@ -43,7 +43,7 @@ def __init__(self, base_model: str, config: BaseModel): pass @abstractmethod - def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: + def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None: """Create a new model in the backend. Creates optimizer and configures LoRA adapter. @@ -51,6 +51,7 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: Args: model_id: The model identifier lora_config: LoRA configuration with rank and alpha + model_role: Logical role for the model (e.g. policy or critic) """ pass diff --git a/skyrl/backends/jax.py b/skyrl/backends/jax.py index bb9a9de69f..787f25e8b0 100644 --- a/skyrl/backends/jax.py +++ b/skyrl/backends/jax.py @@ -546,11 +546,13 @@ def has_model(self, model_id: str) -> bool: """Check if a model is registered with the backend.""" return model_id in self.models - def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: + def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None: """Create a new model in the backend. Creates optimizer and configures LoRA adapter. Allocates adapter_index internally. """ + if model_role != "policy": + raise ValueError(f"JaxBackend only supports model_role='policy', got {model_role!r}") # Allocate adapter index for this model_id (find first available slot) # Index 0 is reserved for base model, so user models use indices 1 to max_lora_adapters-1 used_indices = {m.adapter_index for m in self.models.values()} @@ -615,6 +617,8 @@ def _model_pass( """ if not prepared_batch.all_model_inputs: return {} + if "ppo_critic" in prepared_batch.all_loss_fns: + raise ValueError("ppo_critic is only supported by the SkyRL-Train backend") results = {} @@ -1105,8 +1109,8 @@ def serialize(k, v): ) return getattr(super(), method)(**kwargs) - def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: - self._broadcast_and_call("create_model", model_id=model_id, lora_config=lora_config) + def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None: + self._broadcast_and_call("create_model", model_id=model_id, lora_config=lora_config, model_role=model_role) def forward_backward(self, prepared_batch: types.PreparedModelPassBatch): return self._broadcast_and_call("forward_backward", prepared_batch=prepared_batch) diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index 40efbf02e5..e43a2e64a1 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -244,6 +244,10 @@ def empty_cache(self) -> None: """Empty GPU memory cache on Worker's CUDA device""" torch.cuda.empty_cache() + def set_algorithm_config(self, **kwargs) -> None: + for key, value in kwargs.items(): + setattr(self.cfg.algorithm, key, value) + def offload_to_cpu(self, pin_memory=True, non_blocking=True): """Offload all worker state to CPU. diff --git a/skyrl/backends/skyrl_train/workers/worker_dispatch.py b/skyrl/backends/skyrl_train/workers/worker_dispatch.py index 54bacddf72..a2a27cdf9b 100644 --- a/skyrl/backends/skyrl_train/workers/worker_dispatch.py +++ b/skyrl/backends/skyrl_train/workers/worker_dispatch.py @@ -70,6 +70,10 @@ def __init__( # GPU state tracking (only matters when colocated) self._gpu_state: Dict[str, GPUState] = {name: GPUState() for name in self._actor_groups.keys()} + def register_actor_group(self, model: str, actor_group: PPORayActorGroup) -> None: + self._actor_groups[model] = actor_group + self._gpu_state[model] = GPUState() + def get_lcm_dp_size(self) -> int: """Get LCM of all models' dp_size.""" import math @@ -288,6 +292,11 @@ def set_lr(self, model: str, learning_rate: float) -> None: self._ensure_on_gpu(model, need_optimizer=True, need_model=False) ray.get(self._actor_groups[model].async_run_ray_method("pass_through", "set_lr", learning_rate=learning_rate)) + def set_algorithm_config(self, model: str, **kwargs) -> None: + """Update algorithm config fields on all workers for a model.""" + self._ensure_on_gpu(model, need_optimizer=False, need_model=False) + ray.get(self._actor_groups[model].async_run_ray_method("pass_through", "set_algorithm_config", **kwargs)) + def _save_memory_snapshot(self, model: str, tag: str) -> None: """Save memory snapshot on workers.""" ray.get( diff --git a/skyrl/backends/skyrl_train_backend.py b/skyrl/backends/skyrl_train_backend.py index 4622d60d88..214992a51b 100644 --- a/skyrl/backends/skyrl_train_backend.py +++ b/skyrl/backends/skyrl_train_backend.py @@ -1,8 +1,4 @@ -"""SkyRL-Train backend for TinkerEngine. - -Uses SkyRL-Train infrastructure for supervised training with cross-entropy loss. -Currently supports a single model only. -""" +"""SkyRL-Train backend for TinkerEngine.""" import asyncio import os @@ -74,11 +70,14 @@ def _build_skyrl_train_config( # NOTE: It is better to add this as a part of the CLI overrides since we have post_init logic # that will resolve other attributes such as the reference model path based on the policy model path. user_overrides["trainer.policy.model.path"] = base_model + user_overrides["trainer.critic.model.path"] = base_model cfg = SkyRLTrainConfig.from_cli_overrides(user_overrides) # Disable scheduler - Tinker manages learning rate externally via set_lr() cfg.trainer.policy.optimizer_config.scheduler = "constant_with_warmup" cfg.trainer.policy.optimizer_config.num_warmup_steps = 0 + cfg.trainer.critic.optimizer_config.scheduler = "constant_with_warmup" + cfg.trainer.critic.optimizer_config.num_warmup_steps = 0 # TODO(tyler): Support KL Loss cfg.trainer.algorithm.use_kl_loss = False @@ -114,18 +113,35 @@ def __init__(self, base_model: str, config: SkyRLTrainBackendOverrides): self.base_model = base_model # NOTE: We currently have two config attributes "config" which is just config overrides and "_cfg" which is the actual config object. This is a temporary state given that the Tinker engine expects a .config attribute self.config = config - self._model_id: str | None = None - self._model_metadata: types.ModelMetadata | None = None + self._model_ids: dict[str, str] = {} + self._model_metadata: dict[str, types.ModelMetadata] = {} self._cfg = None self._dispatch: WorkerDispatch | None = None + self._colocate_pg: ResolvedPlacementGroup | None = None self._tokenizer: AutoTokenizer = get_tokenizer(self.base_model) self._inference_engine_client = None self._inference_engines_initialized = False def has_model(self, model_id: str) -> bool: - return self._model_id == model_id - - def build_models(self, PolicyWorker): + return model_id in self._model_ids + + def _get_role(self, model_id: str) -> str: + try: + return self._model_ids[model_id] + except KeyError as exc: + raise ValueError(f"Model {model_id} not found") from exc + + def _get_batch_role(self, model_ids: list[str]) -> str: + if not model_ids: + return "policy" + roles = {self._get_role(model_id) for model_id in model_ids} + if len(roles) != 1: + raise ValueError(f"Mixed model roles in one batch are not supported: {sorted(roles)}") + if len(set(model_ids)) != 1: + raise ValueError(f"Mixed model_ids in one batch are not supported: {sorted(set(model_ids))}") + return next(iter(roles)) + + def _build_policy(self, PolicyWorker): cfg = self._cfg colocate_all = cfg.trainer.placement.colocate_all pg = self._colocate_pg @@ -182,6 +198,36 @@ def build_models(self, PolicyWorker): logger.info("init policy model done") + def _build_critic(self, CriticWorker, lora_config: types.LoraConfig) -> None: + cfg = self._cfg + colocate_all = cfg.trainer.placement.colocate_all + if colocate_all: + num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes + num_critic_gpus = cfg.trainer.placement.critic_num_gpus_per_node * cfg.trainer.placement.critic_num_nodes + assert ( + num_policy_gpus == num_critic_gpus + ), "num_policy_gpus and num_critic_gpus must be the same when colocating policy and critic model" + + cfg.trainer.critic.model.lora.rank = lora_config.rank + cfg.trainer.critic.model.lora.alpha = int(lora_config.alpha) + critic_model = PPORayActorGroup( + cfg.trainer, + cfg.trainer.placement.critic_num_nodes, + cfg.trainer.placement.critic_num_gpus_per_node, + CriticWorker, + pg=self._colocate_pg, + num_gpus_per_actor=0.2 if colocate_all else 1, + colocate_all=colocate_all, + sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size, + ) + self._dispatch.register_actor_group("critic", critic_model) + self._dispatch.init_model("critic", cfg.trainer.critic.model.path, num_training_steps=1e9) + ray.get(critic_model.async_run_ray_method("pass_through", "_set_pad_token_id", self._tokenizer.pad_token_id)) + if colocate_all: + critic_model.offload_to_cpu() + self._dispatch.mark_all_offloaded() + logger.info("init critic model done") + def init_weight_sync_state(self): """ Setup the connection between policy model and inference engine for weight syncing. @@ -206,39 +252,52 @@ def _ensure_inference_engines(self): self.init_weight_sync_state() self._inference_engines_initialized = True - def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: - if self._model_id is not None: - raise ValueError(f"Model '{self._model_id}' already exists. Only one model supported.") + def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role: str = "policy") -> None: + if model_id in self._model_ids: + raise ValueError(f"Model '{model_id}' already exists") + if model_role in self._model_ids.values(): + raise ValueError(f"SkyRLTrainBackend already has a '{model_role}' model") - # Build config - self._cfg = _build_skyrl_train_config(self.base_model, self.config, lora_config) + if model_role == "policy": + self._cfg = _build_skyrl_train_config(self.base_model, self.config, lora_config) - if not ray.is_initialized(): - logger.info("Initializing Ray with runtime environment") - initialize_ray(self._cfg) + if not ray.is_initialized(): + logger.info("Initializing Ray with runtime environment") + initialize_ray(self._cfg) - # Create shared placement group only when colocating training + inference - if self._cfg.trainer.placement.colocate_all: - self._colocate_pg = self._create_colocate_pg() - else: - self._colocate_pg = None - - # Get worker types based on strategy - if self._cfg.trainer.strategy in ("fsdp", "fsdp2"): - from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker - elif self._cfg.trainer.strategy == "megatron": - from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import ( - PolicyWorker, - ) - else: - raise ValueError(f"Unknown strategy type: {self._cfg.trainer.strategy}") + self._colocate_pg = self._create_colocate_pg() if self._cfg.trainer.placement.colocate_all else None - logger.info("Building models.") - self.build_models(PolicyWorker) + if self._cfg.trainer.strategy in ("fsdp", "fsdp2"): + from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import ( + PolicyWorker, + ) + elif self._cfg.trainer.strategy == "megatron": + from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import ( + PolicyWorker, + ) + else: + raise ValueError(f"Unknown strategy type: {self._cfg.trainer.strategy}") + + logger.info("Building models.") + self._build_policy(PolicyWorker) + elif model_role == "critic": + if "policy" not in self._model_ids.values(): + raise ValueError("Create a policy model before creating a critic model") + if self._cfg.trainer.strategy in ("fsdp", "fsdp2"): + from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import ( + CriticWorker, + ) + elif self._cfg.trainer.strategy == "megatron": + raise NotImplementedError("Critic model support is not implemented for the Megatron backend yet") + else: + raise ValueError(f"Unknown strategy type: {self._cfg.trainer.strategy}") + self._build_critic(CriticWorker, lora_config) + else: + raise ValueError(f"Unknown model_role: {model_role}") - self._model_id = model_id - self._model_metadata = types.ModelMetadata(adapter_index=0, lora_config=lora_config) - logger.info(f"Created model {model_id} using RayPPOTrainer") + self._model_ids[model_id] = model_role + self._model_metadata[model_id] = types.ModelMetadata(adapter_index=0, lora_config=lora_config) + logger.info(f"Created {model_role} model {model_id} using RayPPOTrainer") def _create_colocate_pg(self): """Create a placement group for colocated training + inference.""" @@ -256,12 +315,11 @@ def _create_colocate_pg(self): return ResolvedPlacementGroup(pg) def delete_model(self, model_id: str) -> None: - if self._model_id != model_id: - raise ValueError(f"Model {model_id} not found") + self._get_role(model_id) # TODO: For now, prefer shutting down the backend and re-launching. Will be improved shortly. raise NotImplementedError("Deleting models not yet implemented") - def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch) -> TrainingInputBatch: + def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch, role: str) -> TrainingInputBatch: """Convert PreparedModelPassBatch to TrainingInputBatch.""" if not prepared_batch.all_model_inputs: return TrainingInputBatch({}) @@ -281,12 +339,15 @@ def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch) -> Tr sequences, attention_masks, loss_masks, response_masks = [], [], [], [] action_log_probs_list, advantages_list = [], [] + values_list, returns_list = [], [] - for seq, weights, logprobs, advs in zip( + for seq, weights, logprobs, advs, values, returns in zip( full_sequences, prepared_batch.all_token_weights, prepared_batch.all_sampling_logprobs, prepared_batch.all_advantages, + prepared_batch.all_values, + prepared_batch.all_returns, ): pad_len = max_seq_len - len(seq) sequences.append([self._tokenizer.pad_token_id] * pad_len + list(seq)) @@ -296,6 +357,8 @@ def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch) -> Tr response_masks.append([0] * action_pad + [1] * len(weights)) action_log_probs_list.append([0.0] * action_pad + [float(lp) for lp in logprobs]) advantages_list.append([0.0] * action_pad + [float(a) for a in advs]) + values_list.append([0.0] * action_pad + [float(v) for v in values]) + returns_list.append([0.0] * action_pad + [float(r) for r in returns]) sequences_tensor = torch.tensor(sequences, dtype=torch.long) attention_mask_tensor = torch.tensor(attention_masks, dtype=torch.long) @@ -316,6 +379,9 @@ def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch) -> Tr batch_dict["action_log_probs"] = torch.tensor(action_log_probs_list, dtype=torch.float32) if has_advantages: batch_dict["advantages"] = torch.tensor(advantages_list, dtype=torch.float32) + if role == "critic": + batch_dict["values"] = torch.tensor(values_list, dtype=torch.float32) + batch_dict["returns"] = torch.tensor(returns_list, dtype=torch.float32) batch = TrainingInputBatch(batch_dict) batch.metadata = {"response_length": max_response_len} @@ -377,8 +443,18 @@ def _extract_metrics(self, data: dict) -> dict[str, float]: metrics["pg_loss:sum"] = float(data["policy_loss"]) if "policy_entropy" in data: metrics["entropy_loss:sum"] = float(data["policy_entropy"]) + if "critic_loss" in data: + metrics["critic_loss:sum"] = float(data["critic_loss"]) + if "values_mean" in data: + metrics["values_mean:mean"] = float(data["values_mean"]) + if "values_clipfrac" in data: + metrics["values_clipfrac:mean"] = float(data["values_clipfrac"]) if "response_length" in data: metrics["num_tokens:sum"] = float(data["response_length"]) + if "policy_lr" in data: + metrics["policy_lr:last"] = float(data["policy_lr"]) + if "critic_lr" in data: + metrics["critic_lr:last"] = float(data["critic_lr"]) return metrics @@ -387,6 +463,12 @@ def _sleep_inference_engines(self): if self._inference_engines_initialized and self._cfg.trainer.placement.colocate_all: asyncio.run(self._inference_engine_client.sleep()) + def _validate_batch_role_and_loss(self, role: str, loss_fn: str): + if role == "critic" and loss_fn != "ppo_critic": + raise ValueError(f"Critic batches must use loss_fn='ppo_critic', got {loss_fn!r}") + if role != "critic" and loss_fn == "ppo_critic": + raise ValueError("loss_fn='ppo_critic' is only valid for critic models") + def forward_backward( self, prepared_batch: types.PreparedModelPassBatch, @@ -395,7 +477,17 @@ def forward_backward( return {} self._sleep_inference_engines() - batch = self._to_training_batch(prepared_batch) + role = self._get_batch_role(prepared_batch.all_model_ids) + loss_fn = prepared_batch.all_loss_fns[0] + self._validate_batch_role_and_loss(role, loss_fn) + if role == "critic" and any( + len(values) != len(weights) or len(returns) != len(weights) + for values, returns, weights in zip( + prepared_batch.all_values, prepared_batch.all_returns, prepared_batch.all_token_weights + ) + ): + raise ValueError("Critic forward_backward requires values and returns for every response token") + batch = self._to_training_batch(prepared_batch, role) micro_bs = ( self._cfg.trainer.micro_train_batch_size_per_gpu if self._cfg.trainer.strategy == "megatron" else None ) @@ -409,12 +501,19 @@ def forward_backward( loss_fn, ) loss_fn_config = next((c for c in prepared_batch.all_loss_fn_configs if c is not None), None) - data = self._dispatch.forward_backward( - "policy", - batch, - loss_fn=loss_fn, - loss_fn_config=loss_fn_config, - ) + if role == "critic": + self._dispatch.set_algorithm_config( + "critic", + value_clip=(loss_fn_config or {}).get("value_clip", self._cfg.trainer.algorithm.value_clip), + ) + data = self._dispatch.forward_backward("critic", batch) + else: + data = self._dispatch.forward_backward( + role, + batch, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + ) # Trim padding entries from loss_fn_outputs if pad_size > 0 and "loss_fn_outputs" in data: @@ -424,25 +523,22 @@ def forward_backward( results = {} for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices: - loss_fn_outputs = [] - for i in range(start_idx, end_idx): - raw_output = data["loss_fn_outputs"][i] - logprobs = list(raw_output.get("logprobs", [])) - elementwise_loss = list(raw_output.get("elementwise_loss", [])) - loss_fn_outputs.append( - { - "elementwise_loss": { - "data": elementwise_loss, - "dtype": "float32", - "shape": [len(elementwise_loss)], - }, - "logprobs": { - "data": logprobs, - "dtype": "float32", - "shape": [len(logprobs)], - }, - } - ) + if "loss_fn_outputs" in data: + loss_fn_outputs = [] + for i in range(start_idx, end_idx): + raw_output = data["loss_fn_outputs"][i] + formatted_output = {} + for key in ("elementwise_loss", "logprobs", "values"): + values = list(raw_output.get(key, [])) + if values or key in raw_output: + formatted_output[key] = { + "data": values, + "dtype": "float32", + "shape": [len(values)], + } + loss_fn_outputs.append(formatted_output) + else: + loss_fn_outputs = [{} for _ in range(end_idx - start_idx)] results[request_id] = types.ForwardBackwardOutput( loss_fn_output_type="scalar", loss_fn_outputs=loss_fn_outputs, @@ -458,18 +554,21 @@ def forward( return {} self._sleep_inference_engines() - batch = self._to_training_batch(prepared_batch) + role = self._get_batch_role(prepared_batch.all_model_ids) + loss_fn = prepared_batch.all_loss_fns[0] + self._validate_batch_role_and_loss(role, loss_fn) + batch = self._to_training_batch(prepared_batch, role) micro_bs = ( self._cfg.trainer.micro_forward_batch_size_per_gpu if self._cfg.trainer.strategy == "megatron" else None ) batch, pad_size = self._pad_batch(batch, micro_batch_size=micro_bs) - data = self._dispatch.forward("policy", batch) + data = self._dispatch.forward(role, batch) # dispatch.forward() returns TrainingOutputBatch({"output": tensor[batch, max_response_len]}) # Trim padding entries from output - output_logprobs = data["output"] + output_tensor = data["output"] if pad_size > 0: - output_logprobs = output_logprobs[:-pad_size] + output_tensor = output_tensor[:-pad_size] results = {} for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices: @@ -477,14 +576,15 @@ def forward( for i in range(start_idx, end_idx): # Use token weights length to determine each example's actual response length valid_len = len(prepared_batch.all_token_weights[i]) - start = max(output_logprobs.shape[1] - valid_len, 0) - logprobs = output_logprobs[i, start:].tolist() + start = max(output_tensor.shape[1] - valid_len, 0) + values = output_tensor[i, start:].tolist() + output_key = "values" if role == "critic" else "logprobs" loss_fn_outputs.append( { - "logprobs": { - "data": logprobs, + output_key: { + "data": values, "dtype": "float32", - "shape": [len(logprobs)], + "shape": [len(values)], }, } ) @@ -496,15 +596,14 @@ def forward( return results def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput: - if model_id != self._model_id: - raise ValueError(f"Model {model_id} not found") + role = self._get_role(model_id) # Apply learning rate from AdamParams before optimizer step # Note: beta1, beta2, eps are fixed at optimizer creation and cannot be changed dynamically adam_params = request_data.adam_params - self._dispatch.set_lr("policy", adam_params.learning_rate) + self._dispatch.set_lr(role, adam_params.learning_rate) - grad_norm = self._dispatch.optim_step("policy") + grad_norm = self._dispatch.optim_step(role) logger.info(f"optim_step: lr={adam_params.learning_rate}, grad_norm={grad_norm}") metrics: dict[str, float] = {} @@ -528,9 +627,15 @@ def sample( # 2. Validate single model unique_models = set(prepared_batch.all_model_ids) - if unique_models != {self._model_id}: + if len(unique_models) != 1: + error = types.ErrorResponse( + error=f"Expected exactly one model_id for sampling, got {unique_models}", status="error" + ) + return {req_id: error for req_id, _, _, _, _ in prepared_batch.request_batch_slices} + model_id = next(iter(unique_models)) + if self._get_role(model_id) != "policy": error = types.ErrorResponse( - error=f"Model mismatch. Expected {self._model_id}, got {unique_models}", status="error" + error=f"Sampling is only supported for policy models, got '{model_id}'", status="error" ) return {req_id: error for req_id, _, _, _, _ in prepared_batch.request_batch_slices} @@ -645,8 +750,7 @@ def _aggregate_sample_results( def _validate_model_state(self, model_id: str) -> None: """Validate that model exists and is initialized.""" - if model_id != self._model_id: - raise ValueError(f"Model {model_id} not found") + self._get_role(model_id) if self._dispatch is None: raise RuntimeError("Model not initialized") @@ -662,13 +766,14 @@ def _create_tar_from_directory(self, source_dir: str, output_path: str) -> None: def save_checkpoint(self, output_path, model_id: str) -> None: """Save full training checkpoint (model + optimizer + scheduler) as tar.""" self._validate_model_state(model_id) + role = self._get_role(model_id) # Create temp directory for checkpoint with tempfile.TemporaryDirectory() as temp_dir: ckpt_dir = os.path.join(temp_dir, "checkpoint") # Save checkpoint directory (includes optimizer state automatically) - self._dispatch.save_checkpoint(model="policy", ckpt_dir=ckpt_dir, tokenizer=self._tokenizer) + self._dispatch.save_checkpoint(model=role, ckpt_dir=ckpt_dir, tokenizer=self._tokenizer) # Create tar archive self._create_tar_from_directory(ckpt_dir, output_path) @@ -678,6 +783,7 @@ def save_checkpoint(self, output_path, model_id: str) -> None: def load_checkpoint(self, checkpoint_path, model_id: str) -> None: """Load full training checkpoint (model + optimizer + scheduler) from tar.""" self._validate_model_state(model_id) + role = self._get_role(model_id) # Extract tar to temp directory (filter='data' prevents path traversal attacks) with tempfile.TemporaryDirectory() as temp_dir: @@ -686,7 +792,7 @@ def load_checkpoint(self, checkpoint_path, model_id: str) -> None: # Load checkpoint (includes optimizer and scheduler states) self._dispatch.load_checkpoint( - model="policy", ckpt_dir=temp_dir, load_optimizer_states=True, load_lr_scheduler_states=True + model=role, ckpt_dir=temp_dir, load_optimizer_states=True, load_lr_scheduler_states=True ) logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}") @@ -699,6 +805,8 @@ def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = Tr loops) the expensive HuggingFace model export is skipped entirely. """ self._validate_model_state(model_id) + if self._get_role(model_id) != "policy": + raise ValueError("save_sampler_checkpoint is only supported for policy models") # Lazily create inference engines on first sampling-related call self._ensure_inference_engines() diff --git a/skyrl/tinker/api.py b/skyrl/tinker/api.py index f505be77a4..bdf3f37d81 100644 --- a/skyrl/tinker/api.py +++ b/skyrl/tinker/api.py @@ -250,6 +250,7 @@ class CreateModelRequest(BaseModel): session_id: str base_model: str lora_config: LoRAConfig + model_role: str = "policy" class CreateModelResponse(BaseModel): @@ -399,6 +400,8 @@ def to_types(self) -> types.Datum: weights=weights, advantages=inp["advantages"].to_types() if "advantages" in inp else types.TensorData(data=[]), logprobs=inp["logprobs"].to_types() if "logprobs" in inp else types.TensorData(data=[]), + values=inp["values"].to_types() if "values" in inp else types.TensorData(data=[]), + returns=inp["returns"].to_types() if "returns" in inp else types.TensorData(data=[]), ), model_input=self.model_input.to_types(), ) @@ -410,10 +413,11 @@ class ForwardBackwardInput(BaseModel): "importance_sampling": set(), "ppo": {"clip_low_threshold", "clip_high_threshold"}, "cispo": {"clip_low_threshold", "clip_high_threshold"}, + "ppo_critic": {"value_clip"}, } data: list[Datum] - loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo"] + loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "ppo_critic"] loss_fn_config: dict[str, float] | None = None @model_validator(mode="after") @@ -749,7 +753,7 @@ async def create_model(request: CreateModelRequest, session: AsyncSession = Depe session=session, request_type=types.RequestType.CREATE_MODEL, model_id=model_id, - request_data=types.CreateModelInput(lora_config=lora_config), + request_data=types.CreateModelInput(lora_config=lora_config, model_role=request.model_role), ) model_db = ModelDB( diff --git a/skyrl/tinker/engine.py b/skyrl/tinker/engine.py index 6c449eb9f7..b75d17d93d 100644 --- a/skyrl/tinker/engine.py +++ b/skyrl/tinker/engine.py @@ -119,6 +119,8 @@ def prepare_model_pass_batch( all_model_ids = [] all_sampling_logprobs = [] all_advantages = [] + all_values = [] + all_returns = [] all_loss_fns = [] all_loss_fn_configs = [] request_batch_slices = [] @@ -136,6 +138,8 @@ def prepare_model_pass_batch( all_token_weights.append(loss_fn_inputs.weights.data) all_sampling_logprobs.append(loss_fn_inputs.logprobs.data) all_advantages.append(loss_fn_inputs.advantages.data) + all_values.append(loss_fn_inputs.values.data) + all_returns.append(loss_fn_inputs.returns.data) all_model_ids.append(model_id) all_loss_fns.append(request_data.loss_fn) all_loss_fn_configs.append(request_data.loss_fn_config) @@ -148,6 +152,8 @@ def prepare_model_pass_batch( all_token_weights=all_token_weights, all_sampling_logprobs=all_sampling_logprobs, all_advantages=all_advantages, + all_values=all_values, + all_returns=all_returns, all_model_ids=all_model_ids, all_loss_fns=all_loss_fns, all_loss_fn_configs=all_loss_fn_configs, @@ -405,7 +411,7 @@ def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.R def process_create_model(self, model_id: str, request_data: types.CreateModelInput) -> types.CreateModelOutput: """Create and initialize a model.""" # Create model in backend (allocates adapter_index, creates optimizer, and configures adapter) - self.backend.create_model(model_id, request_data.lora_config) + self.backend.create_model(model_id, request_data.lora_config, model_role=request_data.model_role) logger.info(f"Created LoRA model {model_id}") diff --git a/skyrl/tinker/loss_fns.py b/skyrl/tinker/loss_fns.py index 2a394cb64f..e5475493ea 100644 --- a/skyrl/tinker/loss_fns.py +++ b/skyrl/tinker/loss_fns.py @@ -78,12 +78,23 @@ def cispo_loss( return -safe_loss_mask(cispo_objective, loss_mask) +def ppo_critic_loss( + _target_logprobs: jax.Array, + _loss_mask: jax.Array, + _sampling_logprobs: jax.Array, + _advantages: jax.Array, + _loss_fn_config: LossFnConfig, +) -> jax.Array: + return jnp.zeros_like(_loss_mask) + + # Map from string names to loss functions LOSS_FUNCTION_MAP = { "cross_entropy": cross_entropy_loss, "importance_sampling": importance_sampling_loss, "ppo": ppo_loss, "cispo": cispo_loss, + "ppo_critic": ppo_critic_loss, } # Build list of functions indexed by LOSS_TYPES values (for jax.lax.switch) diff --git a/skyrl/tinker/types.py b/skyrl/tinker/types.py index a0ec7b987b..6105ab5141 100644 --- a/skyrl/tinker/types.py +++ b/skyrl/tinker/types.py @@ -9,7 +9,7 @@ from typing import Annotated, Literal from urllib.parse import urlparse -from pydantic import Base64Bytes, BaseModel, Discriminator +from pydantic import Base64Bytes, BaseModel, Discriminator, Field class RequestType(str, Enum): @@ -74,6 +74,7 @@ class LoraConfig(BaseModel): class CreateModelInput(BaseModel): lora_config: LoraConfig + model_role: str = "policy" class CreateModelOutput(BaseModel): @@ -143,6 +144,8 @@ class LossFnInputs(BaseModel): weights: TensorData advantages: TensorData logprobs: TensorData + values: TensorData = Field(default_factory=lambda: TensorData(data=[])) + returns: TensorData = Field(default_factory=lambda: TensorData(data=[])) class Datum(BaseModel): @@ -152,7 +155,7 @@ class Datum(BaseModel): class ForwardBackwardInput(BaseModel): data: list[Datum] - loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo"] + loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "ppo_critic"] loss_fn_config: dict[str, float] | None = None @@ -264,6 +267,8 @@ class PreparedModelPassBatch(BaseModel): all_token_weights: list[list[float]] all_sampling_logprobs: list[list[float]] all_advantages: list[list[float]] + all_values: list[list[float]] + all_returns: list[list[float]] # Per-example scalars all_model_ids: list[str] @@ -300,4 +305,5 @@ class PreparedSampleBatch(BaseModel): "importance_sampling": 1, "ppo": 2, "cispo": 3, + "ppo_critic": 4, } diff --git a/tests/backends/test_jax_backend.py b/tests/backends/test_jax_backend.py index 751fd20d1f..b6cce68525 100644 --- a/tests/backends/test_jax_backend.py +++ b/tests/backends/test_jax_backend.py @@ -44,6 +44,12 @@ def test_delete_model_basic(): assert not backend.has_model(model_id) +def test_create_model_rejects_non_policy_role(): + backend = create_backend() + with pytest.raises(ValueError, match="model_role='policy'"): + backend.create_model("critic_model", LoraConfig(rank=LORA_RANK, alpha=16, seed=0), model_role="critic") + + def test_delete_non_existent_model(): """Test deleting a non-existent model raises ValueError.""" backend = create_backend() @@ -166,6 +172,34 @@ def make_fwd_bwd_input(token_lists: list[list[int]]) -> types.ForwardBackwardInp return types.ForwardBackwardInput(data=samples, loss_fn="cross_entropy") +def test_forward_backward_rejects_ppo_critic_loss(): + backend = create_backend() + create_model(backend, "model_1") + reqs = { + "req1": ( + "model_1", + types.ForwardBackwardInput( + data=[ + types.Datum( + model_input=types.ModelInput(chunks=[types.EncodedTextChunk(tokens=[1, 2, 3])]), + loss_fn_inputs=types.LossFnInputs( + target_tokens=types.TensorData(data=[2, 3, 0]), + weights=types.TensorData(data=[1.0, 1.0, 1.0]), + advantages=types.TensorData(data=[]), + logprobs=types.TensorData(data=[]), + values=types.TensorData(data=[0.1, 0.2, 0.3]), + returns=types.TensorData(data=[0.4, 0.5, 0.6]), + ), + ) + ], + loss_fn="ppo_critic", + ), + ) + } + with pytest.raises(ValueError, match="ppo_critic is only supported"): + backend.forward_backward(prepare_model_pass_batch(reqs)) + + def _assert_tree_allclose(t1, t2, rtol=1e-3, atol=1e-3, min_match_pct=99.0): """Assert that at least min_match_pct% of elements in two trees are close.""" leaves1 = jax.tree.leaves(t1) diff --git a/tests/tinker/test_api_validation.py b/tests/tinker/test_api_validation.py index ae0d032dc8..4599123394 100644 --- a/tests/tinker/test_api_validation.py +++ b/tests/tinker/test_api_validation.py @@ -27,6 +27,15 @@ def test_forward_backward_input_accepts_ppo_threshold_keys(): assert req.loss_fn_config == {"clip_low_threshold": 0.9, "clip_high_threshold": 1.1} +def test_forward_backward_input_accepts_ppo_critic_value_clip(): + req = api.ForwardBackwardInput( + data=[_make_datum()], + loss_fn="ppo_critic", + loss_fn_config={"value_clip": 0.2}, + ) + assert req.loss_fn_config == {"value_clip": 0.2} + + def test_forward_backward_input_rejects_invalid_ppo_loss_fn_config_keys(): with pytest.raises(ValidationError, match="Invalid loss_fn_config keys"): api.ForwardBackwardInput( @@ -45,6 +54,27 @@ def test_forward_backward_input_rejects_loss_fn_config_for_cross_entropy(): ) +def test_datum_to_types_defaults_values_and_returns_to_empty(): + datum = _make_datum().to_types() + assert datum.loss_fn_inputs.values.data == [] + assert datum.loss_fn_inputs.returns.data == [] + + +def test_datum_to_types_preserves_values_and_returns(): + datum = api.Datum( + model_input=api.ModelInput(chunks=[api.EncodedTextChunk(tokens=[1, 2, 3])]), + loss_fn_inputs={ + "target_tokens": api.TensorData(data=[2, 3, 4]), + "weights": api.TensorData(data=[1.0, 1.0, 1.0]), + "values": api.TensorData(data=[0.1, 0.2, 0.3]), + "returns": api.TensorData(data=[0.4, 0.5, 0.6]), + }, + ).to_types() + + assert datum.loss_fn_inputs.values.data == [0.1, 0.2, 0.3] + assert datum.loss_fn_inputs.returns.data == [0.4, 0.5, 0.6] + + # --- ModelInputChunk discriminator tests (api) --- _api_adapter = TypeAdapter(api.ModelInputChunk) diff --git a/tests/tinker/test_engine.py b/tests/tinker/test_engine.py index 2d3db81946..4108029b0f 100644 --- a/tests/tinker/test_engine.py +++ b/tests/tinker/test_engine.py @@ -83,23 +83,28 @@ def test_cleanup_stale_sessions(): @pytest.mark.parametrize( - ("loss_fn", "loss_fn_config", "advantages", "logprobs"), + ("loss_fn", "loss_fn_config", "advantages", "logprobs", "values", "returns"), [ pytest.param( "ppo", {"clip_low_threshold": 0.7, "clip_high_threshold": 1.3}, [], [], + [], + [], id="ppo_with_loss_fn_config", ), - pytest.param("cross_entropy", None, [], [], id="cross_entropy_default_config"), + pytest.param("cross_entropy", None, [], [], [], [], id="cross_entropy_default_config"), pytest.param( "cispo", {"clip_low_threshold": 0.7, "clip_high_threshold": 1.3}, [0.1, 0.2, 0.3], [-1.1, -1.0, -0.9], + [], + [], id="cispo", ), + pytest.param("ppo_critic", {"value_clip": 0.2}, [], [], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9], id="ppo_critic"), ], ) def test_prepare_model_pass_batch_loss_fn_and_config( @@ -107,6 +112,8 @@ def test_prepare_model_pass_batch_loss_fn_and_config( loss_fn_config: dict[str, float] | None, advantages: list[float], logprobs: list[float], + values: list[float], + returns: list[float], ): """Test that prepare_model_pass_batch preserves loss_fn and loss_fn_config values.""" datum = types.Datum( @@ -116,6 +123,8 @@ def test_prepare_model_pass_batch_loss_fn_and_config( weights=types.TensorData(data=[1.0, 1.0, 1.0]), advantages=types.TensorData(data=advantages), logprobs=types.TensorData(data=logprobs), + values=types.TensorData(data=values), + returns=types.TensorData(data=returns), ), ) @@ -134,3 +143,5 @@ def test_prepare_model_pass_batch_loss_fn_and_config( assert batch.all_loss_fns == [loss_fn] assert batch.all_loss_fn_configs == [loss_fn_config] assert batch.all_model_inputs == [datum.model_input] + assert batch.all_values == [values] + assert batch.all_returns == [returns]