Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Heterogeneous Speculative Decoding (CPU + GPU) #5065

Open
wants to merge 63 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
aaece57
hete spec decode engine
jiqing-feng May 16, 2024
21fb773
compile ops for cuda and cpu
jiqing-feng May 24, 2024
5f02fdd
can run hete spec decode
jiqing-feng May 24, 2024
8febd81
add parameter cpu_draft_worker to run draft model on CPU
jiqing-feng May 27, 2024
d9af7a6
rm useless comments
jiqing-feng May 27, 2024
74fb5d5
merge main
jiqing-feng May 27, 2024
44acebe
fix conflict
jiqing-feng May 27, 2024
b4b8744
add copy comment
jiqing-feng May 27, 2024
cc7998e
rebase
jiqing-feng Jul 2, 2024
8f7ecf3
fix bug
jiqing-feng Jul 2, 2024
794613e
rebase
jiqing-feng Jul 3, 2024
52022a5
fix style
jiqing-feng Jul 3, 2024
fa40a93
rebbase
jiqing-feng Jul 3, 2024
aa4d556
fix style
jiqing-feng Jul 3, 2024
f7491eb
Merge branch 'main' into hete_spec_decode
jiqing-feng Jul 4, 2024
344f5d7
fix format
jiqing-feng Jul 11, 2024
53cf9b6
rebase
jiqing-feng Jul 11, 2024
23a4575
fix format
jiqing-feng Jul 11, 2024
2cab72f
rebase main
jiqing-feng Jul 30, 2024
185836b
rebase main
jiqing-feng Aug 23, 2024
0556d02
fix style
jiqing-feng Aug 23, 2024
2eb3201
fix diff
jiqing-feng Aug 23, 2024
345788e
fix arg
jiqing-feng Aug 23, 2024
49d5bdf
fix match
jiqing-feng Aug 23, 2024
12e6c5d
fix cmake
jiqing-feng Aug 26, 2024
e313329
fix cmake style
jiqing-feng Aug 26, 2024
32c9c9a
add numa in cuda dockerfile
jiqing-feng Aug 26, 2024
e1782d3
use low version gcc
jiqing-feng Aug 26, 2024
ff7efee
rm useless link
jiqing-feng Aug 26, 2024
1c5d8c4
enable TP
jiqing-feng Aug 26, 2024
59de387
rebase
jiqing-feng Aug 26, 2024
6fc9b3b
fix format
jiqing-feng Aug 26, 2024
ded4e78
rm erro cpu cache ops
jiqing-feng Aug 27, 2024
1eca335
fix cpu op import
jiqing-feng Aug 27, 2024
b986b4d
disable cpu TP model
jiqing-feng Aug 27, 2024
1db03a2
rebase main
jiqing-feng Aug 27, 2024
1033dcc
fix cpu worker core binding
jiqing-feng Aug 27, 2024
a0f172c
fix import cpu ops
jiqing-feng Aug 27, 2024
14df487
enable cpu TP
jiqing-feng Aug 30, 2024
7bbd35b
add cpu-draft-worker parameter
jiqing-feng Sep 3, 2024
c895b50
fix style
jiqing-feng Sep 3, 2024
91499e1
Merge branch 'main' into hete_spec_decode
jiqing-feng Sep 3, 2024
d7b742c
fix param name
jiqing-feng Sep 3, 2024
e016db9
fix cpu-draft-args
jiqing-feng Sep 10, 2024
0d58142
Merge branch 'main' into hete_spec_decode
jiqing-feng Sep 10, 2024
679664b
fix cmake list to avoid amd error
jiqing-feng Sep 10, 2024
729483e
fix ops name
jiqing-feng Sep 10, 2024
13e5e2a
fix distributed tests and disable distributed verified if cpu-draft-m…
jiqing-feng Sep 11, 2024
581c529
fix tests
jiqing-feng Sep 11, 2024
753e1d0
rebase
jiqing-feng Sep 19, 2024
c670338
skip build cpu if rocm and fix code style
jiqing-feng Sep 19, 2024
c3e9488
install onednn
jiqing-feng Sep 19, 2024
5d7233f
fix install onednn position
jiqing-feng Sep 19, 2024
8bfc4e6
ondnn install
jiqing-feng Sep 19, 2024
e01732e
fix cpu op
jiqing-feng Sep 19, 2024
07eb1a1
change dockerfile base image to ubuntu22.04
jiqing-feng Sep 19, 2024
27da2ee
fix cmake list
jiqing-feng Sep 19, 2024
da1728a
install libc6
jiqing-feng Sep 19, 2024
a883fce
revert dockerfile to ubuntu 20.04
jiqing-feng Sep 20, 2024
6aba90b
disable avx512 to pass cpu compile
jiqing-feng Sep 20, 2024
83bc114
reuse multi step worker for CPU
jiqing-feng Sep 20, 2024
d057f34
fix SDPA assert
jiqing-feng Sep 20, 2024
77e97e2
fix format
jiqing-feng Sep 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,21 @@ define_gpu_extension_target(

add_dependencies(default _core_C)

if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
NOT VLLM_TARGET_DEVICE STREQUAL "rocm" AND
NOT VLLM_TARGET_DEVICE STREQUAL "cpu")
return()
endif()

