Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from shutil import which
from typing import Dict, List

import requests
import torch
from packaging.version import Version, parse
from setuptools import Extension, find_packages, setup
Expand All @@ -19,6 +20,26 @@
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME


def fetch_vllm_library(url, save_path):
timeout_s = 30
try:
save_dir = os.path.dirname(save_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)

response = requests.get(url, timeout=timeout_s)
response.raise_for_status() # 如果响应状态码不是200,会抛出异常

# 将下载的内容保存到文件
with open(save_path, 'wb') as file:
file.write(response.content)
print(f"save to {save_path}")
except requests.Timeout:
print(f"请求超时,url is {url}")
except requests.RequestException as e:
print(f"请求失败: {e}")


def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
Expand Down Expand Up @@ -207,7 +228,6 @@ def target_name(s: str) -> str:
for ext in self.extensions:
self.configure(ext)
targets.append(target_name(ext.name))

num_jobs, _ = self.compute_num_jobs()

build_args = [
Expand All @@ -217,7 +237,30 @@ def target_name(s: str) -> str:
*[f"--target={name}" for name in targets],
]

subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
if os.getenv('VLLM_BUILD_FROM_SOURCE') is not None:
subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
else:
print(
"if you want build from source,please set env VLLM_BUILD_FROM_SOURCE=1"
)
# 对应oss存储路径为 s3://brain-deploy/data/stepcast/vllm/
fetch_vllm_library(
"http://deploy.i.basemind.com/data/stepcast/vllm/0.7.2/_C.abi3.so",
"build/lib.linux-x86_64-cpython-310/vllm/_C.abi3.so")
fetch_vllm_library(
"http://deploy.i.basemind.com/data/stepcast/vllm/0.7.2/_moe_C.abi3.so",
"build/lib.linux-x86_64-cpython-310/vllm/_moe_C.abi3.so")
fetch_vllm_library(
"http://deploy.i.basemind.com/data/stepcast/vllm/0.7.2/cumem_allocator.abi3.so",
"build/lib.linux-x86_64-cpython-310/vllm/cumem_allocator.abi3.so")
fetch_vllm_library(
"http://deploy.i.basemind.com/data/stepcast/vllm/0.7.2/_vllm_fa3_C.abi3.so",
"build/lib.linux-x86_64-cpython-310/vllm/_vllm_fa3_C.abi3.so"
)
fetch_vllm_library(
"http://deploy.i.basemind.com/data/stepcast/vllm/0.7.2/_vllm_fa2_C.abi3.so",
"build/lib.linux-x86_64-cpython-310/vllm/_vllm_fa2_C.abi3.so"
)

# Install the libraries
for ext in self.extensions:
Expand All @@ -236,7 +279,6 @@ def target_name(s: str) -> str:
prefix = outdir
if '.' in ext.name:
prefix = prefix.parent

# prefix here should actually be the same for all components
install_args = [
"cmake", "--install", ".", "--prefix", prefix, "--component",
Expand Down Expand Up @@ -629,7 +671,7 @@ def _read_requirements(filename: str) -> List[str]:
}

setup(
name="vllm",
name="step-vllm",
version=get_vllm_version(),
author="vLLM Team",
license="Apache 2.0",
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def __init__(
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
alibi_sqrt: bool = False,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
Expand Down
3 changes: 3 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
alibi_sqrt: bool,
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
Expand All @@ -313,6 +314,8 @@ def __init__(
"sliding_window is invalid for blocksparse attention.")
assert logits_soft_cap is None, ValueError(
"logits_soft_cap is invalid for blocksparse attention.")
assert alibi_sqrt is False, ValueError(
"alibi_sqrt is invalid for blocksparse attention.")

if "num_heads" not in blocksparse_params:
blocksparse_params["num_heads"] = num_heads
Expand Down
227 changes: 162 additions & 65 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def __init__(
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
alibi_sqrt: bool,
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
Expand All @@ -626,6 +627,7 @@ def __init__(
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.alibi_sqrt = alibi_sqrt
self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
Expand Down Expand Up @@ -767,44 +769,87 @@ def forward(
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]

flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=q_seq_start_loc,
cu_seqlens_k=k_seq_start_loc,
max_seqlen_q=q_seq_len,
max_seqlen_k=k_seq_len,
softmax_scale=softmax_scale,
causal=_get_causal_option(attn_type),
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
if self.alibi_sqrt:
torch.ops.Optimus.varlen_fwd(
query,
key,
value,
prefill_output,
q_seq_start_loc,
k_seq_start_loc,
q_seq_len,
k_seq_len,
0.0,
query.shape[-1]**(-0.5),
False,
True,
False,
None,
0,
1,
)
else:
flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=q_seq_start_loc,
cu_seqlens_k=k_seq_start_loc,
max_seqlen_q=q_seq_len,
max_seqlen_k=k_seq_len,
softmax_scale=softmax_scale,
causal=_get_causal_option(attn_type),
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens_tensor,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
if self.alibi_sqrt:
torch.ops.Optimus.vllm_varlen_fwd(
query,
key_cache,
value_cache,
prefill_output,
prefill_meta.query_start_loc,
prefill_meta.seq_start_loc,
None,
prefill_meta.block_tables,
alibi_slopes,
prefill_meta.max_query_len,
max_seq_len,
0,
softmax_scale,
False,
True,
window_size[0],
window_size[1],
False,
None,
)
else:
flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens_tensor,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
)

if decode_meta := attn_metadata.decode_metadata:
Expand All @@ -818,43 +863,95 @@ def forward(
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1"
)
flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
seqused_k=decode_meta.seq_lens_tensor,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
out=decode_output,
fa_version=self.fa_version,
if self.alibi_sqrt:
torch.ops.Optimus.vllm_varlen_fwd(
decode_query,
key_cache,
value_cache,
decode_output,
decode_meta.query_start_loc,
decode_meta.seq_start_loc,
None,
decode_meta.block_tables,
alibi_slopes,
decode_meta.max_query_len,
decode_meta.max_decode_seq_len,
0,
softmax_scale,
False,
True,
window_size[0],
window_size[1],
False,
None,
)
else:
flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
seqused_k=decode_meta.seq_lens_tensor,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
out=decode_output,
fa_version=self.fa_version,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg,
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=block_tables_arg,
cache_seqlens=seq_lens_arg,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=decode_output.unsqueeze(1),
fa_version=self.fa_version,
if self.alibi_sqrt:
(
seq_lens_arg,
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False,
attn_type)
torch.ops.Optimus.vllm_fwd_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
None,
None,
seq_lens_arg,
None,
None,
None,
block_tables_arg,
alibi_slopes,
decode_output.unsqueeze(1),
softmax_scale,
True,
window_size[0],
window_size[1],
True,
0
)
else:
(
seq_lens_arg,
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False,
attn_type)
flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=block_tables_arg,
cache_seqlens=seq_lens_arg,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=decode_output.unsqueeze(1),
fa_version=self.fa_version,
)
return output

Expand Down
3 changes: 3 additions & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,12 +902,15 @@ def __init__(
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
alibi_sqrt: bool,
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
assert alibi_sqrt is False, ValueError(
"alibi_sqrt is invalid for FlashInfer attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand Down
Loading