Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 74 additions & 28 deletions skyrl-agent/skyrl_agent/functional/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field
from functools import wraps
from loguru import logger

# Type aliases for Transition
Observation = Dict[str, Any]
Expand Down Expand Up @@ -122,7 +123,7 @@ class TrainingDatum:

input_tokens: List[int]
response_tokens: List[int]
response_logprobs: List[float]
response_logprobs: Optional[List[float]] # None when logprobs unavailable (e.g. external actions)
response_mask: List[float] # 0 for observation tokens, 1 for action tokens


Expand Down Expand Up @@ -157,78 +158,123 @@ def transitions_to_training_data(
List of TrainingDatum objects
"""

# Accumulator state for building sequences
if not transitions:
return []

data: List[TrainingDatum] = []

full_sequence: List[int] = []
sampled_logprobs: List[float] = []
mask: List[float] = []
has_valid_logprobs: bool = True

data: List[TrainingDatum] = []

def make_datum():
def finalize_datum() -> Optional[TrainingDatum]:
"""Create a TrainingDatum from current accumulated state."""
if not full_sequence:
return None

first_nonzero = mask.index(1) if 1 in mask else len(mask)
# till the first non-zero mask
if first_nonzero == len(mask):
logger.warning("Datum has no action tokens (all mask is 0), skipping")
return None

input_tokens = full_sequence[:first_nonzero]
response_tokens = full_sequence[first_nonzero:]
response_logprobs = sampled_logprobs[first_nonzero:]
response_mask = mask[first_nonzero:]

if has_valid_logprobs:
response_logprobs = sampled_logprobs[first_nonzero:]

if len(response_logprobs) != len(response_tokens):
logger.error(
f"response_logprobs length ({len(response_logprobs)}) "
f"!= response_tokens length ({len(response_tokens)})"
)
return None
else:
response_logprobs = None

if len(response_mask) != len(response_tokens):
logger.error(
f"Length mismatch: response_mask ({len(response_mask)}) "
f"!= response_tokens ({len(response_tokens)})"
)
return None

return TrainingDatum(
input_tokens=input_tokens,
response_tokens=response_tokens,
response_logprobs=response_logprobs,
response_mask=response_mask,
)

def clear_accumulator():
"""Clear the accumulator state."""
nonlocal full_sequence, sampled_logprobs, mask
def reset_accumulator():
"""Clear accumulator for next datum."""
nonlocal full_sequence, sampled_logprobs, mask, has_valid_logprobs
full_sequence = []
sampled_logprobs = []
mask = []
has_valid_logprobs = True

# Process each transition
for transition in transitions:
# Get observation tokens
for idx, transition in enumerate(transitions):
# Validation
if transition.ob is None:
logger.warning(f"Transition {idx} has None observation, skipping")
continue
if transition.ac is None:
logger.warning(f"Transition {idx} has None action, skipping")
continue

ob_tokens = transition.ob.input_ids
if not ob_tokens:
logger.warning(f"Transition {idx} has empty observation tokens, skipping")
continue

# Get action tokens and logprobs
ac_tokens = transition.ac.token_ids
ac_logprobs = transition.ac.logprobs or [0.0] * len(ac_tokens)
if not ac_tokens:
logger.warning(f"Transition {idx} has empty action tokens, skipping")
continue

ac_logprobs = transition.ac.logprobs
transition_has_valid_logprobs = (
ac_logprobs is not None and len(ac_logprobs) == len(ac_tokens)
)

if not transition_has_valid_logprobs:
if ac_logprobs is None:
logger.debug(f"Transition {idx} has no logprobs")
else:
logger.warning(
f"Transition {idx} has mismatched logprobs "
f"({len(ac_logprobs)} logprobs vs {len(ac_tokens)} tokens)"
)

# Determine delta observation (new tokens not in accumulated sequence)
if len(full_sequence) == 0:
# First transition, use all observation tokens
# First transition always starts a new datum
delta_ob_tokens = ob_tokens
elif _is_prefix(full_sequence, ob_tokens):
# Current observation extends previous sequence
# Only add the delta (new tokens)
delta_ob_tokens = ob_tokens[len(full_sequence) :]
delta_ob_tokens = ob_tokens[len(full_sequence):]
else:
# Current observation doesn't extend previous sequence
# Save current accumulated datum and start fresh
datum = make_datum()
datum = finalize_datum()
if datum:
data.append(datum)
clear_accumulator()
reset_accumulator()
delta_ob_tokens = ob_tokens

# Add delta observation tokens to sequence
has_valid_logprobs = has_valid_logprobs and transition_has_valid_logprobs

# Accumulate tokens (pad logprobs with 0.0 when invalid — safe since has_valid_logprobs gates usage)
full_sequence.extend(delta_ob_tokens)
sampled_logprobs.extend([0.0] * len(delta_ob_tokens))
mask.extend([0.0] * len(delta_ob_tokens))

# Add action tokens to sequence
full_sequence.extend(ac_tokens)
sampled_logprobs.extend(ac_logprobs)
sampled_logprobs.extend(ac_logprobs if transition_has_valid_logprobs else [0.0] * len(ac_tokens))
mask.extend([1.0] * len(ac_tokens))

# Create final datum from remaining accumulated state
if full_sequence:
datum = make_datum()
datum = finalize_datum()
if datum:
data.append(datum)

Expand Down
78 changes: 63 additions & 15 deletions skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
from dataclasses import dataclass
from http import HTTPStatus
from types import SimpleNamespace
from types import SimpleNamespace, MethodType
from uuid import uuid4

import ray
Expand Down Expand Up @@ -342,6 +342,7 @@ class AsyncVLLMInferenceEngine(BaseVLLMInferenceEngine):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._weight_loader = VLLMWeightLoader(self.llm, is_async=True)
self._active_lora_id = None

def _create_engine(self, *args, **kwargs):
openai_kwargs = pop_openai_kwargs(kwargs)
Expand Down Expand Up @@ -404,6 +405,25 @@ def _create_engine(self, *args, **kwargs):
**openai_kwargs,
)

if self._is_lora:
original = self.openai_serving_chat._maybe_get_adapters
async_engine = self # capture outer self safely

def patched(self_chat, request, *args, **kwargs):
active_lora_id = getattr(async_engine, "_active_lora_id", None)
if active_lora_id is not None:
return LoRARequest(
lora_name=str(active_lora_id),
lora_int_id=active_lora_id,
lora_path="/dummy_lora_path",
)
return original(request, *args, **kwargs)

self.openai_serving_chat._maybe_get_adapters = MethodType(
patched,
self.openai_serving_chat,
)
Comment on lines +408 to +425
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 LoRA adapter not injected for /completions OpenAI endpoint

The PR monkey-patches _maybe_get_adapters on self.openai_serving_chat to inject the active LoRA adapter for chat completions, but the same patch is not applied to self.openai_serving_completion. In vLLM 0.16.0, the completion serving path (vllm/entrypoints/openai/completion/serving.py) also calls _maybe_get_adapters to resolve adapter requests. When a LoRA-enabled engine serves /completions requests (via vllm_engine.py:683), the active LoRA adapter won't be applied, and inference will fall back to the base model weights.

Prompt for agents
In skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py, the _maybe_get_adapters monkey-patch at lines 408-425 is applied only to self.openai_serving_chat but not to self.openai_serving_completion. Apply the same monkey-patch logic to self.openai_serving_completion as well. You can extract the patching logic into a helper function and call it for both serving objects. For example, after line 425, add the same patching for self.openai_serving_completion:

  original_completion = self.openai_serving_completion._maybe_get_adapters
  def patched_completion(self_completion, request, *args, **kwargs):
      active_lora_id = getattr(async_engine, "_active_lora_id", None)
      if active_lora_id is not None:
          return LoRARequest(
              lora_name=str(active_lora_id),
              lora_int_id=active_lora_id,
              lora_path="/dummy_lora_path",
          )
      return original_completion(request, *args, **kwargs)
  self.openai_serving_completion._maybe_get_adapters = MethodType(
      patched_completion,
      self.openai_serving_completion,
  )
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.


# TODO(Charlie): revisit kwargs `return_tokens_as_token_ids`,
# `enable_prompt_tokens_details`, `enable_force_include_usage`.
self.openai_serving_completion = OpenAIServingCompletion(
Expand Down Expand Up @@ -438,27 +458,55 @@ def _create_ray_prometheus_stat_loggers(self):
)
return None

async def abort_generation(self) -> None:
"""Abort all running and waiting requests."""
engine = self._get_engine()
unfinished_request_ids = self._get_unfinished_request_ids(engine.output_processor)
if unfinished_request_ids:
await engine.abort(unfinished_request_ids)
await self.reset_prefix_cache() # avoid KV-cache pollution
logger.info(f"abort_generation() finished, aborted {len(unfinished_request_ids)} requests")

async def _load_lora_from_disk(self, lora_path: str):
"""Load LoRA adapters from disk using vLLM's native add_lora method."""
lora_id = int(time.time_ns() % 0x7FFFFFFF)
lora_request = LoRARequest(lora_name=f"{lora_id}", lora_int_id=lora_id, lora_path=lora_path)
result = await self.llm.add_lora(lora_request)
return result
"""Swap LoRA adapter: abort in-flight requests, remove old, add new, reset cache."""
await self.abort_generation()

if self._active_lora_id is not None:
Comment on lines +471 to +474
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_load_lora_from_disk() calls await self.abort_generation(), but AsyncVLLMInferenceEngine (and its base classes) don't define abort_generation. As written this will raise AttributeError the first time a LoRA swap is attempted. Add an abort_generation() implementation (likely similar to the unfinished-request abort logic used in sleep()), or replace this with the correct existing control method (pause_generation/resume_generation or direct engine.abort of unfinished request IDs).

Copilot uses AI. Check for mistakes.
try:
await self.llm.remove_lora(self._active_lora_id)
logger.info(f"Removed old LoRA {self._active_lora_id}")
except Exception as e:
logger.error(f"Failed removing old LoRA: {e}")

new_id = uuid4().int & 0x7FFFFFFF

await self.llm.add_lora(
LoRARequest(
lora_name=str(new_id),
lora_int_id=new_id,
lora_path=lora_path,
)
)

self._active_lora_id = new_id

await self.reset_prefix_cache()

logger.info(f"Loaded new LoRA {new_id}")
return {"status": "ok", "lora_id": new_id}

async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_params: SamplingParams):
"""Collect outputs for a single prompt."""
# Check if LoRA is enabled and create LoRA request
# Check if LoRA is enabled and create LoRA request using tracked _active_lora_id
final_output = None
lora_request = None

if self._is_lora:
lora_int_ids = list(await self.llm.list_loras())
if len(lora_int_ids) > 0:
lora_int_id = lora_int_ids[0]
# dummy_lora_path for placeholder (actual loading done in add_lora())
lora_request = LoRARequest(
lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/dummy_lora_path"
)
if self._is_lora and self._active_lora_id is not None:
lora_request = LoRARequest(
lora_name=str(self._active_lora_id),
lora_int_id=self._active_lora_id,
lora_path="/dummy_lora_path",
)

async for request_output in self.llm.generate(
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
Expand Down
Loading