@@ -133,7 +133,8 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
133
133
if ngroups % tp_size == 0 :
134
134
return 0
135
135
136
- return tp_size - ngroups % tp_size
136
+ # for n_groups == 1, this is exactly tp_size - n_groups
137
+ return tp_size - ngroups
137
138
138
139
139
140
def mamba_v2_sharded_weight_loader (
@@ -153,7 +154,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
153
154
boundary , loaded_boundary = 0 , 0
154
155
155
156
# - iterate over the shard specs
156
- for full_dim , extra , ratio in shard_spec :
157
+ for full_dim , extra , duplicate_groups in shard_spec :
157
158
# - full dim is the model dim (before TP).
158
159
# - extra > 0, means there is expected overall increase
159
160
# of dimensions. This is so because of replication.
@@ -167,7 +168,12 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
167
168
# - compute the rank into the loaded shard.
168
169
# - if there is replication, different TP shards will
169
170
# take from the same rank.
170
- rank = tp_rank // ratio
171
+ if duplicate_groups :
172
+ # NOTE: currently we only support duplication
173
+ # in the case where num_groups == 1
174
+ rank = 0
175
+ else :
176
+ rank = tp_rank
171
177
172
178
# - leftmost boundary index into loaded weight.
173
179
loaded_skip = rank * shard_size
@@ -233,12 +239,21 @@ def __init__(self,
233
239
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
234
240
# to allocate extra space in the shard, such that groups
235
241
# may be replicated to follow the head shard.
242
+ # - NOTE: currently for the world size DOES NOT divide groups
243
+ # case, we only support the case when n_groups == 1
236
244
self .tp_size = get_tensor_model_parallel_world_size ()
237
245
tp_rank = get_tensor_model_parallel_rank ()
238
246
239
247
assert num_heads % self .tp_size == 0 , \
240
248
"Tensor parallel world size must divide num heads."
241
249
250
+
251
+ assert (n_groups % self .tp_size ) == 0 or n_groups == 1 , \
252
+ (
253
+ "If tensor parallel world size does not divide num_heads, "
254
+ "then num_groups must equal 1."
255
+ )
256
+
242
257
self .ssm_state_size = ssm_state_size
243
258
self .activation = activation
244
259
@@ -284,11 +299,10 @@ def __init__(self,
284
299
self .n_groups * self .ssm_state_size , # expected model size
285
300
(self .n_groups - n_groups ) *
286
301
self .ssm_state_size , # extra dims assigned
287
- self .num_heads //
288
- n_groups , # ratio for mapping back to original group
302
+ n_groups == 1 , # if there was only one group
289
303
)
290
- intermediate_settings = (intermediate_size , 0 , 1 )
291
- head_setings = (self .num_heads , 0 , 1 )
304
+ intermediate_settings = (intermediate_size , 0 , False )
305
+ head_setings = (self .num_heads , 0 , False )
292
306
293
307
# - the weight already has a "weight_loader" attribute
294
308
# which set_weight_attrs will raise if we do not
0 commit comments