diff --git a/skyrl-agent/skyrl_agent/functional/utils.py b/skyrl-agent/skyrl_agent/functional/utils.py index a5436bb65d..3796b68954 100644 --- a/skyrl-agent/skyrl_agent/functional/utils.py +++ b/skyrl-agent/skyrl_agent/functional/utils.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional from dataclasses import dataclass, field from functools import wraps +from loguru import logger # Type aliases for Transition Observation = Dict[str, Any] @@ -122,7 +123,7 @@ class TrainingDatum: input_tokens: List[int] response_tokens: List[int] - response_logprobs: List[float] + response_logprobs: Optional[List[float]] # None when logprobs unavailable (e.g. external actions) response_mask: List[float] # 0 for observation tokens, 1 for action tokens @@ -157,25 +158,49 @@ def transitions_to_training_data( List of TrainingDatum objects """ - # Accumulator state for building sequences + if not transitions: + return [] + + data: List[TrainingDatum] = [] + full_sequence: List[int] = [] sampled_logprobs: List[float] = [] mask: List[float] = [] + has_valid_logprobs: bool = True - data: List[TrainingDatum] = [] - - def make_datum(): + def finalize_datum() -> Optional[TrainingDatum]: """Create a TrainingDatum from current accumulated state.""" if not full_sequence: return None first_nonzero = mask.index(1) if 1 in mask else len(mask) - # till the first non-zero mask + if first_nonzero == len(mask): + logger.warning("Datum has no action tokens (all mask is 0), skipping") + return None + input_tokens = full_sequence[:first_nonzero] response_tokens = full_sequence[first_nonzero:] - response_logprobs = sampled_logprobs[first_nonzero:] response_mask = mask[first_nonzero:] + if has_valid_logprobs: + response_logprobs = sampled_logprobs[first_nonzero:] + + if len(response_logprobs) != len(response_tokens): + logger.error( + f"response_logprobs length ({len(response_logprobs)}) " + f"!= response_tokens length ({len(response_tokens)})" + ) + return None + else: + response_logprobs = None + + if len(response_mask) != len(response_tokens): + logger.error( + f"Length mismatch: response_mask ({len(response_mask)}) " + f"!= response_tokens ({len(response_tokens)})" + ) + return None + return TrainingDatum( input_tokens=input_tokens, response_tokens=response_tokens, @@ -183,52 +208,73 @@ def make_datum(): response_mask=response_mask, ) - def clear_accumulator(): - """Clear the accumulator state.""" - nonlocal full_sequence, sampled_logprobs, mask + def reset_accumulator(): + """Clear accumulator for next datum.""" + nonlocal full_sequence, sampled_logprobs, mask, has_valid_logprobs full_sequence = [] sampled_logprobs = [] mask = [] + has_valid_logprobs = True # Process each transition - for transition in transitions: - # Get observation tokens + for idx, transition in enumerate(transitions): + # Validation + if transition.ob is None: + logger.warning(f"Transition {idx} has None observation, skipping") + continue + if transition.ac is None: + logger.warning(f"Transition {idx} has None action, skipping") + continue + ob_tokens = transition.ob.input_ids + if not ob_tokens: + logger.warning(f"Transition {idx} has empty observation tokens, skipping") + continue - # Get action tokens and logprobs ac_tokens = transition.ac.token_ids - ac_logprobs = transition.ac.logprobs or [0.0] * len(ac_tokens) + if not ac_tokens: + logger.warning(f"Transition {idx} has empty action tokens, skipping") + continue + + ac_logprobs = transition.ac.logprobs + transition_has_valid_logprobs = ( + ac_logprobs is not None and len(ac_logprobs) == len(ac_tokens) + ) + + if not transition_has_valid_logprobs: + if ac_logprobs is None: + logger.debug(f"Transition {idx} has no logprobs") + else: + logger.warning( + f"Transition {idx} has mismatched logprobs " + f"({len(ac_logprobs)} logprobs vs {len(ac_tokens)} tokens)" + ) - # Determine delta observation (new tokens not in accumulated sequence) if len(full_sequence) == 0: - # First transition, use all observation tokens + # First transition always starts a new datum delta_ob_tokens = ob_tokens elif _is_prefix(full_sequence, ob_tokens): - # Current observation extends previous sequence - # Only add the delta (new tokens) - delta_ob_tokens = ob_tokens[len(full_sequence) :] + delta_ob_tokens = ob_tokens[len(full_sequence):] else: - # Current observation doesn't extend previous sequence - # Save current accumulated datum and start fresh - datum = make_datum() + datum = finalize_datum() if datum: data.append(datum) - clear_accumulator() + reset_accumulator() delta_ob_tokens = ob_tokens - # Add delta observation tokens to sequence + has_valid_logprobs = has_valid_logprobs and transition_has_valid_logprobs + + # Accumulate tokens (pad logprobs with 0.0 when invalid — safe since has_valid_logprobs gates usage) full_sequence.extend(delta_ob_tokens) sampled_logprobs.extend([0.0] * len(delta_ob_tokens)) mask.extend([0.0] * len(delta_ob_tokens)) - # Add action tokens to sequence full_sequence.extend(ac_tokens) - sampled_logprobs.extend(ac_logprobs) + sampled_logprobs.extend(ac_logprobs if transition_has_valid_logprobs else [0.0] * len(ac_tokens)) mask.extend([1.0] * len(ac_tokens)) - # Create final datum from remaining accumulated state if full_sequence: - datum = make_datum() + datum = finalize_datum() if datum: data.append(datum) diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index cfef56ace2..2c982b2c00 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -9,7 +9,7 @@ import time from dataclasses import dataclass from http import HTTPStatus -from types import SimpleNamespace +from types import SimpleNamespace, MethodType from uuid import uuid4 import ray @@ -342,6 +342,7 @@ class AsyncVLLMInferenceEngine(BaseVLLMInferenceEngine): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._weight_loader = VLLMWeightLoader(self.llm, is_async=True) + self._active_lora_id = None def _create_engine(self, *args, **kwargs): openai_kwargs = pop_openai_kwargs(kwargs) @@ -404,6 +405,25 @@ def _create_engine(self, *args, **kwargs): **openai_kwargs, ) + if self._is_lora: + original = self.openai_serving_chat._maybe_get_adapters + async_engine = self # capture outer self safely + + def patched(self_chat, request, *args, **kwargs): + active_lora_id = getattr(async_engine, "_active_lora_id", None) + if active_lora_id is not None: + return LoRARequest( + lora_name=str(active_lora_id), + lora_int_id=active_lora_id, + lora_path="/dummy_lora_path", + ) + return original(request, *args, **kwargs) + + self.openai_serving_chat._maybe_get_adapters = MethodType( + patched, + self.openai_serving_chat, + ) + # TODO(Charlie): revisit kwargs `return_tokens_as_token_ids`, # `enable_prompt_tokens_details`, `enable_force_include_usage`. self.openai_serving_completion = OpenAIServingCompletion( @@ -438,27 +458,55 @@ def _create_ray_prometheus_stat_loggers(self): ) return None + async def abort_generation(self) -> None: + """Abort all running and waiting requests.""" + engine = self._get_engine() + unfinished_request_ids = self._get_unfinished_request_ids(engine.output_processor) + if unfinished_request_ids: + await engine.abort(unfinished_request_ids) + await self.reset_prefix_cache() # avoid KV-cache pollution + logger.info(f"abort_generation() finished, aborted {len(unfinished_request_ids)} requests") + async def _load_lora_from_disk(self, lora_path: str): - """Load LoRA adapters from disk using vLLM's native add_lora method.""" - lora_id = int(time.time_ns() % 0x7FFFFFFF) - lora_request = LoRARequest(lora_name=f"{lora_id}", lora_int_id=lora_id, lora_path=lora_path) - result = await self.llm.add_lora(lora_request) - return result + """Swap LoRA adapter: abort in-flight requests, remove old, add new, reset cache.""" + await self.abort_generation() + + if self._active_lora_id is not None: + try: + await self.llm.remove_lora(self._active_lora_id) + logger.info(f"Removed old LoRA {self._active_lora_id}") + except Exception as e: + logger.error(f"Failed removing old LoRA: {e}") + + new_id = uuid4().int & 0x7FFFFFFF + + await self.llm.add_lora( + LoRARequest( + lora_name=str(new_id), + lora_int_id=new_id, + lora_path=lora_path, + ) + ) + + self._active_lora_id = new_id + + await self.reset_prefix_cache() + + logger.info(f"Loaded new LoRA {new_id}") + return {"status": "ok", "lora_id": new_id} async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_params: SamplingParams): """Collect outputs for a single prompt.""" - # Check if LoRA is enabled and create LoRA request + # Check if LoRA is enabled and create LoRA request using tracked _active_lora_id final_output = None lora_request = None - if self._is_lora: - lora_int_ids = list(await self.llm.list_loras()) - if len(lora_int_ids) > 0: - lora_int_id = lora_int_ids[0] - # dummy_lora_path for placeholder (actual loading done in add_lora()) - lora_request = LoRARequest( - lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/dummy_lora_path" - ) + if self._is_lora and self._active_lora_id is not None: + lora_request = LoRARequest( + lora_name=str(self._active_lora_id), + lora_int_id=self._active_lora_id, + lora_path="/dummy_lora_path", + ) async for request_output in self.llm.generate( prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),