Skip to content

Commit

Permalink
Merge branch 'main' into chunked-prefill-vlm
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 committed Nov 9, 2024
2 parents d386818 + 49d2a41 commit c8b0dd9
Show file tree
Hide file tree
Showing 85 changed files with 2,365 additions and 658 deletions.
52 changes: 31 additions & 21 deletions .buildkite/run-cpu-test-ppc64le.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,35 @@ source /etc/environment
#docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN="$HF_TOKEN" --name cpu-test cpu-test

# Run basic model test
docker exec cpu-test bash -c "
pip install pytest matplotlib einops transformers_stream_generator
pytest -v -s tests/models -m \"not vlm\" \
--ignore=tests/models/test_embedding.py \
--ignore=tests/models/test_oot_registration.py \
--ignore=tests/models/test_registry.py \
--ignore=tests/models/test_jamba.py \
--ignore=tests/models/test_mamba.py \
--ignore=tests/models/test_danube3_4b.py" # Mamba kernels and Danube3-4B on CPU is not supported
function cpu_tests() {
# Run basic model test
docker exec cpu-test bash -c "
set -e
pip install pytest pytest-asyncio \
decord einops librosa peft Pillow sentence-transformers soundfile \
transformers_stream_generator matplotlib datamodel_code_generator
pip install torchvision --index-url https://download.pytorch.org/whl/cpu
# Embedding models are not supported for CPU yet
# pytest -v -s tests/models/embedding/language
pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language/test_models.py
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
pytest -v -s tests/models/decoder_only/vision_language -m cpu_model"

# online inference
docker exec cpu-test bash -c "
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
--backend vllm \
--dataset-name random \
--model facebook/opt-125m \
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"
# online inference
docker exec cpu-test bash -c "
set -e
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
--backend vllm \
--dataset-name random \
--model facebook/opt-125m \
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"
}

# All of CPU tests are expected to be finished less than 25 mins.
export -f cpu_tests
timeout 25m bash -c "cpu_tests"
90 changes: 52 additions & 38 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,55 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
--cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2

# offline inference
docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"

# Run basic model test
docker exec cpu-test bash -c "
pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language \
--ignore=tests/models/test_fp8.py \
--ignore=tests/models/decoder_only/language/test_jamba.py \
--ignore=tests/models/decoder_only/language/test_mamba.py \
--ignore=tests/models/decoder_only/language/test_granitemoe.py \
--ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported

# Run compressed-tensor test
docker exec cpu-test bash -c "
pytest -s -v \
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"

# Run AWQ test
docker exec cpu-test bash -c "
pytest -s -v \
tests/quantization/test_ipex_quant.py"

# online inference
docker exec cpu-test bash -c "
export VLLM_CPU_KVCACHE_SPACE=10
export VLLM_CPU_OMP_THREADS_BIND=48-92
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
--backend vllm \
--dataset-name random \
--model facebook/opt-125m \
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"
function cpu_tests() {
# offline inference
docker exec cpu-test-avx2 bash -c "
set -e
python3 examples/offline_inference.py"

# Run basic model test
docker exec cpu-test bash -c "
set -e
pip install pytest pytest-asyncio \
decord einops librosa peft Pillow sentence-transformers soundfile \
transformers_stream_generator matplotlib datamodel_code_generator
pip install torchvision --index-url https://download.pytorch.org/whl/cpu
# Embedding models are not supported for CPU yet
# pytest -v -s tests/models/embedding/language
pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language/test_models.py
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
pytest -v -s tests/models/decoder_only/vision_language -m cpu_model"

# Run compressed-tensor test
docker exec cpu-test bash -c "
set -e
pytest -s -v \
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"

# Run AWQ test
docker exec cpu-test bash -c "
set -e
pytest -s -v \
tests/quantization/test_ipex_quant.py"

# online inference
docker exec cpu-test bash -c "
set -e
export VLLM_CPU_KVCACHE_SPACE=10
export VLLM_CPU_OMP_THREADS_BIND=48-92
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
--backend vllm \
--dataset-name random \
--model facebook/opt-125m \
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"
}

# All of CPU tests are expected to be finished less than 25 mins.
export -f cpu_tests
timeout 25m bash -c "cpu_tests"
3 changes: 1 addition & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ steps:
source_file_dependencies:
- benchmarks/
commands:
- pip install aiohttp
- bash run-benchmarks.sh

- label: Quantization Test # 33min
Expand Down Expand Up @@ -331,7 +330,7 @@ steps:
commands:
- pytest -v -s models/decoder_only/language --ignore=models/decoder_only/language/test_models.py

- label: Decoder-only Multi-Modal Models Test (Standard)
- label: Decoder-only Multi-Modal Models Test (Standard) # 26min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
Expand Down
5 changes: 5 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ ADD . /vllm-workspace/
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-dev.txt

# enable fast downloads from hf (for testing)
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install hf_transfer
ENV HF_HUB_ENABLE_HF_TRANSFER 1

# Copy in the v1 package for testing (it isn't distributed yet)
COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ async def async_request_openai_completions(
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
first_valid_chunk_received = False
first_chunk_received = False
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
Expand All @@ -275,7 +275,7 @@ async def async_request_openai_completions(
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if not first_valid_chunk_received:
if not first_chunk_received:
first_chunk_received = True
ttft = time.perf_counter() - st
output.ttft = ttft
Expand Down
165 changes: 4 additions & 161 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "type_convert.cuh"
#include "dispatch_utils.h"

#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>

#include "dispatch_utils.h"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>

using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif

namespace vllm {
Expand Down Expand Up @@ -51,155 +43,6 @@ __global__ void rms_norm_kernel(
}
}

/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
operators/constructors are not consistently implemented by HIP/CUDA, so
a generic conversion via type casts cannot be implemented.
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
*/
template <typename torch_type>
struct _typeConvert {
static constexpr bool exists = false;
};

#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
struct _typeConvert<c10::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;

__device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) {
return __half22float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
};

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template <>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;

__device__ static inline float convert(hip_type x) {
return __bfloat162float(x);
}
__device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))

/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template <typename scalar_t, int width>
struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>;
using T1 = typename Converter::hip_type;
using T2 = typename Converter::packed_hip_type;
T1 data[width];

__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] += other.data[i];
}
return *this;
}

__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
}
return *this;
}

__device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
temp_f.x *= scale;
temp_f.y *= scale;
T2 temp = Converter::convert(temp_f);
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp);
}
}
return *this;
}

__device__ float sum_squares() const {
float result = 0.0f;
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
result += z.x * z.x + z.y * z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]);
result += x * x;
}
}
return result;
}
};

/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
Expand Down
Loading

0 comments on commit c8b0dd9

Please sign in to comment.