From f5718c4050b9a221ca29b463a83521758fc0e90b Mon Sep 17 00:00:00 2001 From: oliveryuan Date: Mon, 17 Feb 2025 20:12:15 +0800 Subject: [PATCH 1/5] add step1 model --- setup.py | 50 ++- vllm/attention/backends/abstract.py | 1 + vllm/attention/backends/blocksparse_attn.py | 3 + vllm/attention/backends/flash_attn.py | 228 ++++++++--- vllm/attention/backends/flashinfer.py | 3 + vllm/attention/backends/hpu_attn.py | 3 + vllm/attention/backends/ipex_attn.py | 3 + vllm/attention/backends/pallas.py | 3 + vllm/attention/backends/rocm_flash_attn.py | 3 + vllm/attention/backends/torch_sdpa.py | 3 + vllm/attention/backends/xformers.py | 3 + vllm/attention/layer.py | 3 +- vllm/model_executor/models/registry.py | 1 + vllm/model_executor/models/step1.py | 418 ++++++++++++++++++++ 14 files changed, 655 insertions(+), 70 deletions(-) create mode 100644 vllm/model_executor/models/step1.py diff --git a/setup.py b/setup.py index a4043c43a7d5..f2d8ce73a6b8 100755 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ from shutil import which from typing import Dict, List +import requests import torch from packaging.version import Version, parse from setuptools import Extension, find_packages, setup @@ -19,6 +20,26 @@ from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME +def fetch_vllm_library(url, save_path): + timeout_s = 30 + try: + save_dir = os.path.dirname(save_path) + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + + response = requests.get(url, timeout=timeout_s) + response.raise_for_status() # 如果响应状态码不是200,会抛出异常 + + # 将下载的内容保存到文件 + with open(save_path, 'wb') as file: + file.write(response.content) + print(f"save to {save_path}") + except requests.Timeout: + print(f"请求超时,url is {url}") + except requests.RequestException as e: + print(f"请求失败: {e}") + + def load_module_from_path(module_name, path): spec = importlib.util.spec_from_file_location(module_name, path) module = importlib.util.module_from_spec(spec) @@ -207,7 +228,6 @@ def target_name(s: str) -> str: for ext in self.extensions: self.configure(ext) targets.append(target_name(ext.name)) - num_jobs, _ = self.compute_num_jobs() build_args = [ @@ -217,7 +237,30 @@ def target_name(s: str) -> str: *[f"--target={name}" for name in targets], ] - subprocess.check_call(["cmake", *build_args], cwd=self.build_temp) + if os.getenv('VLLM_BUILD_FROM_SOURCE') is not None: + subprocess.check_call(["cmake", *build_args], cwd=self.build_temp) + else: + print( + "if you want build from source,please set env VLLM_BUILD_FROM_SOURCE=1" + ) + # 对应oss存储路径为 s3://brain-deploy/data/stepcast/vllm/ + fetch_vllm_library( + "http://deploy.i.basemind.com/data/stepcast/vllm/0.7.2/_C.abi3.so", + "build/lib.linux-x86_64-cpython-310/vllm/_C.abi3.so") + fetch_vllm_library( + "http://deploy.i.basemind.com/data/stepcast/vllm/0.7.2/_moe_C.abi3.so", + "build/lib.linux-x86_64-cpython-310/vllm/_moe_C.abi3.so") + fetch_vllm_library( + "http://deploy.i.basemind.com/data/stepcast/vllm/0.7.2/cumem_allocator.abi3.so", + "build/lib.linux-x86_64-cpython-310/vllm/cumem_allocator.abi3.so") + fetch_vllm_library( + "http://deploy.i.basemind.com/data/stepcast/vllm/0.7.2/_vllm_fa3_C.abi3.so", + "build/lib.linux-x86_64-cpython-310/vllm/_vllm_fa3_C.abi3.so" + ) + fetch_vllm_library( + "http://deploy.i.basemind.com/data/stepcast/vllm/0.7.2/_vllm_fa2_C.abi3.so", + "build/lib.linux-x86_64-cpython-310/vllm/_vllm_fa2_C.abi3.so" + ) # Install the libraries for ext in self.extensions: @@ -236,7 +279,6 @@ def target_name(s: str) -> str: prefix = outdir if '.' in ext.name: prefix = prefix.parent - # prefix here should actually be the same for all components install_args = [ "cmake", "--install", ".", "--prefix", prefix, "--component", @@ -629,7 +671,7 @@ def _read_requirements(filename: str) -> List[str]: } setup( - name="vllm", + name="step-vllm", version=get_vllm_version(), author="vLLM Team", license="Apache 2.0", diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 5f0a54013540..10259ee5dab7 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -258,6 +258,7 @@ def __init__( scale: float, num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, + alibi_sqrt: bool = False, sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", blocksparse_params: Optional[Dict[str, Any]] = None, diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 9765e7881ad9..b06d3986bbb5 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -300,6 +300,7 @@ def __init__( scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], + alibi_sqrt: bool, sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, @@ -313,6 +314,8 @@ def __init__( "sliding_window is invalid for blocksparse attention.") assert logits_soft_cap is None, ValueError( "logits_soft_cap is invalid for blocksparse attention.") + assert alibi_sqrt is False, ValueError( + "alibi_sqrt is invalid for blocksparse attention.") if "num_heads" not in blocksparse_params: blocksparse_params["num_heads"] = num_heads diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 6a82127acdf7..c9e4c379c3df 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -610,6 +610,7 @@ def __init__( scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], + alibi_sqrt: bool, sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, @@ -626,6 +627,7 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.alibi_sqrt = alibi_sqrt self.sliding_window = ((sliding_window - 1, 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype @@ -767,21 +769,42 @@ def forward( key = key[:num_prefill_kv_tokens] value = value[:num_prefill_kv_tokens] - flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=q_seq_start_loc, - cu_seqlens_k=k_seq_start_loc, - max_seqlen_q=q_seq_len, - max_seqlen_k=k_seq_len, - softmax_scale=softmax_scale, - causal=_get_causal_option(attn_type), - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - out=prefill_output, - fa_version=self.fa_version, + if self.alibi_sqrt: + torch.ops.Optimus.varlen_fwd( + query, + key, + value, + prefill_output, + q_seq_start_loc, + k_seq_start_loc, + q_seq_len, + k_seq_len, + 0.0, + query.shape[-1]**(-0.5), + False, + True, + False, + None, + 0, + 1, + alibi_slopes, + ) + else: + flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, + softmax_scale=softmax_scale, + causal=_get_causal_option(attn_type), + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + out=prefill_output, + fa_version=self.fa_version, ) else: # prefix-enabled attention @@ -789,22 +812,45 @@ def forward( "Only decoder-only models support prefix caching") assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - seqused_k=prefill_meta.seq_lens_tensor, - max_seqlen_k=max_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=logits_soft_cap, - out=prefill_output, - fa_version=self.fa_version, + if self.alibi_sqrt: + torch.ops.Optimus.vllm_varlen_fwd( + query, + key_cache, + value_cache, + prefill_output, + prefill_meta.query_start_loc, + prefill_meta.seq_start_loc, + None, + prefill_meta.block_tables, + alibi_slopes, + prefill_meta.max_query_len, + max_seq_len, + 0, + query.shape[-1]**(-0.5), + False, + True, + window_size[0], + window_size[1], + False, + None, + ) + else: + flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + seqused_k=prefill_meta.seq_lens_tensor, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + out=prefill_output, + fa_version=self.fa_version, ) if decode_meta := attn_metadata.decode_metadata: @@ -818,43 +864,95 @@ def forward( assert attn_type == AttentionType.DECODER, ( "Only decoder-only models support max_decode_query_len > 1" ) - flash_attn_varlen_func( - q=decode_query, - k=key_cache, - v=value_cache, - cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_q=decode_meta.max_decode_query_len, - seqused_k=decode_meta.seq_lens_tensor, - max_seqlen_k=decode_meta.max_decode_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - block_table=decode_meta.block_tables, - out=decode_output, - fa_version=self.fa_version, + if self.alibi_sqrt: + torch.ops.Optimus.vllm_varlen_fwd( + decode_query, + key_cache, + value_cache, + decode_output, + decode_meta.query_start_loc, + decode_meta.seq_start_loc, + None, + decode_meta.block_tables, + alibi_slopes, + decode_meta.max_query_len, + decode_meta.max_decode_seq_len, + 0, + query.shape[-1]**(-0.5), + False, + True, + window_size[0], + window_size[1], + False, + None, + ) + else: + flash_attn_varlen_func( + q=decode_query, + k=key_cache, + v=value_cache, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.max_decode_query_len, + seqused_k=decode_meta.seq_lens_tensor, + max_seqlen_k=decode_meta.max_decode_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + block_table=decode_meta.block_tables, + out=decode_output, + fa_version=self.fa_version, ) else: # Use flash_attn_with_kvcache for normal decoding. - ( - seq_lens_arg, - _, - block_tables_arg, - ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - flash_attn_with_kvcache( - q=decode_query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - block_table=block_tables_arg, - cache_seqlens=seq_lens_arg, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - out=decode_output.unsqueeze(1), - fa_version=self.fa_version, + if self.alibi_sqrt: + ( + seq_lens_arg, + _, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, + attn_type) + torch.ops.Optimus.vllm_fwd_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + k=None, + v=None, + cache_seqlens=seq_lens_arg, + rotary_cos=None, + rotary_sin=None, + cache_batch_idx=None, + block_table=block_tables_arg, + alibi_slopes=alibi_slopes, + out=decode_output.unsqueeze(1), + softmax_scale=softmax_scale, + causal=True, + window_size_left=window_size[0], + window_size_right=window_size[1], + rotary_interleaved=True, + num_splits=0 + ) + else: + ( + seq_lens_arg, + _, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, + attn_type) + flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + block_table=block_tables_arg, + cache_seqlens=seq_lens_arg, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + out=decode_output.unsqueeze(1), + fa_version=self.fa_version, ) return output diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 715ed6748b84..787797729dcd 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -902,12 +902,15 @@ def __init__( scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], + alibi_sqrt: bool, sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, ) -> None: + assert alibi_sqrt is False, ValueError( + "alibi_sqrt is invalid for FlashInfer attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 1ad5e6e8e4e1..834f2e5f3546 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -102,6 +102,7 @@ def __init__( scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], + alibi_sqrt: bool, sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, @@ -109,6 +110,8 @@ def __init__( attn_type: str = AttentionType.DECODER, ) -> None: super(AttentionImpl, self).__init__() + assert alibi_sqrt is False, ValueError( + "alibi_sqrt is invalid for HPU attention.") self.kv_cache_dtype = kv_cache_dtype self.num_heads = num_heads self.head_size = head_size diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index b4879af4cf20..ca9076493e07 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -113,12 +113,15 @@ def __init__( scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], + alibi_sqrt: bool, sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, ) -> None: + assert alibi_sqrt is False, ValueError( + "alibi_sqrt is invalid for IPEX attention.") if blocksparse_params is not None: raise ValueError( "IPEX backend does not support block-sparse attention.") diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index b61dfe63ddca..e361d52aba5a 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -99,12 +99,15 @@ def __init__( scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], + alibi_sqrt: bool, sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, ) -> None: + assert alibi_sqrt is False, ValueError( + "alibi_sqrt is invalid for Pallas attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 02bff57a62b7..1df3217669a7 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -455,12 +455,15 @@ def __init__( scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], + alibi_sqrt: bool, sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, ) -> None: + assert alibi_sqrt is False, ValueError( + "alibi_sqrt is invalid for ROCmFlashAttention.") if blocksparse_params is not None: raise ValueError( "ROCmFlashAttention does not support blocksparse attention.") diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 25fe6ed95c5d..58af878a47ef 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -395,12 +395,15 @@ def __init__( scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], + alibi_sqrt: bool, sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, ) -> None: + assert alibi_sqrt is False, ValueError( + "alibi_sqrt is invalid for TorchSDPA.") if blocksparse_params is not None: raise ValueError( "Torch SPDA does not support block-sparse attention.") diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 723a4558d0b3..b1008df4bc24 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -384,12 +384,15 @@ def __init__( scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], + alibi_sqrt: bool, sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, ) -> None: + assert alibi_sqrt is False, ValueError( + "alibi_sqrt is invalid for XFormers.") if blocksparse_params is not None: raise ValueError( "XFormers does not support block-sparse attention.") diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index e4df7ffc5885..5eec1a2711d4 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -37,6 +37,7 @@ def __init__( scale: float, num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, + alibi_sqrt: bool = False, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, blocksparse_params: Optional[Dict[str, Any]] = None, @@ -113,7 +114,7 @@ def __init__( use_mla=use_mla) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, + alibi_slopes, alibi_sqrt, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, **extra_impl_args) self.num_heads = num_heads diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 3b2a7069efc9..2b1dfc412a83 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -103,6 +103,7 @@ "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 + "Step1ForCausalLM": ("step1", "Step1ForCausalLM"), } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/step1.py b/vllm/model_executor/models/step1.py new file mode 100644 index 000000000000..0bc11cacd991 --- /dev/null +++ b/vllm/model_executor/models/step1.py @@ -0,0 +1,418 @@ +import math +from typing import Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) + + +def _get_alibi_slopes(n_heads): + n = 2**math.floor(math.log2(n_heads)) # nearest 2**n to n_heads + m0 = 2.0**(-8.0 / n) + slopes = np.power(m0, np.arange(1, n + 1)) + if n < n_heads: + m1 = 2.0**(-4.0 / n) + mm = np.power(m1, np.arange(1, 1 + 2 * (n_heads - n), 2)) + slopes = np.concatenate([slopes, mm]) + return slopes + + +def _get_ntk_alibi_slopes(max_pos_interp_ratio, slopes): + if max_pos_interp_ratio == 1.0: + return slopes + smax, smin = slopes.max(), slopes.min() + D0 = np.log2(smax) - np.log2(smin) + W1 = (np.log2(smax) - np.log2(slopes)) / D0 + ratios = np.power(max_pos_interp_ratio, W1) + return slopes / (ratios**0.5) + + +class Step1MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class Step1Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + slopes: Optional[List[float]] = None, + max_pos_interp_ratio: float = 1.0, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # Create the alibi slopes and slice them. + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + if slopes is None: + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = _get_ntk_alibi_slopes(max_pos_interp_ratio, + alibi_slopes) + alibi_slopes = alibi_slopes[head_start:head_end] + else: + assert len(slopes) == self.total_num_heads + alibi_slopes = _get_ntk_alibi_slopes(max_pos_interp_ratio, + slopes).tolist() + alibi_slopes = slopes[head_start:head_end] + + scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + self.num_kv_heads, + alibi_slopes, + alibi_sqrt=True, + cache_config=cache_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Step1DecoderLayer(nn.Module): + + def __init__(self, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + config = model_config.hf_config + self.hidden_size = config.hidden_size + self.self_attn = Step1Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_attention_groups, + slopes=config.alibi_slopes, + max_pos_interp_ratio=config.max_pos_interp_ratio, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = Step1MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + sampling_metadata: Optional[SamplingMetadata] = None + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Step1Model(nn.Module): + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + assert lora_config is None + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Step1DecoderLayer(model_config=vllm_config.model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Step1PretrainedModel(nn.Module, SupportsPP): + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + +class Step1ForCausalLM(Step1PretrainedModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + self.config = config + + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): + return Step1Model(vllm_config=vllm_config, prefix=prefix) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def sequence_flops(self, input_length, context_length): + output_flops = 1 * self.config.hidden_size * self.config.vocab_size * 2.0 / 1e12 + return super().sequence_flops(input_length, + context_length) + output_flops \ No newline at end of file From a6732012c5d3da2de74cc45d0769ef8f87c658d4 Mon Sep 17 00:00:00 2001 From: oliveryuan Date: Tue, 18 Feb 2025 00:02:35 +0800 Subject: [PATCH 2/5] update --- vllm/attention/backends/flash_attn.py | 37 +++++++++++++-------------- vllm/model_executor/models/step1.py | 2 +- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c9e4c379c3df..af7c6531cee6 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -787,7 +787,6 @@ def forward( None, 0, 1, - alibi_slopes, ) else: flash_attn_varlen_func( @@ -914,24 +913,24 @@ def forward( ) = get_seq_len_block_table_args(decode_meta, False, attn_type) torch.ops.Optimus.vllm_fwd_kvcache( - q=decode_query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - k=None, - v=None, - cache_seqlens=seq_lens_arg, - rotary_cos=None, - rotary_sin=None, - cache_batch_idx=None, - block_table=block_tables_arg, - alibi_slopes=alibi_slopes, - out=decode_output.unsqueeze(1), - softmax_scale=softmax_scale, - causal=True, - window_size_left=window_size[0], - window_size_right=window_size[1], - rotary_interleaved=True, - num_splits=0 + decode_query.unsqueeze(1), + key_cache, + value_cache, + None, + None, + seq_lens_arg, + None, + None, + None, + block_tables_arg, + alibi_slopes, + decode_output.unsqueeze(1), + softmax_scale, + True, + window_size[0], + window_size[1], + True, + 0 ) else: ( diff --git a/vllm/model_executor/models/step1.py b/vllm/model_executor/models/step1.py index 0bc11cacd991..5a75f0c1cb02 100644 --- a/vllm/model_executor/models/step1.py +++ b/vllm/model_executor/models/step1.py @@ -212,7 +212,7 @@ def forward( hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - sampling_metadata: Optional[SamplingMetadata] = None + residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention if residual is None: From 3dab34cb423f82b20f737df0879b6ca663ab1031 Mon Sep 17 00:00:00 2001 From: oliveryuan Date: Tue, 18 Feb 2025 00:35:09 +0800 Subject: [PATCH 3/5] update --- vllm/attention/backends/flash_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index af7c6531cee6..3c796bad4273 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -825,7 +825,7 @@ def forward( prefill_meta.max_query_len, max_seq_len, 0, - query.shape[-1]**(-0.5), + softmax_scale, False, True, window_size[0], @@ -877,7 +877,7 @@ def forward( decode_meta.max_query_len, decode_meta.max_decode_seq_len, 0, - query.shape[-1]**(-0.5), + softmax_scale, False, True, window_size[0], From c3475c7a961387817f48ebb58c7bd391763a3888 Mon Sep 17 00:00:00 2001 From: oliveryuan Date: Tue, 18 Feb 2025 01:23:31 +0800 Subject: [PATCH 4/5] update --- vllm/model_executor/models/step1.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/step1.py b/vllm/model_executor/models/step1.py index 5a75f0c1cb02..d8ea67affd1d 100644 --- a/vllm/model_executor/models/step1.py +++ b/vllm/model_executor/models/step1.py @@ -239,6 +239,7 @@ class Step1Model(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + print(config) cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -410,9 +411,4 @@ def sample( sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def sequence_flops(self, input_length, context_length): - output_flops = 1 * self.config.hidden_size * self.config.vocab_size * 2.0 / 1e12 - return super().sequence_flops(input_length, - context_length) + output_flops \ No newline at end of file + return next_tokens \ No newline at end of file From 61f5a0113603e7e006857782c7921a7f5da54013 Mon Sep 17 00:00:00 2001 From: oliveryuan Date: Tue, 18 Feb 2025 14:53:57 +0800 Subject: [PATCH 5/5] misc: fix conf --- vllm/config.py | 2 ++ vllm/model_executor/models/step1.py | 20 +++++++++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9ba497576124..b89c0a9677f0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -817,6 +817,8 @@ def get_total_num_kv_heads(self) -> int: "num_key_value_heads", # For ChatGLM: "multi_query_group_num", + # For Step1: + "num_attention_groups", ] for attr in attributes: num_kv_heads = getattr(self.hf_text_config, attr, None) diff --git a/vllm/model_executor/models/step1.py b/vllm/model_executor/models/step1.py index d8ea67affd1d..b6056bdd9bbe 100644 --- a/vllm/model_executor/models/step1.py +++ b/vllm/model_executor/models/step1.py @@ -31,6 +31,21 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +try: + import os + OPTIMUS_LIB_PATH = os.environ.get('OPTIMUS_LIB_PATH') + if torch.__version__ >= "2.5": + torch.ops.load_library(os.path.join(OPTIMUS_LIB_PATH, "liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so")) + elif torch.__version__ >= "2.3": + torch.ops.load_library(os.path.join(OPTIMUS_LIB_PATH, 'liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so')) + elif torch.__version__ >= "2.2": + torch.ops.load_library(os.path.join(OPTIMUS_LIB_PATH, 'liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so')) + else: + raise ImportError("Failed to load optimus library for flash attn ops for step1 model") +except: + raise ImportError("Failed to load optimus library for flash attn ops for step1 model") + + def _get_alibi_slopes(n_heads): n = 2**math.floor(math.log2(n_heads)) # nearest 2**n to n_heads m0 = 2.0**(-8.0 / n) @@ -188,8 +203,8 @@ def __init__(self, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_attention_groups, - slopes=config.alibi_slopes, - max_pos_interp_ratio=config.max_pos_interp_ratio, + slopes=config.alibi_slopes if getattr(config, 'alibi_slopes', None) else None, + max_pos_interp_ratio=getattr(config, 'max_pos_interp_ratio', 1.0), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", @@ -239,7 +254,6 @@ class Step1Model(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - print(config) cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config