Skip to content

Commit 452e8fd

Browse files
yury-tokpanovQuentin-AnthonytlrmchlsmthDarkLight1337
authored
[MODEL] Add support for Zamba2 models (vllm-project#13185)
Signed-off-by: Yury Tokpanov <[email protected]> Signed-off-by: Quentin Anthony <[email protected]> Co-authored-by: Quentin Anthony <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 8b793f7 commit 452e8fd

File tree

9 files changed

+1082
-27
lines changed

9 files changed

+1082
-27
lines changed

docs/source/models/supported_models.md

+5
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,11 @@ See [this page](#generative-models) for more information on how to use generativ
477477
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
478478
* ✅︎
479479
* ✅︎
480+
- * `Zamba2ForCausalLM`
481+
* Zamba2
482+
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
483+
*
484+
*
480485
:::
481486

482487
:::{note}

tests/models/decoder_only/language/test_hybrid.py

+29-22
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ...utils import check_outputs_equal
1010

1111
# This test is for the hybrid models
12-
MODELS = ["ai21labs/Jamba-tiny-dev"]
12+
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
1313
# Bamba at Fp32 is too big for the CI (L4 GPU).
1414
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
1515

@@ -27,17 +27,19 @@ def test_models(
2727
) -> None:
2828

2929
# numeric error produces different generation
30-
if 'Bamba' in model:
30+
if "Bamba" in model:
3131
example_prompts.pop(3)
3232

33-
with hf_runner(
34-
model,
35-
dtype=dtype,
36-
model_kwargs={
37-
"use_mamba_kernels":
38-
False, # mamba kernels are not installed so HF
39-
# don't use them
40-
}) as hf_model:
33+
model_kwargs = {
34+
"use_mamba_kernels": False, # mamba kernels are not installed so HF
35+
# don't use them
36+
}
37+
if "Zamba2" in model:
38+
# Zamba2 HF implementation automatically checks if mamba kernels are
39+
# installed
40+
model_kwargs = {}
41+
42+
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
4143
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
4244

4345
with vllm_runner(model, dtype=dtype) as vllm_model:
@@ -112,26 +114,31 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
112114
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
113115
model: str, dtype: str,
114116
max_tokens: int) -> None:
115-
# numeric error during prefill chucking produces different generation
117+
# numeric error during prefill chunking produces different generation
116118
# compared to w/o prefill chunking for those examples, removed them for now
117-
if 'Jamba' in model:
119+
if "Jamba" in model:
118120
example_prompts.pop(7)
119121
example_prompts.pop(2)
120122
example_prompts.pop(1)
121-
elif 'Bamba' in model:
123+
elif "Bamba" in model:
122124
example_prompts.pop(6)
123125
example_prompts.pop(3)
124126
example_prompts.pop(2)
125127
dtype = "half" # use a different dtype for Bamba
126-
127-
with hf_runner(
128-
model,
129-
dtype=dtype,
130-
model_kwargs={
131-
"use_mamba_kernels":
132-
False, # mamba kernels are not installed so HF
133-
# don't use them
134-
}) as hf_model:
128+
elif "Zamba2" in model:
129+
example_prompts.pop(7)
130+
dtype = "half"
131+
132+
model_kwargs = {
133+
"use_mamba_kernels": False, # mamba kernels are not installed so HF
134+
# don't use them
135+
}
136+
if "Zamba2" in model:
137+
# Zamba2 HF implementation automatically checks if mamba kernels are
138+
# installed
139+
model_kwargs = {}
140+
141+
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
135142
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
136143

137144
with vllm_runner(model,

tests/models/registry.py

+2
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ def check_available_online(
195195
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
196196
is_available_online=False,
197197
trust_remote_code=True),
198+
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct",
199+
min_transformers_version="4.49"),
198200
# [Encoder-decoder]
199201
"BartModel": _HfExamplesInfo("facebook/bart-base"),
200202
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),

vllm/config.py

+14
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,11 @@ def get_head_size(self) -> int:
821821
if qk_rope_head_dim and qk_nope_head_dim:
822822
return qk_rope_head_dim + qk_nope_head_dim
823823

824+
if hasattr(self.hf_text_config,
825+
"model_type") and (self.hf_text_config.model_type
826+
== "zamba2"):
827+
return self.hf_text_config.attention_head_dim
828+
824829
if self.is_attention_free:
825830
return 0
826831

@@ -944,6 +949,15 @@ def get_num_layers_by_block_type(
944949
"cannot determine the num of "
945950
f"{block_type.value} layers")
946951

952+
if hasattr(self.hf_text_config,
953+
"model_type") and (self.hf_text_config.model_type
954+
== "zamba2"):
955+
if attn_block_type:
956+
return sum(t == "hybrid"
957+
for t in layers_block_type_value[start:end])
958+
else:
959+
return self.get_num_layers(parallel_config)
960+
947961
return sum(t == block_type.value
948962
for t in layers_block_type_value[start:end])
949963

vllm/model_executor/layers/mamba/mamba_mixer2.py

-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,6 @@ def __init__(self,
245245
assert num_heads % self.tp_size == 0, \
246246
"Tensor parallel world size must divide num heads."
247247

248-
249248
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
250249
(
251250
"If tensor parallel world size does not divide num_heads, "

vllm/model_executor/models/bamba.py

-2
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838
make_empty_intermediate_tensors_factory, make_layers,
3939
maybe_prefix)
4040

41-
KVCache = Tuple[torch.Tensor, torch.Tensor]
42-
4341

4442
class BambaMLP(nn.Module):
4543

vllm/model_executor/models/jamba.py

-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
make_empty_intermediate_tensors_factory, make_layers,
3737
maybe_prefix)
3838

39-
KVCache = Tuple[torch.Tensor, torch.Tensor]
40-
4139

4240
class JambaMoE(nn.Module):
4341

vllm/model_executor/models/registry.py

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
106106
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
107107
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
108+
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
108109
# [Encoder-decoder]
109110
"BartModel": ("bart", "BartForConditionalGeneration"),
110111
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),

0 commit comments

Comments
 (0)