Skip to content

Commit 2bf686d

Browse files
fabianlimAkshat-Tripathi
authored andcommitted
Correction to TP logic for Mamba Mixer 2 when Num Groups not divisible by TP Size (vllm-project#13660)
1 parent 623a414 commit 2bf686d

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
133133
if ngroups % tp_size == 0:
134134
return 0
135135

136-
return tp_size - ngroups % tp_size
136+
# for n_groups == 1, this is exactly tp_size - n_groups
137+
return tp_size - ngroups
137138

138139

139140
def mamba_v2_sharded_weight_loader(
@@ -153,7 +154,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
153154
boundary, loaded_boundary = 0, 0
154155

155156
# - iterate over the shard specs
156-
for full_dim, extra, ratio in shard_spec:
157+
for full_dim, extra, duplicate_groups in shard_spec:
157158
# - full dim is the model dim (before TP).
158159
# - extra > 0, means there is expected overall increase
159160
# of dimensions. This is so because of replication.
@@ -167,7 +168,12 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
167168
# - compute the rank into the loaded shard.
168169
# - if there is replication, different TP shards will
169170
# take from the same rank.
170-
rank = tp_rank // ratio
171+
if duplicate_groups:
172+
# NOTE: currently we only support duplication
173+
# in the case where num_groups == 1
174+
rank = 0
175+
else:
176+
rank = tp_rank
171177

172178
# - leftmost boundary index into loaded weight.
173179
loaded_skip = rank * shard_size
@@ -233,12 +239,21 @@ def __init__(self,
233239
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
234240
# to allocate extra space in the shard, such that groups
235241
# may be replicated to follow the head shard.
242+
# - NOTE: currently for the world size DOES NOT divide groups
243+
# case, we only support the case when n_groups == 1
236244
self.tp_size = get_tensor_model_parallel_world_size()
237245
tp_rank = get_tensor_model_parallel_rank()
238246

239247
assert num_heads % self.tp_size == 0, \
240248
"Tensor parallel world size must divide num heads."
241249

250+
251+
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
252+
(
253+
"If tensor parallel world size does not divide num_heads, "
254+
"then num_groups must equal 1."
255+
)
256+
242257
self.ssm_state_size = ssm_state_size
243258
self.activation = activation
244259

@@ -284,11 +299,10 @@ def __init__(self,
284299
self.n_groups * self.ssm_state_size, # expected model size
285300
(self.n_groups - n_groups) *
286301
self.ssm_state_size, # extra dims assigned
287-
self.num_heads //
288-
n_groups, # ratio for mapping back to original group
302+
n_groups == 1, # if there was only one group
289303
)
290-
intermediate_settings = (intermediate_size, 0, 1)
291-
head_setings = (self.num_heads, 0, 1)
304+
intermediate_settings = (intermediate_size, 0, False)
305+
head_setings = (self.num_heads, 0, False)
292306

293307
# - the weight already has a "weight_loader" attribute
294308
# which set_weight_attrs will raise if we do not

0 commit comments

Comments
 (0)