Skip to content

Commit 23006dc

Browse files
authored
Qwen2 converter (#163)
1 parent 491451f commit 23006dc

File tree

3 files changed

+123
-13
lines changed

3 files changed

+123
-13
lines changed

fast_llm/models/gpt/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class Starcoder2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
3535
class LlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
3636
name: typing.ClassVar[str] = "llama"
3737

38+
class Qwen2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
39+
name: typing.ClassVar[str] = "qwen2"
40+
3841

3942
class MistralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
4043
name: typing.ClassVar[str] = "mistral"
@@ -98,6 +101,7 @@ class GPTModelConfig(FastLLMModelConfig):
98101
AutoGPTHuggingfaceCheckpointFormat,
99102
Starcoder2GPTHuggingfaceCheckpointFormat,
100103
LlamaGPTHuggingfaceCheckpointFormat,
104+
Qwen2GPTHuggingfaceCheckpointFormat,
101105
MistralGPTHuggingfaceCheckpointFormat,
102106
MixtralGPTHuggingfaceCheckpointFormat,
103107
)

fast_llm/models/gpt/conversion.py

Lines changed: 99 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import abc
22
import dataclasses
3+
import logging
34
import typing
45

56
import torch
67

7-
from fast_llm.config import DEFAULT
8+
from fast_llm.config import DEFAULT, MISSING
89
from fast_llm.engine.checkpoint.config import CheckpointFormat
910
from fast_llm.engine.checkpoint.external import (
1011
AutoStateDictCheckpointHandler,
@@ -23,11 +24,12 @@
2324
from fast_llm.functional.config import ActivationType
2425
from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex
2526
from fast_llm.layers.common.config import NormalizationType
26-
from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType
27+
from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig
2728
from fast_llm.models.gpt.config import (
2829
GPTArchitectureConfig,
2930
GPTModelConfig,
3031
LlamaGPTHuggingfaceCheckpointFormat,
32+
Qwen2GPTHuggingfaceCheckpointFormat,
3133
MistralGPTHuggingfaceCheckpointFormat,
3234
MixtralGPTHuggingfaceCheckpointFormat,
3335
Starcoder2GPTHuggingfaceCheckpointFormat,
@@ -39,6 +41,8 @@
3941
if typing.TYPE_CHECKING:
4042
pass
4143

44+
logger = logging.getLogger(__name__)
45+
4246

4347
class QueryWeightConverter(WeightConverter):
4448
# Hf uses the real format for rotary embeddings.
@@ -156,11 +160,14 @@ def _create_config_converters(cls) -> list[ParamConverter]:
156160
def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
157161
pass
158162

159-
def _create_weight_converters(self) -> list[WeightConverter]:
163+
164+
def _create_weight_converters(
165+
self,
166+
) -> list[WeightConverter]:
160167
converters = []
161168
num_layers = self._model.config.base_model.transformer.num_layers
162169
norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm
163-
linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases
170+
transformer_config: TransformerConfig = self._model.config.base_model.transformer
164171

165172
# Embedding and output
166173
if self._model.config.base_model.tie_word_embeddings:
@@ -180,17 +187,19 @@ def _create_weight_converters(self) -> list[WeightConverter]:
180187
converters += self._get_weight_and_bias_converters(
181188
f"layers.{i+1}.self_attn.query",
182189
f"model.layers.{i}.self_attn.q_proj",
183-
linear_bias,
190+
transformer_config.add_attn_qkv_bias,
184191
QueryWeightConverter,
185192
)
186193
converters += self._get_weight_and_bias_converters(
187194
f"layers.{i+1}.self_attn.key_value",
188195
(f"model.layers.{i}.self_attn.k_proj", f"model.layers.{i}.self_attn.v_proj"),
189-
linear_bias,
196+
transformer_config.add_attn_qkv_bias,
190197
KeyValueWeightConverter,
191198
)
192199
converters += self._get_weight_and_bias_converters(
193-
f"layers.{i+1}.self_attn.dense", f"model.layers.{i}.self_attn.o_proj", linear_bias
200+
f"layers.{i+1}.self_attn.dense",
201+
f"model.layers.{i}.self_attn.o_proj",
202+
transformer_config.add_attn_dense_bias,
194203
)
195204

196205
# Norm
@@ -256,13 +265,16 @@ def _create_config_converters(cls) -> list[ParamConverter]:
256265
]
257266

258267
def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
259-
linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases
268+
transformer_config: TransformerConfig = self._model.config.base_model.transformer
260269
return [
261270
*self._get_weight_and_bias_converters(
262-
f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", linear_bias
271+
f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", transformer_config.add_mlp_bias
263272
),
264273
*self._get_weight_and_bias_converters(
265-
f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.c_proj", linear_bias, MLPLayer2Converter
274+
f"{fast_llm_prefix}.mlp.layer_2",
275+
f"{hf_prefix}.mlp.c_proj",
276+
transformer_config.add_mlp_bias,
277+
MLPLayer2Converter,
266278
),
267279
]
268280

@@ -352,18 +364,91 @@ def _create_config_converters(cls) -> list[ParamConverter]:
352364
]
353365

