Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ git submodule sync && git submodule update --init --recursive
#printf "Installing PyTorch with cu128"
#if [[ "$TORCH_VERSION" == "nightly" ]]; then
# if [ "${CU_VERSION:-}" == cpu ] ; then
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
# pip install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
# else
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
# pip install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
# fi
#elif [[ "$TORCH_VERSION" == "stable" ]]; then
# if [ "${CU_VERSION:-}" == cpu ] ; then
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
# pip install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
# else
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
# pip install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
# fi
#else
# printf "Failed to install pytorch"
Expand All @@ -47,9 +47,10 @@ git submodule sync && git submodule update --init --recursive

# install tensordict
if [[ "$RELEASE" == 0 ]]; then
pip3 install git+https://github.com/pytorch/tensordict.git
pip install "pybind11[global]" ninja
pip install git+https://github.com/pytorch/tensordict.git
else
pip3 install tensordict
pip install tensordict
fi

# smoke test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,4 @@ lib_dir="${env_dir}/lib"

conda deactivate && conda activate ./env

python -c "import transformers, datasets"

pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips

python examples/rlhf/train_rlhf.py \
sys.device=cuda:0 sys.ref_device=cuda:0 \
model.name_or_path=gpt2 train.max_epochs=2 \
data.batch_size=2 train.ppo.ppo_batch_size=2 \
train.ppo.ppo_num_epochs=1 reward_model.name_or_path= \
train.ppo.episode_length=8 train.ppo.num_rollouts_per_epoch=4 \
data.block_size=110 io.logger=csv
pytest test/llm -vvv --instafail --durations 600 --capture no --error-for-skips
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,19 @@
# Do not install PyTorch and torchvision here, otherwise they also get cached.

set -e
apt-get update && apt-get upgrade -y && apt-get install -y git cmake
export DEBIAN_FRONTEND=noninteractive
export TZ=UTC
apt-get update
apt-get install -yq --no-install-recommends git wget unzip curl patchelf
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
apt-get install -y wget \
gcc \
g++ \
unzip \
curl \
patchelf \
libosmesa6-dev \
libgl1-mesa-glx \
libglfw3 \
swig3.0 \
libglew-dev \
libglvnd0 \
libgl1 \
libglx0 \
libegl1 \
libgles2
# The base PyTorch devel image provides compilers, CMake >= 3.22, and most build deps.
# Install only minimal utilities not guaranteed to be present.

# Upgrade specific package
apt-get upgrade -y libstdc++6
# CMake available in the PyTorch devel image (Ubuntu 22.04) is sufficient.