#
# Forward the non-CUDA device extensions to external CMake scripts.
# The CUDA device extensions need CPU CMake scripts to support Heterogeneous Speculative Decoding.
#
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
if (VLLM_TARGET_DEVICE STREQUAL "cpu")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
else()
return()
endif()
if (NOT HIP_FOUND AND CUDA_FOUND)
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
endif()

if (VLLM_TARGET_DEVICE STREQUAL "cpu")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
return()
endif()

Expand Down
19 changes: 17 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl sudo \
&& apt-get install -y ccache software-properties-common git curl sudo wget numactl gcc-10 g++-10 libtcmalloc-minimal4 libnuma-dev libc6 \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
Expand All @@ -41,7 +41,6 @@ COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-cuda.txt


# cuda arch list used by torch
# can be useful for both `dev` and `test`
# explicitly set the list to avoid issues with torch 2.2
Expand Down Expand Up @@ -83,6 +82,22 @@ ARG USE_SCCACHE
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
ARG SCCACHE_REGION_NAME=us-west-2
ARG SCCACHE_S3_NO_CREDENTIALS=0
ENV VLLM_CPU_DISABLE_AVX512="true"

# install oneDNN
RUN git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git

RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/pip \
cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \
-DONEDNN_BUILD_DOC=OFF \
-DONEDNN_BUILD_EXAMPLES=OFF \
-DONEDNN_BUILD_TESTS=OFF \
-DONEDNN_BUILD_GRAPH=OFF \
-DONEDNN_ENABLE_WORKLOAD=INFERENCE \
-DONEDNN_ENABLE_PRIMITIVE=MATMUL && \
cmake --build ./oneDNN/build --target install --config Release

# if USE_SCCACHE is set, use sccache to speed up compilation
RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$USE_SCCACHE" = "1" ]; then \
Expand Down
6 changes: 3 additions & 3 deletions cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
list(APPEND LIBS dnnl numa)

#
# _C extension
# _C_cpu extension
#
set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp"
Expand All @@ -109,7 +109,7 @@ endif()
#

define_gpu_extension_target(
_C
_C_cpu
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
Expand All @@ -120,4 +120,4 @@ define_gpu_extension_target(
)

message(STATUS "Enabling C extension.")
add_dependencies(default _C)
add_dependencies(default _C_cpu)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def _read_requirements(filename: str) -> List[str]:

if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))
ext_modules.append(CMakeExtension(name="vllm._C_cpu"))

package_data = {
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
Expand Down
71 changes: 45 additions & 26 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)

try:
import vllm._C_cpu
except ImportError as e:
logger.warning("Failed to import from vllm._C_cpu with %r", e)

if current_platform.is_rocm():
import vllm._rocm_C # noqa: F401

Expand Down Expand Up @@ -45,27 +50,33 @@ def wrapper(*args, **kwargs):

# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul(out, x)
ops = torch.ops._C_cpu if x.device.type == "cpu" else torch.ops._C
ops.silu_and_mul(out, x)


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul(out, x)
ops = torch.ops._C_cpu if x.device.type == "cpu" else torch.ops._C
ops.gelu_and_mul(out, x)


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x)
ops = torch.ops._C_cpu if x.device.type == "cpu" else torch.ops._C
ops.gelu_tanh_and_mul(out, x)


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_fast(out, x)
ops = torch.ops._C_cpu if x.device.type == "cpu" else torch.ops._C
ops.gelu_fast(out, x)


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_new(out, x)
ops = torch.ops._C_cpu if x.device.type == "cpu" else torch.ops._C
ops.gelu_new(out, x)


def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_quick(out, x)
ops = torch.ops._C_cpu if x.device.type == "cpu" else torch.ops._C
ops.gelu_quick(out, x)


