Improved LoRA weight swap and robust transitions_to_training_data#1368
Improved LoRA weight swap and robust transitions_to_training_data#1368ashutoshuiuc wants to merge 2 commits intoNovaSky-AI:mainfrom
Conversation
LoRA weight swap improvements: - Abort in-flight generation before swapping adapters - Remove old adapter before adding new one (prevents stale adapter buildup) - Reset prefix cache after swap for correctness - Track active_lora_id explicitly instead of querying list_loras() - Monkey-patch _maybe_get_adapters for consistent adapter lookup in chat completions transitions_to_training_data robustness: - Validate None/empty observations and actions per-transition - Track logprobs validity per-datum (handle external actions without logprobs) - Allow response_logprobs=None in TrainingDatum when logprobs unavailable - Explicit length-mismatch checks between response tokens, logprobs, and mask - Skip datums with no action tokens (all-zero mask) Closes NovaSky-AI#1297
There was a problem hiding this comment.
Code Review
This pull request introduces significant robustness improvements to transitions_to_training_data and refactors LoRA weight swapping. The changes in transitions_to_training_data add comprehensive validation for transitions, observations, and actions, which is a great improvement. The LoRA weight swapping logic is now more explicit and robust, correctly handling adapter removal and tracking the active adapter.
My review includes two suggestions for improvement: one to simplify a boolean condition for better readability in utils.py, and another to use a more robust method for generating LoRA IDs in vllm_engine.py to prevent potential collisions.
| if not transition_has_valid_logprobs and has_valid_logprobs: | ||
| has_valid_logprobs = False |
There was a problem hiding this comment.
This conditional logic can be simplified for better readability. The current implementation is correct, but a more direct and idiomatic way to express that has_valid_logprobs should become False if any transition_has_valid_logprobs is False is to use a boolean and operation.
has_valid_logprobs = has_valid_logprobs and transition_has_valid_logprobs| except Exception as e: | ||
| logger.error(f"Failed removing old LoRA: {e}") | ||
|
|
||
| new_id = int(time.time_ns() % 0x7FFFFFFF) |
There was a problem hiding this comment.
Using time.time_ns() for ID generation can lead to collisions if this function is called in rapid succession, especially with the modulo operation. A more robust approach is to use a random source. Since uuid4 is already imported, you can use it to generate a random 31-bit integer. This significantly reduces the chance of collision.
Note that the int() cast in the original code is redundant as time.time_ns() % ... already produces an integer.
| new_id = int(time.time_ns() % 0x7FFFFFFF) | |
| new_id = uuid4().int & 0x7FFFFFFF |
There was a problem hiding this comment.
Pull request overview
This PR improves runtime robustness in two areas: (1) vLLM LoRA adapter hot-swapping during training/inference weight sync, and (2) conversion of agent transitions into training datums, especially when observations/actions/logprobs may be missing or inconsistent.
Changes:
- Update
AsyncVLLMInferenceEngineLoRA swapping to explicitly track the active adapter ID, remove the previous adapter, add the new one, and reset the prefix cache; also monkey-patch adapter lookup for OpenAI-serving paths. - Make
transitions_to_training_datamore defensive: validate observations/actions, handle per-datum logprob availability, add explicit length checks, and skip invalid datums.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py |
Implements explicit active LoRA tracking and a more controlled LoRA swap flow for the async vLLM engine. |
skyrl-agent/skyrl_agent/functional/utils.py |
Hardens transitions_to_training_data and updates TrainingDatum to allow missing logprobs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| """Swap LoRA adapter: abort in-flight requests, remove old, add new, reset cache.""" | ||
| await self.abort_generation() | ||
|
|
||
| if self._active_lora_id is not None: |
There was a problem hiding this comment.
_load_lora_from_disk() calls await self.abort_generation(), but AsyncVLLMInferenceEngine (and its base classes) don't define abort_generation. As written this will raise AttributeError the first time a LoRA swap is attempted. Add an abort_generation() implementation (likely similar to the unfinished-request abort logic used in sleep()), or replace this with the correct existing control method (pause_generation/resume_generation or direct engine.abort of unfinished request IDs).
- Add abort_generation() to AsyncVLLMInferenceEngine (was missing from base) - Use uuid4 for LoRA adapter IDs instead of time.time_ns() (avoids collisions) - Simplify has_valid_logprobs tracking with boolean AND
| if self._is_lora: | ||
| original = self.openai_serving_chat._maybe_get_adapters | ||
| async_engine = self # capture outer self safely | ||
|
|
||
| def patched(self_chat, request, *args, **kwargs): | ||
| active_lora_id = getattr(async_engine, "_active_lora_id", None) | ||
| if active_lora_id is not None: | ||
| return LoRARequest( | ||
| lora_name=str(active_lora_id), | ||
| lora_int_id=active_lora_id, | ||
| lora_path="/dummy_lora_path", | ||
| ) | ||
| return original(request, *args, **kwargs) | ||
|
|
||
| self.openai_serving_chat._maybe_get_adapters = MethodType( | ||
| patched, | ||
| self.openai_serving_chat, | ||
| ) |
There was a problem hiding this comment.
🔴 LoRA adapter not injected for /completions OpenAI endpoint
The PR monkey-patches _maybe_get_adapters on self.openai_serving_chat to inject the active LoRA adapter for chat completions, but the same patch is not applied to self.openai_serving_completion. In vLLM 0.16.0, the completion serving path (vllm/entrypoints/openai/completion/serving.py) also calls _maybe_get_adapters to resolve adapter requests. When a LoRA-enabled engine serves /completions requests (via vllm_engine.py:683), the active LoRA adapter won't be applied, and inference will fall back to the base model weights.
Prompt for agents
In skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py, the _maybe_get_adapters monkey-patch at lines 408-425 is applied only to self.openai_serving_chat but not to self.openai_serving_completion. Apply the same monkey-patch logic to self.openai_serving_completion as well. You can extract the patching logic into a helper function and call it for both serving objects. For example, after line 425, add the same patching for self.openai_serving_completion:
original_completion = self.openai_serving_completion._maybe_get_adapters
def patched_completion(self_completion, request, *args, **kwargs):
active_lora_id = getattr(async_engine, "_active_lora_id", None)
if active_lora_id is not None:
return LoRARequest(
lora_name=str(active_lora_id),
lora_int_id=active_lora_id,
lora_path="/dummy_lora_path",
)
return original_completion(request, *args, **kwargs)
self.openai_serving_completion._maybe_get_adapters = MethodType(
patched_completion,
self.openai_serving_completion,
)
Was this helpful? React with 👍 or 👎 to provide feedback.
Summary
active_lora_idexplicitly instead of queryinglist_loras(), monkey-patch_maybe_get_adaptersfor consistent adapter lookuptransitions_to_training_data: Validate None/empty observations and actions, track logprobs validity per-datum (handle external actions without logprobs), explicit length-mismatch checks, skip all-zero mask datumsSplit from #1298 per maintainer feedback.
Closes #1297