# Cleanup APT cache
apt-get clean && rm -rf /var/lib/apt/lists/*

this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
root_dir="$(git rev-parse --show-toplevel)"
Expand Down
21 changes: 11 additions & 10 deletions .github/workflows/test-linux-llm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,18 @@ permissions:

jobs:
unittests:
if: ${{ github.event_name == 'push' || (github.event_name == 'pull_request' && contains(join(github.event.pull_request.labels.*.name, ', '), 'llm/')) }}
strategy:
matrix:
python_version: ["3.9"]
cuda_arch_version: ["12.8"]
python_version: ["3.12"]
cuda_arch_version: ["12.9"]
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
repository: pytorch/rl
runner: "linux.g5.4xlarge.nvidia.gpu"
runner: "linux.g6.4xlarge.experimental.nvidia.gpu"
# gpu-arch-type: cuda
# gpu-arch-version: "11.7"
docker-image: "nvidia/cudagl:11.4.0-base"
docker-image: "pytorch/pytorch:2.8.0-cuda12.9-cudnn9-devel"
timeout: 120
script: |
if [[ "${{ github.ref }}" =~ release/* ]]; then
Expand All @@ -43,14 +44,14 @@ jobs:
fi

set -euo pipefail
export PYTHON_VERSION="3.9"
export CU_VERSION="cu117"
export PYTHON_VERSION="3.12"
export CU_VERSION="cu129"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export TD_GET_DEFAULTS_TO_NONE=1

bash .github/unittest/linux_libs/scripts_llm/setup_env.sh
bash .github/unittest/linux_libs/scripts_llm/install.sh
bash .github/unittest/linux_libs/scripts_llm/run_test.sh
bash .github/unittest/linux_libs/scripts_llm/post_process.sh
bash .github/unittest/llm/scripts_llm/setup_env.sh
bash .github/unittest/llm/scripts_llm/install.sh
bash .github/unittest/llm/scripts_llm/run_test.sh
bash .github/unittest/llm/scripts_llm/post_process.sh
2 changes: 1 addition & 1 deletion test/llm/test_updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import argparse
import gc
Expand Down
64 changes: 46 additions & 18 deletions torchrl/modules/llm/backends/vllm/vllm_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,8 @@
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Literal, TYPE_CHECKING

import ray

import torch

from ray.util.placement_group import placement_group, remove_placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torchrl._utils import logger as torchrl_logger

# Import RLvLLMEngine and shared utilities
Expand All @@ -51,6 +47,25 @@
_has_vllm = False


def _get_ray():
"""Import Ray on demand to avoid global import side-effects.

Returns:
ModuleType: The imported Ray module.

Raises:
ImportError: If Ray is not installed.
"""
try:
import ray # type: ignore

return ray
except Exception as e: # pragma: no cover - surfaced to callers
raise ImportError(
"ray is not installed. Please install it with `pip install ray`."
) from e


class _AsyncvLLMWorker:
"""Async vLLM worker extension for Ray with weight update capabilities."""

Expand Down Expand Up @@ -264,7 +279,7 @@ async def generate(
"vllm is not installed. Please install it with `pip install vllm`."
)

from vllm import RequestOutput, SamplingParams, TokensPrompt
from vllm import SamplingParams, TokensPrompt

# Track whether input was originally a single prompt
single_prompt_input = False
Expand Down Expand Up @@ -471,11 +486,7 @@ def _gpus_per_replica(engine_args: AsyncEngineArgs) -> int:
)


# Create Ray remote versions
if ray is not None and _has_vllm:
_AsyncLLMEngineActor = ray.remote(num_cpus=0, num_gpus=0)(_AsyncLLMEngine)
else:
_AsyncLLMEngineActor = None
# Ray actor wrapper is created lazily in __init__ to avoid global Ray import.


class AsyncVLLM(RLvLLMEngine):
Expand Down Expand Up @@ -580,17 +591,18 @@ def __init__(
raise ImportError(
"vllm is not installed. Please install it with `pip install vllm`."
)
if ray is None:
raise ImportError(
"ray is not installed. Please install it with `pip install ray`."
)
# Lazily import ray only when constructing the actor class to avoid global import

# Enable prefix caching by default for better performance
engine_args.enable_prefix_caching = enable_prefix_caching

self.engine_args = engine_args
self.num_replicas = num_replicas
self.actor_class = actor_class or _AsyncLLMEngineActor
if actor_class is None:
ray = _get_ray()
self.actor_class = ray.remote(num_cpus=0, num_gpus=0)(_AsyncLLMEngine)
else:
self.actor_class = actor_class
self.actors: list = []
self._launched = False
self._service_id = uuid.uuid4().hex[
Expand All @@ -605,6 +617,11 @@ def _launch(self):
torchrl_logger.warning("AsyncVLLMEngineService already launched")
return

# Local imports to avoid global Ray dependency
ray = _get_ray()
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

torchrl_logger.info(
f"Launching {self.num_replicas} async vLLM engine actors..."
)
Expand Down Expand Up @@ -944,6 +961,7 @@ def generate(
Returns:
RequestOutput | list[RequestOutput]: Generated outputs from vLLM.
"""
ray = _get_ray()
# Check if this is a batch request
if self._is_batch(prompts, prompt_token_ids):
# Handle batched input by unbinding and sending individual requests
Expand Down Expand Up @@ -1068,6 +1086,9 @@ def shutdown(self):
f"Shutting down {len(self.actors)} async vLLM engine actors..."
)

ray = _get_ray()
from ray.util.placement_group import remove_placement_group

# Kill all actors
for i, actor in enumerate(self.actors):
try:
Expand Down Expand Up @@ -1260,6 +1281,7 @@ def _update_weights_with_nccl_broadcast_simple(
)

updated_weights = 0
ray = _get_ray()
with torch.cuda.device(0): # Ensure we're on the correct CUDA device
for name, weight in gpu_weights.items():
# Convert dtype to string name (like periodic-mono)
Expand Down Expand Up @@ -1336,6 +1358,7 @@ def get_num_unfinished_requests(
"AsyncVLLM service must be launched before getting request counts"
)

ray = _get_ray()
if actor_index is not None:
if not (0 <= actor_index < len(self.actors)):
raise IndexError(
Expand Down Expand Up @@ -1366,6 +1389,7 @@ def get_cache_usage(self, actor_index: int | None = None) -> float | list[float]
"AsyncVLLM service must be launched before getting cache usage"
)

ray = _get_ray()
if actor_index is not None:
if not (0 <= actor_index < len(self.actors)):
raise IndexError(
Expand Down Expand Up @@ -1678,6 +1702,7 @@ def _select_by_requests(self) -> int:
futures = [
actor.get_num_unfinished_requests.remote() for actor in self.actors
]
ray = _get_ray()
request_counts = ray.get(futures)

# Find the actor with minimum pending requests
Expand Down Expand Up @@ -1705,6 +1730,7 @@ def _select_by_cache_usage(self) -> int:
else:
# Query actors directly
futures = [actor.get_cache_usage.remote() for actor in self.actors]
ray = _get_ray()
cache_usages = ray.get(futures)

# Find the actor with minimum cache usage
Expand Down Expand Up @@ -1844,7 +1870,8 @@ def _is_actor_overloaded(self, actor_index: int) -> bool:
futures = [
actor.get_num_unfinished_requests.remote() for actor in self.actors
]
request_counts = ray.get(futures)
ray = _get_ray()
request_counts = ray.get(futures)

if not request_counts:
return False
Expand Down Expand Up @@ -1893,8 +1920,9 @@ def get_stats(self) -> dict[str, Any]:
cache_futures = [
actor.get_cache_usage.remote() for actor in self.actors
]
request_counts = ray.get(request_futures)
cache_usages = ray.get(cache_futures)
ray = _get_ray()
request_counts = ray.get(request_futures)
cache_usages = ray.get(cache_futures)

for i, (requests, cache_usage) in enumerate(
zip(request_counts, cache_usages)
Expand Down
Loading