@@ -128,7 +128,47 @@ def _shard_column_parallel_linear_lora(
128128
129129def _shard_qkv_parallel_linear_lora (layer : MergedQKVParallelLinearWithLoRA ,
130130 mesh : Mesh ) -> None :
131- _shard_base_linear_lora (layer , mesh )
131+ # mesh=Mesh(axis_sizes=(1, 2), axis_names=('data', 'model'), axis_types=(Auto, Auto))
132+ # NOTE: lora_a_stacked[i] has shape [max_loras, 1, num_out, num_in]
133+ sharded_lora_a_tpu = torch .nn .ParameterList ()
134+ sharded_lora_b_tpu = torch .nn .ParameterList ()
135+ sharded_lora_bias_tpu = torch .nn .ParameterList ()
136+
137+ assert layer .n_slices > 0 , "layer.n_slices should be greater than 0"
138+ mesh_lora_b_shape = (1 , 1 ) + (mesh .shape ['data' ], mesh .shape ['model' ])
139+ mesh_lora_b_axis = ('replica_num_lora' , 'replica' , 'data' , 'model' )
140+ lora_b_mesh = jax .make_mesh (
141+ mesh_lora_b_shape , mesh_lora_b_axis ,
142+ devices = mesh .devices [0 ]) # mesh.devices=[[device0, ..device_n]]
143+ lora_b_partition_spec = P (None , None , 'model' , None )
144+ lora_b_sharding = NamedSharding (lora_b_mesh , lora_b_partition_spec )
145+
146+ mesh_lora_bias_shape = (1 , 1 ) + (mesh .shape ['model' ], )
147+ mesh_lora_bias_axis = ('replica_num_lora' , 'replica' , 'model' )
148+ lora_bias_mesh = jax .make_mesh (
149+ mesh_lora_bias_shape , mesh_lora_bias_axis ,
150+ devices = mesh .devices [0 ]) # mesh.devices=[[device0, ..device_n]]
151+ lora_bias_partition_spec = P (None , None , 'model' )
152+ lora_bias_sharding = NamedSharding (lora_bias_mesh ,
153+ lora_bias_partition_spec )
154+
155+ for i in range (layer .n_slices ):
156+ sharded_lora_a_tpu .append (
157+ _shard_tensor_to_tpu_replicated (layer .lora_a_stacked [i ], mesh ))
158+
159+ sharded_lora_b_tpu .append (
160+ _convert_to_torchax_and_shard (layer .lora_b_stacked [i ],
161+ lora_b_sharding ))
162+
163+ if layer .lora_bias_stacked is not None :
164+ sharded_lora_bias_tpu .append (
165+ _convert_to_torchax_and_shard (layer .lora_bias_stacked [i ],
166+ lora_bias_sharding ))
167+
168+ layer .lora_a_stacked = sharded_lora_a_tpu
169+ layer .lora_b_stacked = sharded_lora_b_tpu
170+ if layer .lora_bias_stacked is not None :
171+ layer .lora_bias_stacked = sharded_lora_bias_tpu
132172
133173
134174def _shard_row_parallel_linear_lora (layer : RowParallelLinearWithLoRA ,
@@ -152,7 +192,7 @@ def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA,
152192def _shard_module_to_tpu (model : torch .nn .Module , mesh : Mesh ) -> None :
153193 for path , module in model .named_modules ():
154194 for module_type , sharding_func in MODULE_TYPE_TO_SHARDING_FUNC :
155- if isinstance (module , module_type ) :
195+ if type (module ) is module_type :
156196 logger .debug ("shard %s with %s" , path , sharding_func )
157197 sharding_func (module , mesh )
158198 break
0 commit comments