-
Notifications
You must be signed in to change notification settings - Fork 286
Improved LoRA weight swap and robust transitions_to_training_data #1368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |
| import time | ||
| from dataclasses import dataclass | ||
| from http import HTTPStatus | ||
| from types import SimpleNamespace | ||
| from types import SimpleNamespace, MethodType | ||
| from uuid import uuid4 | ||
|
|
||
| import ray | ||
|
|
@@ -342,6 +342,7 @@ class AsyncVLLMInferenceEngine(BaseVLLMInferenceEngine): | |
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| self._weight_loader = VLLMWeightLoader(self.llm, is_async=True) | ||
| self._active_lora_id = None | ||
|
|
||
| def _create_engine(self, *args, **kwargs): | ||
| openai_kwargs = pop_openai_kwargs(kwargs) | ||
|
|
@@ -404,6 +405,25 @@ def _create_engine(self, *args, **kwargs): | |
| **openai_kwargs, | ||
| ) | ||
|
|
||
| if self._is_lora: | ||
| original = self.openai_serving_chat._maybe_get_adapters | ||
| async_engine = self # capture outer self safely | ||
|
|
||
| def patched(self_chat, request, *args, **kwargs): | ||
| active_lora_id = getattr(async_engine, "_active_lora_id", None) | ||
| if active_lora_id is not None: | ||
| return LoRARequest( | ||
| lora_name=str(active_lora_id), | ||
| lora_int_id=active_lora_id, | ||
| lora_path="/dummy_lora_path", | ||
| ) | ||
| return original(request, *args, **kwargs) | ||
|
|
||
| self.openai_serving_chat._maybe_get_adapters = MethodType( | ||
| patched, | ||
| self.openai_serving_chat, | ||
| ) | ||
|
|
||
| # TODO(Charlie): revisit kwargs `return_tokens_as_token_ids`, | ||
| # `enable_prompt_tokens_details`, `enable_force_include_usage`. | ||
| self.openai_serving_completion = OpenAIServingCompletion( | ||
|
|
@@ -438,27 +458,55 @@ def _create_ray_prometheus_stat_loggers(self): | |
| ) | ||
| return None | ||
|
|
||
| async def abort_generation(self) -> None: | ||
| """Abort all running and waiting requests.""" | ||
| engine = self._get_engine() | ||
| unfinished_request_ids = self._get_unfinished_request_ids(engine.output_processor) | ||
| if unfinished_request_ids: | ||
| await engine.abort(unfinished_request_ids) | ||
| await self.reset_prefix_cache() # avoid KV-cache pollution | ||
| logger.info(f"abort_generation() finished, aborted {len(unfinished_request_ids)} requests") | ||
|
|
||
| async def _load_lora_from_disk(self, lora_path: str): | ||
| """Load LoRA adapters from disk using vLLM's native add_lora method.""" | ||
| lora_id = int(time.time_ns() % 0x7FFFFFFF) | ||
| lora_request = LoRARequest(lora_name=f"{lora_id}", lora_int_id=lora_id, lora_path=lora_path) | ||
| result = await self.llm.add_lora(lora_request) | ||
| return result | ||
| """Swap LoRA adapter: abort in-flight requests, remove old, add new, reset cache.""" | ||
| await self.abort_generation() | ||
devin-ai-integration[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if self._active_lora_id is not None: | ||
|
Comment on lines
+471
to
+474
|
||
| try: | ||
| await self.llm.remove_lora(self._active_lora_id) | ||
| logger.info(f"Removed old LoRA {self._active_lora_id}") | ||
| except Exception as e: | ||
| logger.error(f"Failed removing old LoRA: {e}") | ||
|
|
||
| new_id = uuid4().int & 0x7FFFFFFF | ||
|
|
||
| await self.llm.add_lora( | ||
| LoRARequest( | ||
| lora_name=str(new_id), | ||
| lora_int_id=new_id, | ||
| lora_path=lora_path, | ||
| ) | ||
| ) | ||
|
|
||
| self._active_lora_id = new_id | ||
|
|
||
| await self.reset_prefix_cache() | ||
|
|
||
| logger.info(f"Loaded new LoRA {new_id}") | ||
| return {"status": "ok", "lora_id": new_id} | ||
|
|
||
| async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_params: SamplingParams): | ||
| """Collect outputs for a single prompt.""" | ||
| # Check if LoRA is enabled and create LoRA request | ||
| # Check if LoRA is enabled and create LoRA request using tracked _active_lora_id | ||
| final_output = None | ||
| lora_request = None | ||
|
|
||
| if self._is_lora: | ||
| lora_int_ids = list(await self.llm.list_loras()) | ||
| if len(lora_int_ids) > 0: | ||
| lora_int_id = lora_int_ids[0] | ||
| # dummy_lora_path for placeholder (actual loading done in add_lora()) | ||
| lora_request = LoRARequest( | ||
| lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/dummy_lora_path" | ||
| ) | ||
| if self._is_lora and self._active_lora_id is not None: | ||
| lora_request = LoRARequest( | ||
| lora_name=str(self._active_lora_id), | ||
| lora_int_id=self._active_lora_id, | ||
| lora_path="/dummy_lora_path", | ||
| ) | ||
|
|
||
| async for request_output in self.llm.generate( | ||
| prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔴 LoRA adapter not injected for
/completionsOpenAI endpointThe PR monkey-patches
_maybe_get_adaptersonself.openai_serving_chatto inject the active LoRA adapter for chat completions, but the same patch is not applied toself.openai_serving_completion. In vLLM 0.16.0, the completion serving path (vllm/entrypoints/openai/completion/serving.py) also calls_maybe_get_adaptersto resolve adapter requests. When a LoRA-enabled engine serves/completionsrequests (viavllm_engine.py:683), the active LoRA adapter won't be applied, and inference will fall back to the base model weights.Prompt for agents
Was this helpful? React with 👍 or 👎 to provide feedback.