# page attention ops
Expand All @@ -90,12 +101,13 @@ def paged_attention_v1(
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v1(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step)
ops = torch.ops._C_cpu if query.device.type == "cpu" else torch.ops._C
ops.paged_attention_v1(out, query, key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens, block_size,
max_seq_len, alibi_slopes, kv_cache_dtype, k_scale,
v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step)


def paged_attention_v2(
Expand All @@ -122,12 +134,14 @@ def paged_attention_v2(
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v2(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
ops = torch.ops._C_cpu if query.device.type == "cpu" else torch.ops._C
ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, key_cache,
value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step)


def paged_attention_rocm(
Expand Down Expand Up @@ -163,8 +177,9 @@ def rotary_embedding(
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
torch.ops._C.rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox)
ops = torch.ops._C_cpu if query.device.type == "cpu" else torch.ops._C
ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
is_neox)


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
Expand All @@ -180,12 +195,14 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.rms_norm(out, input, weight, epsilon)
ops = torch.ops._C_cpu if input.device.type == "cpu" else torch.ops._C
ops.rms_norm(out, input, weight, epsilon)


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
ops = torch.ops._C_cpu if input.device.type == "cpu" else torch.ops._C
ops.fused_add_rms_norm(input, residual, weight, epsilon)


def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
Expand Down Expand Up @@ -819,9 +836,9 @@ def reshape_and_cache(
k_scale: float,
v_scale: float,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale)
ops = torch.ops._C_cpu_cache_ops if key.device.type == "cpu" else torch.ops._C_cache_ops
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale)


def reshape_and_cache_flash(
Expand All @@ -843,7 +860,9 @@ def reshape_and_cache_flash(
def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
ops = torch.ops._C_cpu_cache_ops if key_caches[
0].device.type == "cpu" else torch.ops._C_cache_ops
ops.copy_blocks(key_caches, value_caches, block_mapping)


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
Expand Down
5 changes: 4 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,12 @@ def __init__(
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
device = None
if hasattr(cache_config, "cpu_kvcache_space_bytes"):
device = "cpu"
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size, blocksparse_params
block_size, device, blocksparse_params
is not None)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
Expand Down
8 changes: 5 additions & 3 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def get_attn_backend(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
device=None,
is_blocksparse: bool = False,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
Expand All @@ -107,7 +108,7 @@ def get_attn_backend(

backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
block_size, device)
if backend == _Backend.FLASH_ATTN:
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
Expand All @@ -123,7 +124,7 @@ def get_attn_backend(
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
assert is_cpu(), RuntimeError(
assert is_cpu() or device == "cpu", RuntimeError(
"Torch SDPA backend is only used for the CPU device.")
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
Expand Down Expand Up @@ -158,6 +159,7 @@ def which_attn_to_use(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
device=None,
) -> _Backend:
"""Returns which flash attention backend to use."""
# Default case.
Expand All @@ -178,7 +180,7 @@ def which_attn_to_use(
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)

if is_cpu():
if is_cpu() or device == "cpu":
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
Expand Down
10 changes: 9 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,7 @@ def maybe_create_spec_config(
typical_acceptance_sampler_posterior_threshold: Optional[float],
typical_acceptance_sampler_posterior_alpha: Optional[float],
disable_logprobs: Optional[bool],
cpu_draft_worker: Optional[bool],
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.

Expand Down Expand Up @@ -1150,6 +1151,7 @@ def maybe_create_spec_config(
If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams.
If not specified, it defaults to True.
cpu_draft_worker (Optional[bool]): Run draft model on CPU.

Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
Expand Down Expand Up @@ -1251,7 +1253,8 @@ def maybe_create_spec_config(
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config,
speculative_draft_tensor_parallel_size, draft_hf_config))
speculative_draft_tensor_parallel_size, draft_hf_config,
cpu_draft_worker))

if num_speculative_tokens is None:
raise ValueError(
Expand Down Expand Up @@ -1280,6 +1283,7 @@ def maybe_create_spec_config(
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
cpu_draft_worker=cpu_draft_worker,
)

@staticmethod
Expand Down Expand Up @@ -1322,6 +1326,7 @@ def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig,
cpu_draft_worker: Optional[bool],
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.

Expand Down Expand Up @@ -1374,6 +1379,7 @@ def __init__(
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
cpu_draft_worker: Optional[bool],
):
"""Create a SpeculativeConfig object.

Expand Down Expand Up @@ -1408,6 +1414,7 @@ def __init__(
returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
cpu_draft_worker: Run draft model on CPU.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
Expand All @@ -1423,6 +1430,7 @@ def __init__(
typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs
self.disable_log_stats = disable_log_stats
self.cpu_draft_worker = cpu_draft_worker or False

self._verify_args()

Expand Down
Loading
Loading