Skip to content

Commit f5c163a

Browse files
committed
Added the sharding to shard lora_b
Signed-off-by: Xiongfei Wei <[email protected]>
1 parent aa8c235 commit f5c163a

File tree

1 file changed

+42
-2
lines changed

1 file changed

+42
-2
lines changed

tpu_commons/models/vllm/sharding.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,47 @@ def _shard_column_parallel_linear_lora(
128128

129129
def _shard_qkv_parallel_linear_lora(layer: MergedQKVParallelLinearWithLoRA,
130130
mesh: Mesh) -> None:
131-
_shard_base_linear_lora(layer, mesh)
131+
# mesh=Mesh(axis_sizes=(1, 2), axis_names=('data', 'model'), axis_types=(Auto, Auto))
132+
# NOTE: lora_a_stacked[i] has shape [max_loras, 1, num_out, num_in]
133+
sharded_lora_a_tpu = torch.nn.ParameterList()
134+
sharded_lora_b_tpu = torch.nn.ParameterList()
135+
sharded_lora_bias_tpu = torch.nn.ParameterList()
136+
137+
assert layer.n_slices > 0, "layer.n_slices should be greater than 0"
138+
mesh_lora_b_shape = (1, 1) + (mesh.shape['data'], mesh.shape['model'])
139+
mesh_lora_b_axis = ('replica_num_lora', 'replica', 'data', 'model')
140+
lora_b_mesh = jax.make_mesh(
141+
mesh_lora_b_shape, mesh_lora_b_axis,
142+
devices=mesh.devices[0]) # mesh.devices=[[device0, ..device_n]]
143+
lora_b_partition_spec = P(None, None, 'model', None)
144+
lora_b_sharding = NamedSharding(lora_b_mesh, lora_b_partition_spec)
145+
146+
mesh_lora_bias_shape = (1, 1) + (mesh.shape['model'], )
147+
mesh_lora_bias_axis = ('replica_num_lora', 'replica', 'model')
148+
lora_bias_mesh = jax.make_mesh(
149+
mesh_lora_bias_shape, mesh_lora_bias_axis,
150+
devices=mesh.devices[0]) # mesh.devices=[[device0, ..device_n]]
151+
lora_bias_partition_spec = P(None, None, 'model')
152+
lora_bias_sharding = NamedSharding(lora_bias_mesh,
153+
lora_bias_partition_spec)
154+
155+
for i in range(layer.n_slices):
156+
sharded_lora_a_tpu.append(
157+
_shard_tensor_to_tpu_replicated(layer.lora_a_stacked[i], mesh))
158+
159+
sharded_lora_b_tpu.append(
160+
_convert_to_torchax_and_shard(layer.lora_b_stacked[i],
161+
lora_b_sharding))
162+
163+
if layer.lora_bias_stacked is not None:
164+
sharded_lora_bias_tpu.append(
165+
_convert_to_torchax_and_shard(layer.lora_bias_stacked[i],
166+
lora_bias_sharding))
167+
168+
layer.lora_a_stacked = sharded_lora_a_tpu
169+
layer.lora_b_stacked = sharded_lora_b_tpu
170+
if layer.lora_bias_stacked is not None:
171+
layer.lora_bias_stacked = sharded_lora_bias_tpu
132172

133173

134174
def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA,
@@ -152,7 +192,7 @@ def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA,
152192
def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
153193
for path, module in model.named_modules():
154194
for module_type, sharding_func in MODULE_TYPE_TO_SHARDING_FUNC:
155-
if isinstance(module, module_type):
195+
if type(module) is module_type:
156196
logger.debug("shard %s with %s", path, sharding_func)
157197
sharding_func(module, mesh)
158198
break

0 commit comments

Comments
 (0)