Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/gpu_ci_run_skyrl_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ uv run --directory . --isolated --extra dev --extra fsdp pytest -s tests/backend
_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_policy_local_engines_e2e.py
_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py
_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py
_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py
_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py -m "not megatron"
3 changes: 3 additions & 0 deletions ci/gpu_ci_run_skyrl_train_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# Run all megatron tests
uv run --directory . --isolated --extra dev --extra megatron pytest -s tests/backends/skyrl_train/gpu/gpu_ci -m "megatron"

# Run megatron LoRA tests with new inference layer
_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra megatron pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py -m "megatron"

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ override-dependencies = [
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'",
"megatron-core==0.16.0; sys_platform == 'linux'",
"ml_dtypes>=0.5.0; sys_platform == 'linux'",
"transformers>=4.56.1,<5; sys_platform == 'linux'",
]

[tool.uv.extra-build-dependencies]
Expand Down Expand Up @@ -251,8 +252,8 @@ torchvision = [
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
]
# pin megatron bridge commit to fix for MoE + LoRA merging. Update this when an official release is cut
megatron-bridge = {git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge", rev = "02b5fccab5e5b21856d36c2e357839e0123b4b8f", marker = "sys_platform == 'linux'"}
# pin megatron bridge commit for LoRA adapter export support. Update this when an official release is cut
megatron-bridge = {git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge", rev = "f78c65f9", marker = "sys_platform == 'linux'"}
harbor = { git = "https://github.com/laude-institute/harbor", rev = "8c040e1bb010201fd3c75bee3dede2407b9f57cd" }

[tool.black]
Expand Down
41 changes: 39 additions & 2 deletions skyrl/backends/skyrl_train/inference_servers/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import threading
import time
from contextlib import asynccontextmanager
from typing import List, Optional

import httpx
Expand Down Expand Up @@ -106,8 +107,17 @@ def _get_server_for_request(self, request: Request) -> str:

def _build_app(self) -> FastAPI:
"""Build the FastAPI app with proxy routes."""

@asynccontextmanager
async def lifespan(app: FastAPI):
yield
if self._client:
await self._client.aclose()
self._client = None

app = FastAPI(
title="SkyRL Inference Router",
lifespan=lifespan,
docs_url=None,
redoc_url=None,
openapi_url=None,
Expand Down Expand Up @@ -242,9 +252,36 @@ def _wait_until_healthy(
raise RuntimeError(f"Router failed to start within {timeout}s")

def shutdown(self) -> None:
"""Shutdown the router gracefully."""
"""Shutdown the router and ensure the port is released."""
logger.info("Shutting down router...")

if self._server:
self._server.should_exit = True

if self._server_thread:
self._server_thread.join(timeout=5)
self._server_thread.join(timeout=10)

if self._server_thread.is_alive():
logger.warning("Router thread did not exit gracefully, forcing server socket closure")
self._force_close_server_sockets()

self._server = None
self._server_thread = None
self._app = None

def _force_close_server_sockets(self) -> None:
"""Force-close the underlying server sockets to release the port."""
if self._server and hasattr(self._server, "servers"):
for server in self._server.servers:
server.close()
import socket

try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(1)
result = s.connect_ex((self._host if self._host != "0.0.0.0" else "127.0.0.1", self._port))
s.close()
if result == 0:
logger.warning(f"Port {self._port} still in use after forced close")
except Exception:
pass
36 changes: 34 additions & 2 deletions skyrl/backends/skyrl_train/inference_servers/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
import logging
from argparse import Namespace

from skyrl.backends.skyrl_train.weight_sync import get_transfer_strategy
from skyrl.train.config import SkyRLTrainConfig, get_config_as_dict

logger = logging.getLogger(__name__)


def _uses_lora_weight_sync(cfg: SkyRLTrainConfig) -> bool:
"""Return True when the trainer syncs LoRA adapters (not merged weights).

FSDP always syncs LoRA adapters when ``lora.rank > 0``.
Megatron merges LoRA into the base weights by default
(``merge_lora=True``), so the inference engine should not enable LoRA.
"""
if cfg.trainer.policy.model.lora.rank <= 0:
return False
if cfg.trainer.strategy == "megatron":
return not cfg.trainer.policy.megatron_config.lora_config.merge_lora
return True


# TODO: Add a test for validation
def build_vllm_cli_args(cfg: SkyRLTrainConfig) -> Namespace:
Expand Down Expand Up @@ -48,13 +65,28 @@ def build_vllm_cli_args(cfg: SkyRLTrainConfig) -> Namespace:
for key, value in overrides.items():
setattr(args, key, value)

# Add LoRA params if enabled
if cfg.trainer.policy.model.lora.rank > 0:
# Enable LoRA on the inference engine only when the trainer will sync
# LoRA adapters (not merged weights). Megatron merges by default
# (merge_lora=True), so the inference engine must NOT have LoRA wrapping.
if _uses_lora_weight_sync(cfg):
args.enable_lora = True
args.max_lora_rank = cfg.trainer.policy.model.lora.rank
args.max_loras = 1
args.fully_sharded_loras = ie_cfg.fully_sharded_loras

if not cfg.trainer.placement.colocate_all:
lora_path = cfg.trainer.policy.model.lora.lora_sync_path
logger.warning(
"LoRA weight sync is enabled but training and inference are not "
"colocated (placement.colocate_all=false). The trainer saves LoRA "
"adapters to disk for the inference engine to load, so both must "
"share a filesystem. Set trainer.policy.model.lora.lora_sync_path "
"to a shared mount (current value: %s).",
lora_path,
)
else:
args.enable_lora = False

# Add any extra engine_init_kwargs
engine_kwargs = get_config_as_dict(ie_cfg.engine_init_kwargs)
for key, value in engine_kwargs.items():
Expand Down
71 changes: 64 additions & 7 deletions skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
TrainingOutputBatch,
)
from skyrl.backends.skyrl_train.utils.profiler import Profiler
from skyrl.backends.skyrl_train.weight_sync import WeightChunk, WeightExtractor
from skyrl.backends.skyrl_train.weight_sync import (
LoraLoadRequest,
WeightChunk,
WeightExtractor,
)
from skyrl.backends.skyrl_train.workers.megatron.megatron_model_wrapper import (
MegatronModelWrapper,
)
Expand Down Expand Up @@ -802,6 +806,55 @@ async def init_weight_sync_state(self, inference_engine_client, inference_engine
training_dtype=torch.bfloat16 if self.cfg.bf16 else torch.float32,
)

async def _save_lora_adapters_and_sync(self, lora_sync_path, inference_engine_client):
"""Export LoRA adapter weights via Megatron-Bridge and tell the inference engine to load them.

All ranks participate in the collective export (TP/PP/EP gathering is
handled internally by the bridge). Only rank 0 writes to disk and
sends the ``LoraLoadRequest``.
"""
import json

from megatron.bridge.models.conversion.peft_bridge import (
build_adapter_config_dict,
infer_target_modules_from_adapter_weights,
)
from safetensors.torch import save_file
Comment on lines +816 to +822
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Imports should generally be at the top of the file as per PEP 8 style guidelines. This improves readability and makes dependencies clear. The imports for json, megatron.bridge, and safetensors.torch here, as well as for RemoteInferenceClient on line 846, are local to this method. Unless there's a specific reason for lazy loading (like avoiding circular dependencies, which doesn't seem to be the case here), please move these imports to the top of the file.


adapter_state = {}
for name, tensor in self.bridge.export_adapter_weights(self.actor_module, cpu=True, show_progress=False):
adapter_state[f"base_model.model.{name}"] = tensor.clone().float()

if torch.distributed.get_rank() == 0:
os.makedirs(lora_sync_path, exist_ok=True)

target_modules = infer_target_modules_from_adapter_weights(adapter_state.keys())
base_model_name_or_path = str(
getattr(self.bridge.hf_pretrained, "model_name_or_path", "")
or getattr(self.bridge.hf_pretrained, "name_or_path", "")
)
adapter_config = build_adapter_config_dict(
self.lora_cls,
target_modules=target_modules,
base_model_name_or_path=base_model_name_or_path,
)

save_file(adapter_state, os.path.join(lora_sync_path, "adapter_model.safetensors"))
with open(os.path.join(lora_sync_path, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(adapter_config, f, ensure_ascii=False, indent=4)

from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import (
RemoteInferenceClient,
)

if isinstance(inference_engine_client, RemoteInferenceClient):
await inference_engine_client.update_lora_from_disk(lora_sync_path)
else:
lora_request = LoraLoadRequest(lora_path=lora_sync_path)
await inference_engine_client.update_named_weights(lora_request)

torch.distributed.barrier()

async def broadcast_to_inference_engines(
self, inference_engine_client: "InferenceEngineInterface", inference_engine_cfg: "InferenceEngineConfig"
):
Expand All @@ -814,12 +867,16 @@ async def broadcast_to_inference_engines(

torch.cuda.empty_cache()

# Extract and send weights using the sender created at init time
weight_metadata = self.weight_extractor.get_weight_metadata(generator_dtype)
await self._weight_transfer_sender.send_chunks(
self.weight_extractor.extract_weights(generator_dtype),
weight_metadata=weight_metadata,
)
if self._is_lora and not self.cfg.policy.megatron_config.lora_config.merge_lora:
lora_sync_path = self.cfg.policy.model.lora.lora_sync_path
await self._save_lora_adapters_and_sync(lora_sync_path, inference_engine_client)
else:
# Extract and send weights using the sender created at init time
weight_metadata = self.weight_extractor.get_weight_metadata(generator_dtype)
await self._weight_transfer_sender.send_chunks(
self.weight_extractor.extract_weights(generator_dtype),
weight_metadata=weight_metadata,
)

if cache_reset_task is not None:
await cache_reset_task
Expand Down
1 change: 1 addition & 0 deletions skyrl/train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class MegatronTorchProfilerConfig(BaseConfig):
@dataclass
class MegatronLoraConfig(BaseConfig):
lora_type: str = "lora"
merge_lora: bool = True


DEFAULT_MEGATRON_OPTIMIZER_KWARGS = {
Expand Down
69 changes: 52 additions & 17 deletions tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""
# Run tests (requires fsdp extra):
uv run --isolated --extra dev --extra fsdp pytest tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py
# Run FSDP tests:
uv run --isolated --extra dev --extra fsdp pytest tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py -k "fsdp"

# Run Megatron tests:
uv run --isolated --extra dev --extra megatron pytest tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py -k "megatron"
"""

import asyncio
Expand All @@ -22,17 +25,32 @@
MODEL = "Qwen/Qwen2.5-0.5B-Instruct"


def get_test_actor_config(enable_lora: bool = False) -> SkyRLTrainConfig:
def get_test_actor_config(
strategy: str = "fsdp",
enable_lora: bool = False,
colocate_all: bool = False,
weight_sync_backend: str = "nccl",
tp_size: int = 2,
merge_lora: bool = True,
) -> SkyRLTrainConfig:
"""Get base config with test-specific overrides."""
cfg = SkyRLTrainConfig()
cfg.trainer.policy.model.path = MODEL
cfg.trainer.critic.model.path = ""
cfg.trainer.strategy = strategy
cfg.trainer.placement.colocate_all = colocate_all
cfg.trainer.placement.policy_num_gpus_per_node = 2
cfg.generator.inference_engine.async_engine = True
cfg.generator.inference_engine.num_engines = 1
cfg.generator.inference_engine.run_engines_locally = True
cfg.generator.inference_engine.weight_sync_backend = weight_sync_backend
cfg.generator.inference_engine.tensor_parallel_size = tp_size

if strategy == "megatron":
cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2
cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1
cfg.trainer.policy.megatron_config.lora_config.merge_lora = merge_lora

# LoRA configuration
if enable_lora:
cfg.trainer.policy.model.lora = SkyRLLoraConfig(
rank=32,
Expand All @@ -45,29 +63,46 @@ def get_test_actor_config(enable_lora: bool = False) -> SkyRLTrainConfig:


@pytest.mark.parametrize(
("colocate_all", "weight_sync_backend", "strategy", "tp_size"),
("colocate_all", "weight_sync_backend", "strategy", "tp_size", "merge_lora"),
[
pytest.param(False, "nccl", "fsdp", 2),
pytest.param(True, "nccl", "fsdp", 2),
pytest.param(False, "nccl", "fsdp2", 2),
pytest.param(True, "nccl", "fsdp2", 2),
pytest.param(False, "nccl", "fsdp", 2, True),
pytest.param(True, "nccl", "fsdp", 2, True),
pytest.param(False, "nccl", "fsdp2", 2, True),
pytest.param(True, "nccl", "fsdp2", 2, True),
pytest.param(False, "nccl", "megatron", 2, True, marks=pytest.mark.megatron),
pytest.param(True, "nccl", "megatron", 2, True, marks=pytest.mark.megatron),
pytest.param(False, "nccl", "megatron", 2, False, marks=pytest.mark.megatron),
pytest.param(True, "nccl", "megatron", 2, False, marks=pytest.mark.megatron),
],
ids=[
"no_colocate_nccl_fsdp",
"colocate_nccl_fsdp",
"no_colocate_nccl_fsdp2",
"colocate_nccl_fsdp2",
"no_colocate_nccl_megatron_merged",
"colocate_nccl_megatron_merged",
"no_colocate_nccl_megatron_adapter",
"colocate_nccl_megatron_adapter",
],
)
def test_policy_local_engines_e2e(ray_init_fixture, colocate_all, weight_sync_backend, strategy, tp_size):
def test_policy_local_engines_e2e(ray_init_fixture, colocate_all, weight_sync_backend, strategy, tp_size, merge_lora):
"""
Tests initalizing the policy actor group and inference engine, syncing weights, and performing generation.
"""
cfg = get_test_actor_config(enable_lora=True)
cfg.trainer.placement.colocate_all = colocate_all
cfg.generator.inference_engine.weight_sync_backend = weight_sync_backend
cfg.trainer.strategy = strategy
cfg.generator.inference_engine.tensor_parallel_size = tp_size
cfg = get_test_actor_config(
strategy=strategy,
enable_lora=True,
colocate_all=colocate_all,
weight_sync_backend=weight_sync_backend,
tp_size=tp_size,
merge_lora=merge_lora,
)

# Only enable LoRA on the vLLM side when adapters are loaded separately.
# When merge_lora=True the bridge merges LoRA into the full weights, so
# vLLM receives plain weights and must NOT have enable_lora (which wraps
# modules and changes named_parameters(), breaking load_weights).
needs_vllm_lora = not (strategy == "megatron" and merge_lora)

# If colocate is True, this will load the engine, sleep, and wake up the engine
with InferenceEngineState.create(
Expand All @@ -77,8 +112,8 @@ def test_policy_local_engines_e2e(ray_init_fixture, colocate_all, weight_sync_ba
async_engine=cfg.generator.inference_engine.async_engine,
tp_size=cfg.generator.inference_engine.tensor_parallel_size,
colocate_all=cfg.trainer.placement.colocate_all,
sleep_level=1, # since we explicitly sync weights
enable_lora=True, # Enable LoRA for this test
sleep_level=1 if needs_vllm_lora else 2,
enable_lora=needs_vllm_lora,
) as engines:
client, pg = engines.client, engines.pg

Expand Down
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hao-aaron can you add this to the megatron CI here and ensure that tests pass with _SKYRL_USE_NEW_INFERENCE=1 ?

uv run --directory . --isolated --extra dev --extra megatron pytest -s tests/backends/skyrl_train/gpu/gpu_ci -m "megatron"

Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def get_test_training_batch(batch_size=4) -> TrainingInputBatch:
[
pytest.param(True, 4, 2, 2, 1, None, False, marks=_skip_new_inference, id="colocate_all"),
pytest.param(False, 2, 2, 1, 1, None, False, id="non_colocated"),
pytest.param(True, 4, 2, 2, 1, None, True, marks=_skip_new_inference, id="colocate_all_lora"),
],
)
@pytest.mark.megatron
Expand Down
Loading
Loading