diff --git a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/paxml_converters.py b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/paxml_converters.py index 7737992ad..4ca508922 100644 --- a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/paxml_converters.py +++ b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/paxml_converters.py @@ -38,18 +38,38 @@ def _generate_ckpt_map(self): hidden_dim = num_of_head * head_dim mlp_intermediate_dim = self.model_config.mlp_intermediate_dim - for i in range(self.model_config.num_of_layer): - ckpt_map.update({ - f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w": - self._get_convert_pkg( + if self.use_gated_act: + ckpt_map[f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1_gate.linear.w"] = \ + self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel", (hidden_dim, mlp_intermediate_dim), 0, - extra_src_paths = [f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1_gate.linear.w"], - stack_dim = -2) if self.use_gated_act else \ + extra_src_paths = [f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w"], + stack_dim = -2) + else: + ckpt_map[f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w"] = \ self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel", (hidden_dim, mlp_intermediate_dim), 0, lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))), + + for i in range(self.model_config.num_of_layer): + ckpt_map_for_ffn1 = {} + if self.use_gated_act: + ckpt_map_for_ffn1[f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1_gate.linear.w"] = \ + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel", + (hidden_dim, mlp_intermediate_dim), 0, + extra_src_paths = [f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w"], + stack_dim = -2) + else: + ckpt_map_for_ffn1[f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w"] = \ + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel", + (hidden_dim, mlp_intermediate_dim), 0, + lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))), + + ckpt_map.update({ + **ckpt_map_for_ffn1, f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer2.linear.w": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wo_kernel", @@ -313,17 +333,28 @@ def _generate_ckpt_map(self): hidden_dim = num_of_head * head_dim mlp_intermediate_dim = self.model_config.mlp_intermediate_dim + ckpt_map_for_ffn1 = {} + if self.use_gated_act: + ckpt_map_for_ffn1['lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1_gate.linear.w'] = \ + self._get_convert_pkg( + f'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_kernel', + (hidden_dim, mlp_intermediate_dim), 0, + extra_src_paths = ['lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w'], + stack_dim = -2) + else: + ckpt_map_for_ffn1['lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w'] = \ + self._get_convert_pkg( + 'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_kernel', + (num_of_layer, hidden_dim, mlp_intermediate_dim), 1, + lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))) + ckpt_map.update({ 'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.bias.b': self._get_convert_pkg( 'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_bias', (num_of_layer, mlp_intermediate_dim), None, lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))), - 'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w': - self._get_convert_pkg( - 'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_kernel', - (num_of_layer, hidden_dim, mlp_intermediate_dim), 1, - lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))), + **ckpt_map_for_ffn1, 'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b': self._get_convert_pkg( 'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wo_bias',