1
1
import abc
2
2
import dataclasses
3
+ import logging
3
4
import typing
4
5
5
6
import torch
6
7
7
- from fast_llm .config import DEFAULT
8
+ from fast_llm .config import DEFAULT , MISSING
8
9
from fast_llm .engine .checkpoint .config import CheckpointFormat
9
10
from fast_llm .engine .checkpoint .external import (
10
11
AutoStateDictCheckpointHandler ,
23
24
from fast_llm .functional .config import ActivationType
24
25
from fast_llm .functional .rotary import convert_rotary_complex_to_real , convert_rotary_real_to_complex
25
26
from fast_llm .layers .common .config import NormalizationType
26
- from fast_llm .layers .transformer .config import RotaryEmbeddingType , RoutingType
27
+ from fast_llm .layers .transformer .config import RotaryEmbeddingType , RoutingType , TransformerConfig
27
28
from fast_llm .models .gpt .config import (
28
29
GPTArchitectureConfig ,
29
30
GPTModelConfig ,
30
31
LlamaGPTHuggingfaceCheckpointFormat ,
32
+ Qwen2GPTHuggingfaceCheckpointFormat ,
31
33
MistralGPTHuggingfaceCheckpointFormat ,
32
34
MixtralGPTHuggingfaceCheckpointFormat ,
33
35
Starcoder2GPTHuggingfaceCheckpointFormat ,
39
41
if typing .TYPE_CHECKING :
40
42
pass
41
43
44
+ logger = logging .getLogger (__name__ )
45
+
42
46
43
47
class QueryWeightConverter (WeightConverter ):
44
48
# Hf uses the real format for rotary embeddings.
@@ -156,11 +160,14 @@ def _create_config_converters(cls) -> list[ParamConverter]:
156
160
def _get_mlp_converters (self , fast_llm_prefix : str , hf_prefix : str ) -> list [WeightConverter ]:
157
161
pass
158
162
159
- def _create_weight_converters (self ) -> list [WeightConverter ]:
163
+
164
+ def _create_weight_converters (
165
+ self ,
166
+ ) -> list [WeightConverter ]:
160
167
converters = []
161
168
num_layers = self ._model .config .base_model .transformer .num_layers
162
169
norm_bias : bool = self ._model .config .base_model .transformer .normalization .type == NormalizationType .layer_norm
163
- linear_bias : bool = self ._model .config .base_model .transformer . add_linear_biases
170
+ transformer_config : TransformerConfig = self ._model .config .base_model .transformer
164
171
165
172
# Embedding and output
166
173
if self ._model .config .base_model .tie_word_embeddings :
@@ -180,17 +187,19 @@ def _create_weight_converters(self) -> list[WeightConverter]:
180
187
converters += self ._get_weight_and_bias_converters (
181
188
f"layers.{ i + 1 } .self_attn.query" ,
182
189
f"model.layers.{ i } .self_attn.q_proj" ,
183
- linear_bias ,
190
+ transformer_config . add_attn_qkv_bias ,
184
191
QueryWeightConverter ,
185
192
)
186
193
converters += self ._get_weight_and_bias_converters (
187
194
f"layers.{ i + 1 } .self_attn.key_value" ,
188
195
(f"model.layers.{ i } .self_attn.k_proj" , f"model.layers.{ i } .self_attn.v_proj" ),
189
- linear_bias ,
196
+ transformer_config . add_attn_qkv_bias ,
190
197
KeyValueWeightConverter ,
191
198
)
192
199
converters += self ._get_weight_and_bias_converters (
193
- f"layers.{ i + 1 } .self_attn.dense" , f"model.layers.{ i } .self_attn.o_proj" , linear_bias
200
+ f"layers.{ i + 1 } .self_attn.dense" ,
201
+ f"model.layers.{ i } .self_attn.o_proj" ,
202
+ transformer_config .add_attn_dense_bias ,
194
203
)
195
204
196
205
# Norm
@@ -256,13 +265,16 @@ def _create_config_converters(cls) -> list[ParamConverter]:
256
265
]
257
266
258
267
def _get_mlp_converters (self , fast_llm_prefix : str , hf_prefix : str ) -> list [WeightConverter ]:
259
- linear_bias : bool = self ._model .config .base_model .transformer . add_linear_biases
268
+ transformer_config : TransformerConfig = self ._model .config .base_model .transformer
260
269
return [
261
270
* self ._get_weight_and_bias_converters (
262
- f"{ fast_llm_prefix } .mlp.layer_1" , f"{ hf_prefix } .mlp.c_fc" , linear_bias
271
+ f"{ fast_llm_prefix } .mlp.layer_1" , f"{ hf_prefix } .mlp.c_fc" , transformer_config . add_mlp_bias
263
272
),
264
273
* self ._get_weight_and_bias_converters (
265
- f"{ fast_llm_prefix } .mlp.layer_2" , f"{ hf_prefix } .mlp.c_proj" , linear_bias , MLPLayer2Converter
274
+ f"{ fast_llm_prefix } .mlp.layer_2" ,
275
+ f"{ hf_prefix } .mlp.c_proj" ,
276
+ transformer_config .add_mlp_bias ,
277
+ MLPLayer2Converter ,
266
278
),
267
279
]
268
280
@@ -352,18 +364,91 @@ def _create_config_converters(cls) -> list[ParamConverter]:
352
364
]
353
365
354
366
def _get_mlp_converters (self , fast_llm_prefix : str , hf_prefix : str ) -> list [WeightConverter ]:
355
- linear_bias : bool = self ._model .config .base_model .transformer .add_linear_biases
367
+ transformer_config : TransformerConfig = self ._model .config .base_model .transformer
368
+ return [
369
+ * self ._get_weight_and_bias_converters (
370
+ f"{ fast_llm_prefix } .mlp.layer_1" ,
371
+ (f"{ hf_prefix } .mlp.gate_proj" , f"{ hf_prefix } .mlp.up_proj" ),
372
+ transformer_config .add_mlp_bias ,
373
+ SplitWeightConverter ,
374
+ ),
375
+ * self ._get_weight_and_bias_converters (
376
+ f"{ fast_llm_prefix } .mlp.layer_2" ,
377
+ f"{ hf_prefix } .mlp.down_proj" ,
378
+ transformer_config .add_mlp_bias ,
379
+ MLPLayer2Converter ,
380
+ ),
381
+ ]
382
+
383
+
384
+ @dataclasses .dataclass
385
+ class IgnoreImportQwen2SlidingWindowParamsConverter (ParamConverter ):
386
+ def __post_init__ (self ):
387
+ Assert .eq (len (self .fast_llm_names ), 0 )
388
+ Assert .eq (len (self .export_names ), 0 )
389
+ self .export_names = (("use_sliding_window" ,), ("sliding_window" ,), ("max_window_layers" ,))
390
+
391
+ def export_params (self , fast_llm_values : tuple [typing .Any , ...]) -> tuple [typing .Any , ...]:
392
+ return (MISSING , MISSING , MISSING )
393
+
394
+ def import_params (self , export_values : tuple [typing .Any , ...]) -> tuple [typing .Any , ...]:
395
+ # Default value for use_sliding_window in Qwen2 HF config is False
396
+ if export_values [0 ] != MISSING and export_values [0 ] == True :
397
+ logger .warning (
398
+ f"The configuration parameters `{ self .export_names [0 ]} ={ export_values [0 ]} `,"
399
+ f" `{ self .export_names [1 ]} ={ export_values [1 ]} `, `{ self .export_names [2 ]} ={ export_values [2 ]} `"
400
+ f" are ignored during conversion."
401
+ f" If you intend to use them in Fast-LLM, make sure to set them explicitly in the model configuration."
402
+ )
403
+ return ()
404
+
405
+
406
+ class Qwen2HuggingfaceCheckpointHandler (CommonHuggingfaceCheckpointHandler ):
407
+ format : typing .ClassVar [type [CheckpointFormat ]] = Qwen2GPTHuggingfaceCheckpointFormat
408
+
409
+ @classmethod
410
+ def _create_config_converters (cls ) -> list [ParamConverter ]:
411
+ return super ()._create_config_converters () + [
412
+ ConstantExportParamConverter (export_names = (("architectures" ,),), export_value = ["Qwen2ForCausalLM" ]),
413
+ ConstantImportParamConverter (
414
+ fast_llm_names = (("transformer" , "normalization" , "type" ),), fast_llm_value = NormalizationType .rms_norm
415
+ ),
416
+ RenameParamConverter (
417
+ fast_llm_names = (("transformer" , "normalization" , "epsilon" ),), export_names = (("rms_norm_eps" ,),)
418
+ ),
419
+ ConstantImportParamConverter (fast_llm_names = (("transformer" , "gated" ),), fast_llm_value = True ),
420
+ ConstantImportParamConverter (
421
+ fast_llm_names = (("transformer" , "add_linear_biases" ),), fast_llm_value = "only_attn_qkv"
422
+ ),
423
+ RopeScalingParamConverter (
424
+ fast_llm_names = (
425
+ ("transformer" , "rotary" , "type" ),
426
+ ("transformer" , "rotary" , "scale_factor" ),
427
+ ("transformer" , "rotary" , "low_frequency_factor" ),
428
+ ("transformer" , "rotary" , "high_frequency_factor" ),
429
+ ("transformer" , "rotary" , "original_context_length" ),
430
+ ("transformer" , "rotary" , "attention_factor" ),
431
+ ("transformer" , "rotary" , "beta_fast" ),
432
+ ("transformer" , "rotary" , "beta_slow" ),
433
+ ),
434
+ export_names = (("rope_scaling" ,),),
435
+ ),
436
+ IgnoreImportQwen2SlidingWindowParamsConverter (),
437
+ ]
438
+
439
+ def _get_mlp_converters (self , fast_llm_prefix : str , hf_prefix : str ) -> list [WeightConverter ]:
440
+ transformer_config : TransformerConfig = self ._model .config .base_model .transformer
356
441
return [
357
442
* self ._get_weight_and_bias_converters (
358
443
f"{ fast_llm_prefix } .mlp.layer_1" ,
359
444
(f"{ hf_prefix } .mlp.gate_proj" , f"{ hf_prefix } .mlp.up_proj" ),
360
- linear_bias ,
445
+ transformer_config . add_mlp_bias ,
361
446
SplitWeightConverter ,
362
447
),
363
448
* self ._get_weight_and_bias_converters (
364
449
f"{ fast_llm_prefix } .mlp.layer_2" ,
365
450
f"{ hf_prefix } .mlp.down_proj" ,
366
- linear_bias ,
451
+ transformer_config . add_mlp_bias ,
367
452
MLPLayer2Converter ,
368
453
),
369
454
]
@@ -439,6 +524,7 @@ class AutoGPTHuggingfaceCheckpointHandler(
439
524
handler_map = {
440
525
Starcoder2GPTHuggingfaceCheckpointFormat .name : Starcoder2HuggingfaceCheckpointHandler ,
441
526
LlamaGPTHuggingfaceCheckpointFormat .name : LlamaHuggingfaceCheckpointHandler ,
527
+ Qwen2GPTHuggingfaceCheckpointFormat .name : Qwen2HuggingfaceCheckpointHandler ,
442
528
MistralGPTHuggingfaceCheckpointFormat .name : MistralHuggingfaceCheckpointHandler ,
443
529
MixtralGPTHuggingfaceCheckpointFormat .name : MixtralHuggingfaceCheckpointHandler ,
444
530
}
0 commit comments