9
9
from ...utils import check_outputs_equal
10
10
11
11
# This test is for the hybrid models
12
- MODELS = ["ai21labs/Jamba-tiny-dev" ]
12
+ MODELS = ["ai21labs/Jamba-tiny-dev" , "Zyphra/Zamba2-1.2B-instruct" ]
13
13
# Bamba at Fp32 is too big for the CI (L4 GPU).
14
14
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
15
15
@@ -27,17 +27,19 @@ def test_models(
27
27
) -> None :
28
28
29
29
# numeric error produces different generation
30
- if ' Bamba' in model :
30
+ if " Bamba" in model :
31
31
example_prompts .pop (3 )
32
32
33
- with hf_runner (
34
- model ,
35
- dtype = dtype ,
36
- model_kwargs = {
37
- "use_mamba_kernels" :
38
- False , # mamba kernels are not installed so HF
39
- # don't use them
40
- }) as hf_model :
33
+ model_kwargs = {
34
+ "use_mamba_kernels" : False , # mamba kernels are not installed so HF
35
+ # don't use them
36
+ }
37
+ if "Zamba2" in model :
38
+ # Zamba2 HF implementation automatically checks if mamba kernels are
39
+ # installed
40
+ model_kwargs = {}
41
+
42
+ with hf_runner (model , dtype = dtype , model_kwargs = model_kwargs ) as hf_model :
41
43
hf_outputs = hf_model .generate_greedy (example_prompts , max_tokens )
42
44
43
45
with vllm_runner (model , dtype = dtype ) as vllm_model :
@@ -112,26 +114,31 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
112
114
def test_mamba_prefill_chunking (hf_runner , vllm_runner , example_prompts ,
113
115
model : str , dtype : str ,
114
116
max_tokens : int ) -> None :
115
- # numeric error during prefill chucking produces different generation
117
+ # numeric error during prefill chunking produces different generation
116
118
# compared to w/o prefill chunking for those examples, removed them for now
117
- if ' Jamba' in model :
119
+ if " Jamba" in model :
118
120
example_prompts .pop (7 )
119
121
example_prompts .pop (2 )
120
122
example_prompts .pop (1 )
121
- elif ' Bamba' in model :
123
+ elif " Bamba" in model :
122
124
example_prompts .pop (6 )
123
125
example_prompts .pop (3 )
124
126
example_prompts .pop (2 )
125
127
dtype = "half" # use a different dtype for Bamba
126
-
127
- with hf_runner (
128
- model ,
129
- dtype = dtype ,
130
- model_kwargs = {
131
- "use_mamba_kernels" :
132
- False , # mamba kernels are not installed so HF
133
- # don't use them
134
- }) as hf_model :
128
+ elif "Zamba2" in model :
129
+ example_prompts .pop (7 )
130
+ dtype = "half"
131
+
132
+ model_kwargs = {
133
+ "use_mamba_kernels" : False , # mamba kernels are not installed so HF
134
+ # don't use them
135
+ }
136
+ if "Zamba2" in model :
137
+ # Zamba2 HF implementation automatically checks if mamba kernels are
138
+ # installed
139
+ model_kwargs = {}
140
+
141
+ with hf_runner (model , dtype = dtype , model_kwargs = model_kwargs ) as hf_model :
135
142
non_chunked = hf_model .generate_greedy (example_prompts , max_tokens )
136
143
137
144
with vllm_runner (model ,
0 commit comments