354366
def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
355-
linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases
367+
transformer_config: TransformerConfig = self._model.config.base_model.transformer
368+
return [
369+
*self._get_weight_and_bias_converters(
370+
f"{fast_llm_prefix}.mlp.layer_1",
371+
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
372+
transformer_config.add_mlp_bias,
373+
SplitWeightConverter,
374+
),
375+
*self._get_weight_and_bias_converters(
376+
f"{fast_llm_prefix}.mlp.layer_2",
377+
f"{hf_prefix}.mlp.down_proj",
378+
transformer_config.add_mlp_bias,
379+
MLPLayer2Converter,
380+
),
381+
]
382+
383+
384+
@dataclasses.dataclass
385+
class IgnoreImportQwen2SlidingWindowParamsConverter(ParamConverter):
386+
def __post_init__(self):
387+
Assert.eq(len(self.fast_llm_names), 0)
388+
Assert.eq(len(self.export_names), 0)
389+
self.export_names = (("use_sliding_window",), ("sliding_window",), ("max_window_layers",))
390+
391+
def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
392+
return (MISSING, MISSING, MISSING)
393+
394+
def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
395+
# Default value for use_sliding_window in Qwen2 HF config is False
396+
if export_values[0] != MISSING and export_values[0] == True:
397+
logger.warning(
398+
f"The configuration parameters `{self.export_names[0]}={export_values[0]}`,"
399+
f" `{self.export_names[1]}={export_values[1]}`, `{self.export_names[2]}={export_values[2]}`"
400+
f" are ignored during conversion."
401+
f" If you intend to use them in Fast-LLM, make sure to set them explicitly in the model configuration."
402+
)
403+
return ()
404+
405+
406+
class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler):
407+
format: typing.ClassVar[type[CheckpointFormat]] = Qwen2GPTHuggingfaceCheckpointFormat
408+
409+
@classmethod
410+
def _create_config_converters(cls) -> list[ParamConverter]:
411+
return super()._create_config_converters() + [
412+
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Qwen2ForCausalLM"]),
413+
ConstantImportParamConverter(
414+
fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm
415+
),
416+
RenameParamConverter(
417+
fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),)
418+
),
419+
ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True),
420+
ConstantImportParamConverter(
421+
fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv"
422+
),
423+
RopeScalingParamConverter(
424+
fast_llm_names=(
425+
("transformer", "rotary", "type"),
426+
("transformer", "rotary", "scale_factor"),
427+
("transformer", "rotary", "low_frequency_factor"),
428+
("transformer", "rotary", "high_frequency_factor"),
429+
("transformer", "rotary", "original_context_length"),
430+
("transformer", "rotary", "attention_factor"),
431+
("transformer", "rotary", "beta_fast"),
432+
("transformer", "rotary", "beta_slow"),
433+
),
434+
export_names=(("rope_scaling",),),
435+
),
436+
IgnoreImportQwen2SlidingWindowParamsConverter(),
437+
]
438+
439+
def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
440+
transformer_config: TransformerConfig = self._model.config.base_model.transformer
356441
return [
357442
*self._get_weight_and_bias_converters(
358443
f"{fast_llm_prefix}.mlp.layer_1",
359444
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
360-
linear_bias,
445+
transformer_config.add_mlp_bias,
361446
SplitWeightConverter,
362447
),
363448
*self._get_weight_and_bias_converters(
364449
f"{fast_llm_prefix}.mlp.layer_2",
365450
f"{hf_prefix}.mlp.down_proj",
366-
linear_bias,
451+
transformer_config.add_mlp_bias,
367452
MLPLayer2Converter,
368453
),
369454
]
@@ -439,6 +524,7 @@ class AutoGPTHuggingfaceCheckpointHandler(
439524
handler_map = {
440525
Starcoder2GPTHuggingfaceCheckpointFormat.name: Starcoder2HuggingfaceCheckpointHandler,
441526
LlamaGPTHuggingfaceCheckpointFormat.name: LlamaHuggingfaceCheckpointHandler,
527+
Qwen2GPTHuggingfaceCheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler,
442528
MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler,
443529
MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler,
444530
}

tests/common.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from fast_llm.data.dataset.gpt.sampled import GPTSample
1515
from fast_llm.models.gpt.config import (
1616
LlamaGPTHuggingfaceCheckpointFormat,
17+
Qwen2GPTHuggingfaceCheckpointFormat,
1718
MistralGPTHuggingfaceCheckpointFormat,
1819
MixtralGPTHuggingfaceCheckpointFormat,
1920
Starcoder2GPTHuggingfaceCheckpointFormat,
@@ -155,6 +156,18 @@
155156
]
156157
CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"]
157158

159+
# Megatron does not support per sub layer biases
160+
CONFIG_QWEN2_MEGATRON = None
161+
CONFIG_QWEN2_FAST_LLM = CONFIG_SC2_FAST_LLM + [
162+
"model.base_model.transformer.gated=True",
163+
"model.base_model.transformer.activation_type=silu",
164+
"model.base_model.transformer.add_linear_biases=only_attn_qkv",
165+
"model.base_model.transformer.normalization.type=rms_norm",
166+
"model.base_model.transformer.ffn_hidden_size=1024",
167+
"model.base_model.tie_word_embeddings=False",
168+
]
169+
CONFIG_QWEN2_COMMON = CONFIG_QWEN2_FAST_LLM + ["model.distributed.training_dtype=bf16"]
170+
158171
# Yarn-style Rotary Embeddings
159172
CONFIG_LLAMA_YARN_MEGATRON = None
160173
CONFIG_LLAMA_YARN_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [
@@ -202,6 +215,13 @@
202215
CONFIG_LLAMA3_COMMON,
203216
LlamaGPTHuggingfaceCheckpointFormat,
204217
),
218+
"qwen2": (
219+
"gpt",
220+
CONFIG_QWEN2_FAST_LLM,
221+
CONFIG_QWEN2_MEGATRON,
222+
CONFIG_QWEN2_COMMON,
223+
Qwen2GPTHuggingfaceCheckpointFormat,
224+
),
205225
"llama-yarn": (
206226
"gpt",
207227
CONFIG_LLAMA_YARN_FAST_LLM,

0 commit comments

Comments
 (0)