From e9a64d2a59c0f0bd4cf54d18a09bbe83d2dafe27 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 3 Mar 2025 16:21:58 -0500 Subject: [PATCH 1/7] Option to configure layers independently --- fast_llm/layers/language_model/config.py | 59 +++++++++++++++++++++++- fast_llm/models/gpt/model.py | 2 +- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e3a467c..94693959 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,6 +1,6 @@ import typing -from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames @@ -98,6 +98,61 @@ def from_flat_dict( return super().from_flat_dict(default, strict) +@config_class() +class SliceConfig(Config): + begin: int = 0 + end: int | None = None + step: int = 1 + + def in_range(self, index) -> bool: + return ( + index >= self.begin and (self.end is None or index <= self.end) and ((index - self.begin) % self.step == 0) + ) + + +@config_class() +class TransformerLayerConfig(Config): + layer_ranges: list[SliceConfig] = Field( + default_factory=SliceConfig, + desc="Layer range.", + hint=FieldHint.core, + ) + updates: dict[str, typing.Any] = Field( + default_factory=dict, + ) + config: TransformerConfig = Field(init=False) + + def setup(self, default: TransformerConfig) -> None: + self.config = TransformerConfig.from_dict(default, self.updates) + + def _validate(self) -> None: + assert hasattr(self, "config") + assert len(self.layer_ranges) > 0 + + def in_range(self, index) -> bool: + return any(layer_range.in_range(index) for layer_range in self.layer_ranges) + + +@config_class() +class TransformerLayersConfig(Config): + layers: list[TransformerLayerConfig] = Field(default_factory=list) + default: TransformerConfig = Field(init=False) + + def setup(self, default: TransformerConfig) -> None: + self.default = default + for layer in self.layers: + layer.setup(default) + + def _validate(self) -> None: + assert hasattr(self, "default") + + def get_layer_config(self, index: int) -> TransformerConfig: + for layer in self.layers: + if layer.in_range(index): + return layer.config + return self.default + + @config_class() class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): """ @@ -112,6 +167,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) + layers: TransformerLayersConfig = Field(default_factory=TransformerLayersConfig) init_method_std_embed: float = Field( default=None, desc="Initialization scale for the vocabulary embedding and output weights (logits).", @@ -175,6 +231,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): ) def _validate(self) -> None: + self.layers.setup(self.transformer) if self.transformer.init_method_std is None: self.transformer.init_method_std = self.transformer.hidden_size**-0.5 if self.init_method_std_embed is None: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8aa68333..1f4a26b8 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -70,7 +70,7 @@ def get_layers(self) -> list[Layer]: LanguageModelEmbedding(self._config, self._tensor_space), *[ TransformerLayer( - self._config.transformer, + self._config.layers.get_layer_config(i), self._tensor_space, layer_index=i + 1, ) From 39e04884e166daefbe1c54406d04539d7f72bd79 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 4 Mar 2025 14:19:24 -0500 Subject: [PATCH 2/7] fixes --- Megatron-LM | 1 - docs/developer_guide/conversion.md | 2 +- fast_llm/config.py | 1 + fast_llm/engine/config_utils/tensor_space.py | 18 +++- fast_llm/layers/language_model/config.py | 97 ++++++------------- fast_llm/layers/language_model/embedding.py | 4 +- fast_llm/layers/language_model/head.py | 4 +- fast_llm/layers/transformer/attention.py | 8 +- fast_llm/layers/transformer/config.py | 89 ++++++++++++++++- .../layers/transformer/mixture_of_experts.py | 4 +- fast_llm/layers/transformer/mlp.py | 6 +- fast_llm/layers/transformer/preprocessing.py | 7 +- fast_llm/layers/transformer/transformer.py | 4 +- fast_llm/models/gpt/conversion.py | 30 +++--- fast_llm/models/gpt/megatron.py | 10 +- fast_llm/models/gpt/model.py | 29 +++--- fast_llm/models/gpt/trainer.py | 2 +- tests/test_attention.py | 18 ++-- tests/test_config.py | 90 ++++++++++------- tests/test_mlp.py | 18 ++-- 20 files changed, 253 insertions(+), 189 deletions(-) delete mode 160000 Megatron-LM diff --git a/Megatron-LM b/Megatron-LM deleted file mode 160000 index cb6baf17..00000000 --- a/Megatron-LM +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cb6baf171d064db6c2fee52f32dc1b51a2e6538d diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 0620beae..76d8bfa3 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -232,7 +232,7 @@ Continuing our `AwesomeModel` handler example, we define: def _create_weight_converters(self) -> list[WeightConverter]: converters = [] # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.layers.default.num_layers # A simple renaming example, for the word embeddings. converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) diff --git a/fast_llm/config.py b/fast_llm/config.py index f1c88965..7090291a 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -733,6 +733,7 @@ def _from_dict( if strict and default: out._unknown_fields = default.copy() if _AUTO_VALIDATE: + print("WKIUEFNW", out.to_serialized()) out.validate() return out diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 8bc86b73..73e11095 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -113,10 +113,12 @@ class TensorSpace: _is_setup: bool = False _distributed: "Distributed" - def __init__(self, distributed_config: DistributedConfig): + def __init__(self, distributed_config: DistributedConfig, _parent: "TensorSpace|None" = None): self._distributed_config = distributed_config self._tensor_dims: dict[str, TensorDim] = {} self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1)) + self._parent = _parent + self._sub_spaces: dict[str, TensorSpace] = {} def setup(self, distributed: "Distributed") -> None: assert distributed.config is self._distributed_config @@ -146,5 +148,17 @@ def add_tensor_dim(self, dim: TensorDim) -> None: Assert.eq(dim.parallel_dim, self._distributed_config.distributed_dims[dim.parallel_dim.name]) self._tensor_dims[dim.name] = dim + def add_sub_space(self, name: str) -> "TensorSpace": + self._sub_spaces[name] = TensorSpace(self._distributed_config, _parent=self) + return self._sub_spaces[name] + + def get_sub_space(self, name: str) -> "TensorSpace": + return self._sub_spaces[name] + def get_tensor_dim(self, name: str) -> TensorDim: - return self._tensor_dims[name] + if name in self._tensor_dims: + return self._tensor_dims[name] + elif self._parent is not None: + return self._parent.get_tensor_dim(name) + else: + raise KeyError(name) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 94693959..11d117be 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,6 +1,6 @@ import typing -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames @@ -32,11 +32,12 @@ class LanguageModelKwargs: @config_class() class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): - transformer: TransformerArchitectureConfig = Field( - default_factory=TransformerArchitectureConfig, - desc="Configuration for the transformer architecture.", - hint=FieldHint.core, - ) + # transformer: TransformerLayerArchitectureConfig = Field( + # default_factory=TransformerLayerArchitectureConfig, + # desc="Configuration for the transformer architecture.", + # hint=FieldHint.core, + # ) + layers: TransformerConfig = Field(default_factory=TransformerArchitectureConfig) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -60,11 +61,12 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): def _validate(self) -> None: if self.use_position_embeddings is None: - self.use_position_embeddings = not self.transformer.rotary.enabled + self.use_position_embeddings = not self.layers.default.rotary.enabled super()._validate() def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space) + assert self._validated + self.layers.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions @@ -97,60 +99,16 @@ def from_flat_dict( cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") return super().from_flat_dict(default, strict) - -@config_class() -class SliceConfig(Config): - begin: int = 0 - end: int | None = None - step: int = 1 - - def in_range(self, index) -> bool: - return ( - index >= self.begin and (self.end is None or index <= self.end) and ((index - self.begin) % self.step == 0) - ) - - -@config_class() -class TransformerLayerConfig(Config): - layer_ranges: list[SliceConfig] = Field( - default_factory=SliceConfig, - desc="Layer range.", - hint=FieldHint.core, - ) - updates: dict[str, typing.Any] = Field( - default_factory=dict, - ) - config: TransformerConfig = Field(init=False) - - def setup(self, default: TransformerConfig) -> None: - self.config = TransformerConfig.from_dict(default, self.updates) - - def _validate(self) -> None: - assert hasattr(self, "config") - assert len(self.layer_ranges) > 0 - - def in_range(self, index) -> bool: - return any(layer_range.in_range(index) for layer_range in self.layer_ranges) - - -@config_class() -class TransformerLayersConfig(Config): - layers: list[TransformerLayerConfig] = Field(default_factory=list) - default: TransformerConfig = Field(init=False) - - def setup(self, default: TransformerConfig) -> None: - self.default = default - for layer in self.layers: - layer.setup(default) - - def _validate(self) -> None: - assert hasattr(self, "default") - - def get_layer_config(self, index: int) -> TransformerConfig: - for layer in self.layers: - if layer.in_range(index): - return layer.config - return self.default + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + # TODO v0.x: Remove backward compatibility. + cls._handle_renamed_field(default, "transformer", ("layers", "default")) + return super()._from_dict(default, strict, flat) @config_class() @@ -166,8 +124,8 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig - transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) - layers: TransformerLayersConfig = Field(default_factory=TransformerLayersConfig) + # transformer: TransformerLayerConfig = FieldUpdate(default_factory=TransformerLayerConfig) + layers: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) init_method_std_embed: float = Field( default=None, desc="Initialization scale for the vocabulary embedding and output weights (logits).", @@ -231,15 +189,14 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): ) def _validate(self) -> None: - self.layers.setup(self.transformer) - if self.transformer.init_method_std is None: - self.transformer.init_method_std = self.transformer.hidden_size**-0.5 + if self.layers.default.init_method_std is None: + self.layers.default.init_method_std = self.layers.default.hidden_size**-0.5 if self.init_method_std_embed is None: - self.init_method_std_embed = self.transformer.init_method_std + self.init_method_std_embed = self.layers.default.init_method_std if self.init_method_max_embed is None: - self.init_method_max_embed = self.transformer.init_method_max + self.init_method_max_embed = self.layers.default.init_method_max if self.init_method_min_embed is None: - self.init_method_min_embed = self.transformer.init_method_min + self.init_method_min_embed = self.layers.default.init_method_min if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) super()._validate() diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 67e7eb53..6a25e386 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -37,13 +37,13 @@ def __init__( self._tensor_space = tensor_space self._residual_dtype = ( self._distributed_config.optimization_dtype - if config.transformer.full_precision_residual + if config.layers.default.full_precision_residual else self._distributed_config.training_dtype ).torch self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings - self._dropout_p = config.transformer.hidden_dropout + self._dropout_p = config.layers.default.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 4c03e393..f9b3dbf3 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -40,7 +40,7 @@ def __init__( tensor_space: TensorSpace, ): super().__init__(config) - self._debug_transformer = config.transformer.debug_transformer + self._debug_transformer = config.layers.default.debug_transformer self._tie_word_embeddings = config.tie_word_embeddings self._tensor_space = tensor_space @@ -56,7 +56,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self.final_norm = config.transformer.normalization.get_layer(hidden_dim) + self.final_norm = config.layers.default.normalization.get_layer(hidden_dim) self._logits_scale_factor = config.logits_scale_factor self._z_loss_factor = config.logit_z_loss diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index f64de9f1..fe982bed 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -10,11 +10,7 @@ from fast_llm.functional.rotary import apply_rotary_embeddings from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, -) +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, TransformerLayerConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -69,7 +65,7 @@ class Attention(torch.nn.Module): def __init__( self, - config: TransformerConfig, + config: TransformerLayerConfig, tensor_space: TensorSpace, layer_index, ): diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index cf985392..18bfeeb9 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -4,7 +4,7 @@ import typing import warnings -from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace @@ -156,7 +156,7 @@ class AddLinearBiasChoices(str, enum.Enum): @config_class() -class TransformerArchitectureConfig(BaseModelArchitectureConfig): +class TransformerLayerArchitectureConfig(BaseModelArchitectureConfig): _abstract = False normalization: NormalizationArchitectureConfig = Field( default_factory=NormalizationArchitectureConfig, @@ -306,6 +306,7 @@ def _from_dict( return super()._from_dict(default, strict, flat) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + assert self._validated tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension @@ -367,7 +368,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: @config_class() -class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): +class TransformerLayerConfig(TransformerLayerArchitectureConfig, BaseModelConfig): normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) rotary: RotaryConfig = FieldUpdate(default_factory=RotaryConfig) # Default: hidden_size**-0.5 @@ -623,3 +624,85 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: Assert.is_(self.window_size, None) return use_flash_attention + + +@config_class() +class SliceConfig(Config): + begin: int = Field(default=0) + end: int | None = Field(default=None) + step: int = Field(default=1) + + def in_range(self, index) -> bool: + return ( + index >= self.begin and (self.end is None or index <= self.end) and ((index - self.begin) % self.step == 0) + ) + + +@config_class() +class TransformerLayerRangeArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + layer_ranges: list[SliceConfig] = Field( + default_factory=SliceConfig, + desc="Layer range.", + hint=FieldHint.core, + ) + updates: dict[str, typing.Any] = Field( + default_factory=dict, + ) + config: TransformerLayerArchitectureConfig = Field(init=False) + + def setup(self, default: TransformerLayerArchitectureConfig) -> None: + self.config = TransformerLayerArchitectureConfig.from_dict(default, self.updates) + + def _validate(self) -> None: + assert hasattr(self, "config") + assert len(self.layer_ranges) > 0 + self.config.validate() + + def in_range(self, index) -> bool: + return any(layer_range.in_range(index) for layer_range in self.layer_ranges) + + +@config_class() +class TransformerLayerRangeConfig(TransformerLayerRangeArchitectureConfig, BaseModelConfig): + pass + + +@config_class() +class TransformerArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + layers: list[TransformerLayerRangeArchitectureConfig] = Field(default_factory=list) + default: TransformerLayerArchitectureConfig = Field(default_factory=TransformerLayerArchitectureConfig) + + def _validate(self) -> None: + for layer in self.layers: + layer.setup(self.default) + # TODO: Improve this? + Assert.eq(layer.config.hidden_size, self.default.hidden_size) + Assert.eq(layer.config.num_layers, self.default.num_layers) + super()._validate() + + def get_layer_config_and_tensor_space( + self, index: int, tensor_space: TensorSpace + ) -> tuple[TransformerLayerArchitectureConfig, TensorSpace]: + for i, layer in enumerate(self.layers): + if layer.in_range(index): + return layer.config, tensor_space.get_sub_space(f"transformer_layers_{i}") + return self.default, tensor_space + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + assert self._validated + self.default.setup_tensor_space(tensor_space) + for i, layer in enumerate(self.layers): + layer.config.setup_tensor_space(tensor_space.add_sub_space(f"transformer_layers_{i}")) + + +@config_class() +class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): + layers: list[TransformerLayerRangeConfig] = FieldUpdate() + default: TransformerLayerConfig = FieldUpdate(default_factory=TransformerLayerConfig) + + def get_layer_config_and_tensor_space( + self, index: int, tensor_space: TensorSpace + ) -> tuple[TransformerLayerConfig, TensorSpace]: + return super().get_layer_config_and_tensor_space(index, tensor_space) diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 85c6686f..c4405174 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -13,9 +13,9 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.transformer.config import ( RoutingType, - TransformerConfig, TransformerDimNames, TransformerKwargs, + TransformerLayerConfig, TransformerLossNames, ) from fast_llm.layers.transformer.mlp import MLPBase @@ -40,7 +40,7 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerLayerConfig, tensor_space: TensorSpace, name: str = "mlp"): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index adc6242d..5a42bdb9 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,13 +8,13 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerLayerConfig from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerLayerConfig, tensor_space: TensorSpace, name: str = "mlp"): super().__init__() self._name = name @@ -60,7 +60,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerLayerConfig, tensor_space: TensorSpace, name: str = "mlp"): Assert.eq(config.num_experts, 1) super().__init__(config, tensor_space, name) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index a509ce6a..86578f09 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -9,9 +9,9 @@ from fast_llm.layers.transformer.config import ( RotaryConfig, RotaryEmbeddingType, - TransformerConfig, TransformerDimNames, TransformerKwargs, + TransformerLayerConfig, ) from fast_llm.tensor import TensorMeta @@ -49,7 +49,6 @@ def apply_yarn_scaling(config: RotaryConfig, frequencies: torch.Tensor, kv_chann base = config.theta partial_rotary_factor = 1.0 dim = int(kv_channels * partial_rotary_factor) - max_position_embeddings = sequence_length factor = config.scale_factor attention_factor = config.attention_factor @@ -75,7 +74,6 @@ def linear_ramp_factor(min, max, dim): ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func - # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs # to expand the possible context length. In other words, interpolation = apply scaling factor. # pos_freqs = base ** (torch.arange(0, dim, 2).float().to(frequencies.device) / dim) @@ -99,7 +97,6 @@ def linear_ramp_factor(min, max, dim): return inv_freq, attention_factor - def get_rotary_frequencies( config: RotaryConfig, sequence_length, @@ -202,7 +199,7 @@ class BackupAttentionPreprocessor: def __init__( self, - config: TransformerConfig, + config: TransformerLayerConfig, tensor_space: TensorSpace, ): self._config = config diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 4780dd3a..bd351358 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,7 +8,7 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, TransformerLayerConfig from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage @@ -24,7 +24,7 @@ class TransformerLayer(Layer): def __init__( self, - config: TransformerConfig, + config: TransformerLayerConfig, tensor_space: TensorSpace, layer_index: int, ): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 45206ff1..d9a19b9c 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -48,16 +48,16 @@ def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight - if self._config.transformer.rotary.complex_format: - query = convert_rotary_complex_to_real(query[:], self._config.transformer.kv_channels, 0) + if self._config.layers.default.rotary.complex_format: + query = convert_rotary_complex_to_real(query[:], self._config.layers.default.kv_channels, 0) return (query,) def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight - if self._config.transformer.rotary.complex_format: - query = convert_rotary_real_to_complex(query[:], self._config.transformer.kv_channels, 0) + if self._config.layers.default.rotary.complex_format: + query = convert_rotary_real_to_complex(query[:], self._config.layers.default.kv_channels, 0) return (query,) @@ -70,16 +70,16 @@ def export_weight( ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (key_value,) = weight key, value = key_value[:].chunk(2) - if self._config.transformer.rotary.complex_format: - key = convert_rotary_complex_to_real(key, self._config.transformer.kv_channels, 0) + if self._config.layers.default.rotary.complex_format: + key = convert_rotary_complex_to_real(key, self._config.layers.default.kv_channels, 0) return key, value def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: key, value = weight - if self._config.transformer.rotary.complex_format: - key = convert_rotary_real_to_complex(key[:], self._config.transformer.kv_channels, 0) + if self._config.layers.default.rotary.complex_format: + key = convert_rotary_real_to_complex(key[:], self._config.layers.default.kv_channels, 0) key_value = torch.cat([key[:], value[:]]) return (key_value,) @@ -158,9 +158,11 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters(self) -> list[WeightConverter]: converters = [] - num_layers = self._model.config.base_model.transformer.num_layers - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + num_layers = self._model.config.base_model.layers.default.num_layers + norm_bias: bool = ( + self._model.config.base_model.layers.default.normalization.type == NormalizationType.layer_norm + ) + linear_bias: bool = self._model.config.base_model.layers.default.add_linear_biases # Embedding and output if self._model.config.base_model.tie_word_embeddings: @@ -256,7 +258,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + linear_bias: bool = self._model.config.base_model.layers.default.add_linear_biases return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", linear_bias @@ -352,7 +354,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + linear_bias: bool = self._model.config.base_model.layers.default.add_linear_biases return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", @@ -413,7 +415,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - num_experts = self._model.config.base_model.transformer.num_experts + num_experts = self._model.config.base_model.layers.default.num_experts return [ WeightConverter(f"{fast_llm_prefix}.mlp.router.weight", f"{hf_prefix}.block_sparse_moe.gate.weight"), SplitWeightConverter( diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index 842a064e..975fd0c7 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.config import TransformerLayerConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -12,7 +12,7 @@ def get_init_megatron( - meta: "ParameterMeta", config: TransformerConfig + meta: "ParameterMeta", config: TransformerLayerConfig ) -> typing.Callable[["torch.Tensor", "Distributed"], "torch.Tensor"]: def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): Assert.eq(distributed.config.world_size, 1) @@ -49,7 +49,7 @@ def set_megatron_distributed_seeds(config: "DistributedConfig") -> None: def _init_attention_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: TransformerLayerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": # Megatron combines q and kv and inverts the initialization order of qkv and dense layers. # It also always treats the tensors as tensor-parallel and uses a different rotary embedding format. @@ -114,7 +114,7 @@ def _init_position_embeddings_megatron( def _init_transposed_mlp_weight_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: TransformerLayerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": import torch @@ -138,7 +138,7 @@ def _init_moe_router_megatron( def _init_moe_mlp_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: TransformerLayerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": assert meta.param_init_method is not None generator = distributed.tp_init_generator if meta.is_tensor_parallel else distributed.pp_init_generator diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 1f4a26b8..da0eaa8d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -49,20 +49,20 @@ def __init__( distributed_config: DistributedConfig, ): super().__init__(config, distributed_config) - self._use_flash_attention = self._config.transformer.do_use_flash_attention(distributed_config) + self._use_flash_attention = self._config.layers.default.do_use_flash_attention(distributed_config) if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor = PositionEmbeddingPreprocessor(self._config, self._tensor_space) - if self._config.transformer.rotary.enabled: + if self._config.layers.default.rotary.enabled: self._rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( - self._config.transformer.rotary, self._tensor_space + self._config.layers.default.rotary, self._tensor_space ) if not self._use_flash_attention: self._backup_attention_preprocessor = BackupAttentionPreprocessor( - self._config.transformer, self._tensor_space + self._config.layers.default, self._tensor_space ) def get_layers(self) -> list[Layer]: @@ -70,11 +70,10 @@ def get_layers(self) -> list[Layer]: LanguageModelEmbedding(self._config, self._tensor_space), *[ TransformerLayer( - self._config.layers.get_layer_config(i), - self._tensor_space, + *self._config.layers.get_layer_config_and_tensor_space(i, self._tensor_space), layer_index=i + 1, ) - for i in range(self._config.transformer.num_layers) + for i in range(self._config.layers.default.num_layers) ], LanguageModelHead(self._config, self._tensor_space), ] @@ -175,7 +174,7 @@ def preprocess_meta( ) if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess_meta(kwargs) - if self._config.transformer.rotary.enabled: + if self._config.layers.default.rotary.enabled: self._rotary_embedding_preprocessor.preprocess_meta(kwargs) if not self._use_flash_attention: self._backup_attention_preprocessor.preprocess_meta(kwargs) @@ -214,7 +213,7 @@ def preprocess( if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.create_tensors(sequence_length) - if self._config.transformer.rotary.enabled: + if self._config.layers.default.rotary.enabled: self._rotary_embedding_preprocessor.create_tensors(sequence_length) if not self._use_flash_attention: self._backup_attention_preprocessor.create_tensors(sequence_length) @@ -257,7 +256,7 @@ def preprocess( kwargs[LanguageModelKwargs.labels] = labels if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs) - if self._config.transformer.rotary.enabled: + if self._config.layers.default.rotary.enabled: self._rotary_embedding_preprocessor.preprocess(kwargs) if not self._use_flash_attention: self._backup_attention_preprocessor.preprocess(kwargs) @@ -290,22 +289,22 @@ def loss_defs(self) -> list[LossDef]: LossDef(name=LanguageModelLossNames.language_model_loss, formatted_name="language model loss", count=1) ] if ( - self._config.transformer.num_experts > 1 - and self._config.transformer.expert_routing_type == RoutingType.topk + self._config.layers.default.num_experts > 1 + and self._config.layers.default.expert_routing_type == RoutingType.topk ): loss_defs.append( LossDef( name=TransformerLossNames.load_balancing_loss, formatted_name="load balancing loss", - count=self._config.transformer.num_layers, + count=self._config.layers.default.num_layers, ) ) - if self._config.transformer.expert_z_loss_coefficient: + if self._config.layers.default.expert_z_loss_coefficient: loss_defs.append( LossDef( name=TransformerLossNames.router_z_loss, formatted_name="router z loss", - count=self._config.transformer.num_layers, + count=self._config.layers.default.num_layers, ) ) if self._config.logit_z_loss: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 7b03a7b4..21e08229 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -26,7 +26,7 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, # TODO: Do in model, automate/generalize, get other stats """Get tflop/s/GPU from global-batch-size and elapsed-time""" checkpoint_activations_factor = 3 if phase == PhaseType.training else 1 - transformer_config = self._config.model.base_model.transformer + transformer_config = self._config.model.base_model.layers.default sequence_length = self._config.batch.sequence_length tokens = self._config.batch.batch_size * sequence_length diff --git a/tests/test_attention.py b/tests/test_attention.py index c8b91d76..0d90246e 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,8 +1,9 @@ import unittest.mock -from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.engine.distributed.config import DistributedConfig + from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.transformer.attention import Attention +from fast_llm.layers.transformer.config import TransformerLayerConfig def test_decide_window_size(): @@ -10,23 +11,23 @@ def test_decide_window_size(): attention._decide_window_size = Attention._decide_window_size.__get__(attention) # Attach real method # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) - attention._config = TransformerConfig(window_size=512, max_window_layers=2) + attention._config = TransformerLayerConfig(window_size=512, max_window_layers=2) attention._layer_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) - attention._config = TransformerConfig(window_size=512, max_window_layers=2) + attention._config = TransformerLayerConfig(window_size=512, max_window_layers=2) attention._layer_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) - attention._config = TransformerConfig(window_size=512, max_window_layers=None) + attention._config = TransformerLayerConfig(window_size=512, max_window_layers=None) assert attention._decide_window_size() == 512 def test_attention_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, + transformer_conf = TransformerLayerConfig( + num_layers=2, num_attention_heads=2, hidden_size=16, ) @@ -35,4 +36,3 @@ def test_attention_constructor(): transformer_conf.setup_tensor_space(tensor_space) Attention(transformer_conf, tensor_space, 1) - diff --git a/tests/test_config.py b/tests/test_config.py index 7141812a..15578463 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,19 +1,18 @@ import pathlib -import pytest import subprocess import unittest.mock -import yaml +import pytest +import yaml +from fast_llm.config import ValidationError +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerArchitectureConfig, AddLinearBiasChoices, + TransformerLayerArchitectureConfig, + TransformerLayerConfig, ) -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.config import ValidationError - from fast_llm.models.auto import trainer_registry @@ -69,22 +68,22 @@ def test_do_use_flash_attention(): mock_distributed_config = unittest.mock.Mock(spec=DistributedConfig) # Test case 1: use_flash_attention is True and training_dtype is float16 - config = TransformerConfig(use_flash_attention=True, window_size=None) + config = TransformerLayerConfig(use_flash_attention=True, window_size=None) mock_distributed_config.training_dtype = DataType.float16 assert config.do_use_flash_attention(mock_distributed_config) is True # Test case 2: use_flash_attention is False - config = TransformerConfig(use_flash_attention=False, window_size=None) + config = TransformerLayerConfig(use_flash_attention=False, window_size=None) mock_distributed_config.training_dtype = DataType.float16 assert config.do_use_flash_attention(mock_distributed_config) is False # Test case 3: use_flash_attention is True but training_dtype is not float16 or bfloat16 - config = TransformerConfig(use_flash_attention=True, window_size=None) + config = TransformerLayerConfig(use_flash_attention=True, window_size=None) mock_distributed_config.training_dtype = DataType.float32 assert config.do_use_flash_attention(mock_distributed_config) is False # Test case 4: use_flash_attention is False and window_size is not None - config = TransformerConfig(use_flash_attention=False, window_size=512) + config = TransformerLayerConfig(use_flash_attention=False, window_size=512) mock_distributed_config.training_dtype = DataType.float32 with pytest.raises(AssertionError): config.do_use_flash_attention(mock_distributed_config) @@ -92,50 +91,71 @@ def test_do_use_flash_attention(): def test_add_linear_biases_valid_values(): # Valid boolean values - assert TransformerArchitectureConfig(add_linear_biases=True).add_linear_biases is True - assert TransformerArchitectureConfig(add_linear_biases=False).add_linear_biases is False + assert TransformerLayerArchitectureConfig(add_linear_biases=True).add_linear_biases is True + assert TransformerLayerArchitectureConfig(add_linear_biases=False).add_linear_biases is False # Valid enum values - assert TransformerArchitectureConfig(add_linear_biases="nowhere").add_linear_biases == AddLinearBiasChoices.nowhere assert ( - TransformerArchitectureConfig(add_linear_biases="everywhere").add_linear_biases + TransformerLayerArchitectureConfig(add_linear_biases="nowhere").add_linear_biases + == AddLinearBiasChoices.nowhere + ) + assert ( + TransformerLayerArchitectureConfig(add_linear_biases="everywhere").add_linear_biases == AddLinearBiasChoices.everywhere ) assert ( - TransformerArchitectureConfig(add_linear_biases="only_attn_qkv").add_linear_biases == AddLinearBiasChoices.only_attn_qkv + TransformerLayerArchitectureConfig(add_linear_biases="only_attn_qkv").add_linear_biases + == AddLinearBiasChoices.only_attn_qkv ) def test_add_linear_biases_invalid_values(): with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases="invalid_value") + TransformerLayerArchitectureConfig(add_linear_biases="invalid_value") with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases=123) + TransformerLayerArchitectureConfig(add_linear_biases=123) with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases=None) + TransformerLayerArchitectureConfig(add_linear_biases=None) def test_add_mlp_bias(): - assert TransformerArchitectureConfig(add_linear_biases=True).add_mlp_bias is True - assert TransformerArchitectureConfig(add_linear_biases=False).add_mlp_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_mlp_bias is True - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_mlp_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_mlp_bias is False + assert TransformerLayerArchitectureConfig(add_linear_biases=True).add_mlp_bias is True + assert TransformerLayerArchitectureConfig(add_linear_biases=False).add_mlp_bias is False + assert TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_mlp_bias is True + assert TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_mlp_bias is False + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_mlp_bias is False + ) def test_add_attn_qkv_bias(): - assert TransformerArchitectureConfig(add_linear_biases=True).add_attn_qkv_bias is True - assert TransformerArchitectureConfig(add_linear_biases=False).add_attn_qkv_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_qkv_bias is True - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_qkv_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_qkv_bias is True + assert TransformerLayerArchitectureConfig(add_linear_biases=True).add_attn_qkv_bias is True + assert TransformerLayerArchitectureConfig(add_linear_biases=False).add_attn_qkv_bias is False + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_qkv_bias is True + ) + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_qkv_bias is False + ) + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_qkv_bias + is True + ) def test_add_attn_dense_bias(): - assert TransformerArchitectureConfig(add_linear_biases=True).add_attn_dense_bias is True - assert TransformerArchitectureConfig(add_linear_biases=False).add_attn_dense_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_dense_bias is True - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_dense_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_dense_bias is False + assert TransformerLayerArchitectureConfig(add_linear_biases=True).add_attn_dense_bias is True + assert TransformerLayerArchitectureConfig(add_linear_biases=False).add_attn_dense_bias is False + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_dense_bias + is True + ) + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_dense_bias is False + ) + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_dense_bias + is False + ) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index bcfbaf69..7fea9ba5 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -1,12 +1,12 @@ -from fast_llm.layers.transformer.mlp import MLP -from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.transformer.config import TransformerLayerConfig +from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP +from fast_llm.layers.transformer.mlp import MLP def test_mlp_constructor(): - transformer_conf = TransformerConfig( + transformer_conf = TransformerLayerConfig( num_layers=2, num_attention_heads=2, hidden_size=16, @@ -19,12 +19,8 @@ def test_mlp_constructor(): def test_moe_mlp_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - num_experts=2, - add_linear_biases=False + transformer_conf = TransformerLayerConfig( + num_layers=2, num_attention_heads=2, hidden_size=16, num_experts=2, add_linear_biases=False ) distributed_config = DistributedConfig() tensor_space = TensorSpace(distributed_config=distributed_config) From dcc935cc5befaf791a7360047cf28cb18f5023c9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 4 Mar 2025 16:34:39 -0500 Subject: [PATCH 3/7] fixes --- fast_llm/config.py | 1 - fast_llm/layers/language_model/config.py | 2 +- fast_llm/models/gpt/conversion.py | 118 ++++++++++++----------- fast_llm/models/gpt/model.py | 2 +- 4 files changed, 65 insertions(+), 58 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 7090291a..f1c88965 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -733,7 +733,6 @@ def _from_dict( if strict and default: out._unknown_fields = default.copy() if _AUTO_VALIDATE: - print("WKIUEFNW", out.to_serialized()) out.validate() return out diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 11d117be..35970842 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -37,7 +37,7 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): # desc="Configuration for the transformer architecture.", # hint=FieldHint.core, # ) - layers: TransformerConfig = Field(default_factory=TransformerArchitectureConfig) + layers: TransformerArchitectureConfig = Field(default_factory=TransformerArchitectureConfig) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 5ec0ad26..cdae7ca7 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -24,7 +24,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.common.config import NormalizationType -from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig +from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType from fast_llm.models.gpt.config import ( GPTArchitectureConfig, GPTModelConfig, @@ -118,32 +118,32 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), RenameParamConverter( - fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) + fast_llm_names=(("layers", "default", "rotary", "theta"),), export_names=(("rope_theta",),) ), MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), + fast_llm_names=(("layers", "default", "activation_type"),), export_names=(("hidden_act",),), fast_llm_value=ActivationType.from_hf_name, export_value=lambda activation_type: activation_type.hf_name, ), RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), + fast_llm_names=(("layers", "default", "num_layers"),), export_names=(("num_hidden_layers",),), ), RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), + fast_llm_names=(("layers", "default", "hidden_size"),), export_names=(("hidden_size",),), ), RenameParamConverter( - fast_llm_names=(("transformer", "num_attention_heads"),), + fast_llm_names=(("layers", "default", "num_attention_heads"),), export_names=(("num_attention_heads",),), ), RenameParamConverter( - fast_llm_names=(("transformer", "head_groups"),), + fast_llm_names=(("layers", "default", "head_groups"),), export_names=(("num_key_value_heads",),), ), RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), + fast_llm_names=(("layers", "default", "ffn_hidden_size"),), export_names=(("intermediate_size",),), ), RenameParamConverter( @@ -168,7 +168,7 @@ def _create_weight_converters( norm_bias: bool = ( self._model.config.base_model.layers.default.normalization.type == NormalizationType.layer_norm ) - transformer_config: TransformerConfig = self._model.config.base_model.layers.default + layer_config = self._model.config.base_model.layers.default # Embedding and output if self._model.config.base_model.tie_word_embeddings: @@ -188,19 +188,19 @@ def _create_weight_converters( converters += self._get_weight_and_bias_converters( f"layers.{i+1}.self_attn.query", f"model.layers.{i}.self_attn.q_proj", - transformer_config.add_attn_qkv_bias, + layer_config.add_attn_qkv_bias, QueryWeightConverter, ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.self_attn.key_value", (f"model.layers.{i}.self_attn.k_proj", f"model.layers.{i}.self_attn.v_proj"), - transformer_config.add_attn_qkv_bias, + layer_config.add_attn_qkv_bias, KeyValueWeightConverter, ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.self_attn.dense", f"model.layers.{i}.self_attn.o_proj", - transformer_config.add_attn_dense_bias, + layer_config.add_attn_dense_bias, ) # Norm @@ -253,28 +253,31 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Starcoder2ForCausalLM"]), ConstantImportParamConverter( - fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.default + fast_llm_names=(("layers", "default", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.default ), ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.layer_norm + fast_llm_names=(("layers", "default", "normalization", "type"),), + fast_llm_value=NormalizationType.layer_norm, ), RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) + fast_llm_names=(("layers", "default", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) + ), + ConstantImportParamConverter(fast_llm_names=(("layers", "default", "gated"),), fast_llm_value=False), + ConstantImportParamConverter( + fast_llm_names=(("layers", "default", "add_linear_biases"),), fast_llm_value=True ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=False), - ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=True), ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.layers.default + layer_config = self._model.config.base_model.layers.default return [ *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", transformer_config.add_mlp_bias + f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", layer_config.add_mlp_bias ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.c_proj", - transformer_config.add_mlp_bias, + layer_config.add_mlp_bias, MLPLayer2Converter, ), ] @@ -285,27 +288,30 @@ class CommonLlamaHuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + fast_llm_names=(("layers", "default", "normalization", "type"),), + fast_llm_value=NormalizationType.rms_norm, ), RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) + fast_llm_names=(("layers", "default", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) ), RenameParamConverter( - fast_llm_names=(("transformer", "kv_channels"),), - export_names=(("head_dim"),), + fast_llm_names=(("layers", "default", "kv_channels"),), + export_names=(("head_dim",),), + ), + ConstantImportParamConverter(fast_llm_names=(("layers", "default", "gated"),), fast_llm_value=True), + ConstantImportParamConverter( + fast_llm_names=(("layers", "default", "add_linear_biases"),), fast_llm_value=False ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), - ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), RopeScalingParamConverter( fast_llm_names=( - ("transformer", "rotary", "type"), - ("transformer", "rotary", "scale_factor"), - ("transformer", "rotary", "low_frequency_factor"), - ("transformer", "rotary", "high_frequency_factor"), - ("transformer", "rotary", "original_context_length"), - ("transformer", "rotary", "attention_factor"), - ("transformer", "rotary", "beta_fast"), - ("transformer", "rotary", "beta_slow"), + ("layers", "default", "rotary", "type"), + ("layers", "default", "rotary", "scale_factor"), + ("layers", "default", "rotary", "low_frequency_factor"), + ("layers", "default", "rotary", "high_frequency_factor"), + ("layers", "default", "rotary", "original_context_length"), + ("layers", "default", "rotary", "attention_factor"), + ("layers", "default", "rotary", "beta_fast"), + ("layers", "default", "rotary", "beta_slow"), ), export_names=(("rope_scaling",),), ), @@ -365,18 +371,18 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.layers.default + layer_config = self._model.config.base_model.layers.default return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + layer_config.add_mlp_bias, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + layer_config.add_mlp_bias, MLPLayer2Converter, ), ] @@ -412,25 +418,26 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Qwen2ForCausalLM"]), ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + fast_llm_names=(("layers", "default", "normalization", "type"),), + fast_llm_value=NormalizationType.rms_norm, ), RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) + fast_llm_names=(("layers", "default", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("layers", "default", "gated"),), fast_llm_value=True), ConstantImportParamConverter( - fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv" + fast_llm_names=(("layers", "default", "add_linear_biases"),), fast_llm_value="only_attn_qkv" ), RopeScalingParamConverter( fast_llm_names=( - ("transformer", "rotary", "type"), - ("transformer", "rotary", "scale_factor"), - ("transformer", "rotary", "low_frequency_factor"), - ("transformer", "rotary", "high_frequency_factor"), - ("transformer", "rotary", "original_context_length"), - ("transformer", "rotary", "attention_factor"), - ("transformer", "rotary", "beta_fast"), - ("transformer", "rotary", "beta_slow"), + ("layers", "default", "rotary", "type"), + ("layers", "default", "rotary", "scale_factor"), + ("layers", "default", "rotary", "low_frequency_factor"), + ("layers", "default", "rotary", "high_frequency_factor"), + ("layers", "default", "rotary", "original_context_length"), + ("layers", "default", "rotary", "attention_factor"), + ("layers", "default", "rotary", "beta_fast"), + ("layers", "default", "rotary", "beta_slow"), ), export_names=(("rope_scaling",),), ), @@ -438,18 +445,18 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + layer_config = self._model.config.base_model.layers.default return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + layer_config.add_mlp_bias, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + layer_config.add_mlp_bias, MLPLayer2Converter, ), ] @@ -487,13 +494,14 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MixtralForCausalLM"]), ConstantImportParamConverter( - fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk + fast_llm_names=(("layers", "default", "expert_routing_type"),), fast_llm_value=RoutingType.topk ), RenameParamConverter( - fast_llm_names=(("transformer", "num_experts"),), export_names=(("num_local_experts",),) + fast_llm_names=(("layers", "default", "num_experts"),), export_names=(("num_local_experts",),) ), RenameParamConverter( - fast_llm_names=(("transformer", "num_experts_per_token"),), export_names=(("num_experts_per_tok",),) + fast_llm_names=(("layers", "default", "num_experts_per_token"),), + export_names=(("num_experts_per_tok",),), ), IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index da0eaa8d..bad1016f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -53,7 +53,7 @@ def __init__( if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) - param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa + param.init_parameter = get_init_megatron(param, self._config.layers.default) # Noqa if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor = PositionEmbeddingPreprocessor(self._config, self._tensor_space) if self._config.layers.default.rotary.enabled: From a3a1b2c148044766476953769c796db7ac402c7d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 4 Mar 2025 19:07:42 -0500 Subject: [PATCH 4/7] Cleanup, misc --- fast_llm/layers/language_model/config.py | 7 ----- fast_llm/layers/transformer/attention.py | 6 ++-- fast_llm/layers/transformer/config.py | 14 +++++++-- fast_llm/layers/transformer/preprocessing.py | 32 ++++++++++++++------ fast_llm/models/gpt/model.py | 22 ++++++-------- 5 files changed, 46 insertions(+), 35 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 35970842..b4ba271f 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -32,11 +32,6 @@ class LanguageModelKwargs: @config_class() class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): - # transformer: TransformerLayerArchitectureConfig = Field( - # default_factory=TransformerLayerArchitectureConfig, - # desc="Configuration for the transformer architecture.", - # hint=FieldHint.core, - # ) layers: TransformerArchitectureConfig = Field(default_factory=TransformerArchitectureConfig) max_position_embeddings: int = Field( default=2048, @@ -65,7 +60,6 @@ def _validate(self) -> None: super()._validate() def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - assert self._validated self.layers.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -124,7 +118,6 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig - # transformer: TransformerLayerConfig = FieldUpdate(default_factory=TransformerLayerConfig) layers: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) init_method_std_embed: float = Field( default=None, diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index fe982bed..752bd8eb 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -72,7 +72,7 @@ def __init__( super().__init__() self._config = config self._tensor_space = tensor_space - Assert.in_range_incl(layer_index, 1, self._config.num_layers) + Assert.in_range(layer_index, 0, self._config.num_layers) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer @@ -157,10 +157,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / (self._layer_index + 1), ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * (self._layer_index + 1) attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 18bfeeb9..b2884fc4 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -306,7 +306,6 @@ def _from_dict( return super()._from_dict(default, strict, flat) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - assert self._validated tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension @@ -665,7 +664,7 @@ def in_range(self, index) -> bool: @config_class() class TransformerLayerRangeConfig(TransformerLayerRangeArchitectureConfig, BaseModelConfig): - pass + config: TransformerLayerConfig = FieldUpdate(init=False) @config_class() @@ -677,9 +676,12 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig): def _validate(self) -> None: for layer in self.layers: layer.setup(self.default) - # TODO: Improve this? + # Hidden layers must match Assert.eq(layer.config.hidden_size, self.default.hidden_size) + # TODO: Move elsewhere? Kept here because used in a few places like default initialization. Assert.eq(layer.config.num_layers, self.default.num_layers) + # TODO: Rotary preprocessor doesn't support variations across layers. + Assert.eq(layer.config.rotary.to_serialized(), self.default.rotary.to_serialized()) super()._validate() def get_layer_config_and_tensor_space( @@ -702,6 +704,12 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): layers: list[TransformerLayerRangeConfig] = FieldUpdate() default: TransformerLayerConfig = FieldUpdate(default_factory=TransformerLayerConfig) + def _validate(self) -> None: + for layer in self.layers: + # Hidden layers must match + Assert.eq(layer.config.full_precision_residual, self.default.full_precision_residual) + super()._validate() + def get_layer_config_and_tensor_space( self, index: int, tensor_space: TensorSpace ) -> tuple[TransformerLayerConfig, TensorSpace]: diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 86578f09..82538257 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -9,16 +9,16 @@ from fast_llm.layers.transformer.config import ( RotaryConfig, RotaryEmbeddingType, + TransformerConfig, TransformerDimNames, TransformerKwargs, - TransformerLayerConfig, ) from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) -def apply_llama3_scaling(config: RotaryConfig, frequencies: torch.Tensor) -> torch.Tensor: +def apply_llama3_scaling(config: RotaryConfig, frequencies: torch.Tensor) -> tuple[torch.Tensor, float]: """ Llama3 scaling: https://github.com/meta-llama/llama-models/blob/baf7b01b6e62bc7126c7b558d2b67d4533142680/models/llama3/reference_impl/model.py#L45-L67 """ @@ -40,7 +40,7 @@ def apply_llama3_scaling(config: RotaryConfig, frequencies: torch.Tensor) -> tor return torch.tensor(new_frequencies, dtype=frequencies.dtype, device=frequencies.device), 1.0 -def apply_yarn_scaling(config: RotaryConfig, frequencies: torch.Tensor, kv_channels, sequence_length) -> torch.Tensor: +def apply_yarn_scaling(config: RotaryConfig, frequencies: torch.Tensor, kv_channels) -> tuple[torch.Tensor, float]: """ Yarn scaling: https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163 @@ -115,7 +115,7 @@ def get_rotary_frequencies( if config.type == RotaryEmbeddingType.llama3: frequencies, attention_scaling = apply_llama3_scaling(config, frequencies) elif config.type == RotaryEmbeddingType.yarn: - frequencies, attention_scaling = apply_yarn_scaling(config, frequencies, kv_channels, sequence_length) + frequencies, attention_scaling = apply_yarn_scaling(config, frequencies, kv_channels) else: attention_scaling = 1.0 angles = torch.outer(positions, frequencies) @@ -199,16 +199,26 @@ class BackupAttentionPreprocessor: def __init__( self, - config: TransformerLayerConfig, + config: TransformerConfig, tensor_space: TensorSpace, ): - self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - assert not self._config.do_use_flash_attention(self._distributed_config) + all_configs = [config.default] + config.layers + self._enabled = not all( + layer_config.do_use_flash_attention(self._distributed_config) for layer_config in all_configs + ) + if self._enabled: + window_sizes = {layer_config.window_size for layer_config in all_configs} + if len(window_sizes) != 1: + raise ValueError("Variable window size not supported for backup attention.") + self._window_size = window_sizes.pop() + self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) def create_tensors(self, sequence_length: int) -> None: + if not self._enabled: + return if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -218,8 +228,8 @@ def create_tensors(self, sequence_length: int) -> None: dtype=torch.bool, device=self._tensor_space.distributed.device, ).tril_() - if self._config.window_size is not None: - self._mask.triu_(-self._config.window_size + 1) + if self._window_size is not None: + self._mask.triu_(-self._window_size + 1) self._mask_value = torch.full( [], torch.finfo(self._distributed_config.training_dtype.torch).min, @@ -228,6 +238,8 @@ def create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + if not self._enabled: + return sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size kwargs[TransformerKwargs.attention_mask] = self._mask[ None, None, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k, None, :sequence_k @@ -235,6 +247,8 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: kwargs[TransformerKwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + if not self._enabled: + return kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bad1016f..0fd4b522 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -60,20 +60,19 @@ def __init__( self._rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( self._config.layers.default.rotary, self._tensor_space ) - if not self._use_flash_attention: - self._backup_attention_preprocessor = BackupAttentionPreprocessor( - self._config.layers.default, self._tensor_space - ) + self._backup_attention_preprocessor = BackupAttentionPreprocessor( + self._config.layers.default, self._tensor_space + ) def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ TransformerLayer( - *self._config.layers.get_layer_config_and_tensor_space(i, self._tensor_space), - layer_index=i + 1, + *self._config.layers.get_layer_config_and_tensor_space(layer_index, self._tensor_space), + layer_index=layer_index, ) - for i in range(self._config.layers.default.num_layers) + for layer_index in range(self._config.layers.default.num_layers) ], LanguageModelHead(self._config, self._tensor_space), ] @@ -176,8 +175,7 @@ def preprocess_meta( self._position_embedding_preprocessor.preprocess_meta(kwargs) if self._config.layers.default.rotary.enabled: self._rotary_embedding_preprocessor.preprocess_meta(kwargs) - if not self._use_flash_attention: - self._backup_attention_preprocessor.preprocess_meta(kwargs) + self._backup_attention_preprocessor.preprocess_meta(kwargs) preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -215,8 +213,7 @@ def preprocess( self._position_embedding_preprocessor.create_tensors(sequence_length) if self._config.layers.default.rotary.enabled: self._rotary_embedding_preprocessor.create_tensors(sequence_length) - if not self._use_flash_attention: - self._backup_attention_preprocessor.create_tensors(sequence_length) + self._backup_attention_preprocessor.create_tensors(sequence_length) preprocessed = [] presents = None @@ -258,8 +255,7 @@ def preprocess( self._position_embedding_preprocessor.preprocess(kwargs) if self._config.layers.default.rotary.enabled: self._rotary_embedding_preprocessor.preprocess(kwargs) - if not self._use_flash_attention: - self._backup_attention_preprocessor.preprocess(kwargs) + self._backup_attention_preprocessor.preprocess(kwargs) preprocessed.append((tokens, kwargs)) return preprocessed From 8af56c9f8fe42f90d712ffc19b9eb6b11a3f3bc1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 4 Mar 2025 20:08:02 -0500 Subject: [PATCH 5/7] fixes --- Megatron-LM | 1 + fast_llm/config.py | 1 + fast_llm/layers/transformer/config.py | 2 +- fast_llm/models/gpt/conversion.py | 2 ++ fast_llm/models/gpt/model.py | 5 +---- 5 files changed, 6 insertions(+), 5 deletions(-) create mode 160000 Megatron-LM diff --git a/Megatron-LM b/Megatron-LM new file mode 160000 index 00000000..fe1f23cf --- /dev/null +++ b/Megatron-LM @@ -0,0 +1 @@ +Subproject commit fe1f23cf029d088c30a86989562b671af8967129 diff --git a/fast_llm/config.py b/fast_llm/config.py index f1c88965..089e6508 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -888,6 +888,7 @@ def __init_subclass__(cls): valid=value.pop("valid", base_class_field.valid), default=value.pop("default", base_class_field.default), default_factory=value.pop("default_factory", base_class_field.default_factory), + init=value.pop("init", base_class_field.init), repr=value.pop("repr", base_class_field.repr), hash=value.pop("hash", base_class_field.hash), compare=value.pop("compare", base_class_field.compare), diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index b2884fc4..b816f712 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -664,7 +664,7 @@ def in_range(self, index) -> bool: @config_class() class TransformerLayerRangeConfig(TransformerLayerRangeArchitectureConfig, BaseModelConfig): - config: TransformerLayerConfig = FieldUpdate(init=False) + config: TransformerLayerConfig = FieldUpdate() @config_class() diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index cdae7ca7..8d190061 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -116,6 +116,8 @@ class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + # Variable layer config not supported. + ConstantImportParamConverter(fast_llm_names=(("layers", "layers"),), fast_llm_value=[]), ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), RenameParamConverter( fast_llm_names=(("layers", "default", "rotary", "theta"),), export_names=(("rope_theta",),) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 0fd4b522..f522ebb8 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -49,7 +49,6 @@ def __init__( distributed_config: DistributedConfig, ): super().__init__(config, distributed_config) - self._use_flash_attention = self._config.layers.default.do_use_flash_attention(distributed_config) if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) @@ -60,9 +59,7 @@ def __init__( self._rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( self._config.layers.default.rotary, self._tensor_space ) - self._backup_attention_preprocessor = BackupAttentionPreprocessor( - self._config.layers.default, self._tensor_space - ) + self._backup_attention_preprocessor = BackupAttentionPreprocessor(self._config.layers, self._tensor_space) def get_layers(self) -> list[Layer]: return [ From 970b5789f9ba6c09e69150e423e0932c80eab8ea Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 4 Mar 2025 21:53:11 -0500 Subject: [PATCH 6/7] misc --- fast_llm/layers/transformer/config.py | 46 +++++++++++++++----- fast_llm/layers/transformer/preprocessing.py | 2 +- fast_llm/models/gpt/conversion.py | 1 + tests/test_config.py | 35 +-------------- tests/test_transformer.py | 42 ++++++++++++++++++ 5 files changed, 80 insertions(+), 46 deletions(-) create mode 100644 tests/test_transformer.py diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index b816f712..1401c490 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -618,30 +618,49 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: DataType.bfloat16, ) - # Config parameter `window_size` only can be used with flash attention if not use_flash_attention: - Assert.is_(self.window_size, None) + assert self.max_window_layers is None return use_flash_attention @config_class() -class SliceConfig(Config): - begin: int = Field(default=0) - end: int | None = Field(default=None) - step: int = Field(default=1) +class RangeConfig(Config): + """ + A configuration that defines a range of values, to be used for example in python `slice` or `range`. + """ + + # TODO: Not specific to transformers, move elsewhere? + begin: int = Field( + default=0, + desc="The beginning of the range.", + hint=FieldHint.optional, + ) + end: int | None = Field( + default=None, + desc="The end of the range (excluded).", + hint=FieldHint.optional, + ) + step: int = Field( + default=1, + desc="The step for the range.", + hint=FieldHint.optional, + ) def in_range(self, index) -> bool: + """ + Checks whether `index` is in `range(begin, end, step)`. + """ return ( - index >= self.begin and (self.end is None or index <= self.end) and ((index - self.begin) % self.step == 0) + index >= self.begin and (self.end is None or index < self.end) and ((index - self.begin) % self.step == 0) ) @config_class() class TransformerLayerRangeArchitectureConfig(BaseModelArchitectureConfig): _abstract = False - layer_ranges: list[SliceConfig] = Field( - default_factory=SliceConfig, + layer_ranges: list[RangeConfig] = Field( + default_factory=RangeConfig, desc="Layer range.", hint=FieldHint.core, ) @@ -651,7 +670,10 @@ class TransformerLayerRangeArchitectureConfig(BaseModelArchitectureConfig): config: TransformerLayerArchitectureConfig = Field(init=False) def setup(self, default: TransformerLayerArchitectureConfig) -> None: - self.config = TransformerLayerArchitectureConfig.from_dict(default, self.updates) + # Create the full config from the default and updates. + # We use `default.from_dict` so we also have the appropriate class in `TransformerLayerRangeConfig`. + # For the architecture class we need to set `strict=False` because of possible non-architecture parameters. + self.config = default.from_dict(default, self.updates, strict=True) # isinstance(self, BaseModelConfig)) def _validate(self) -> None: assert hasattr(self, "config") @@ -705,10 +727,12 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): default: TransformerLayerConfig = FieldUpdate(default_factory=TransformerLayerConfig) def _validate(self) -> None: + super()._validate() for layer in self.layers: # Hidden layers must match Assert.eq(layer.config.full_precision_residual, self.default.full_precision_residual) - super()._validate() + if self.layers: + warnings.warn("Variable layer configuration is experimental. Use with caution.") def get_layer_config_and_tensor_space( self, index: int, tensor_space: TensorSpace diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 82538257..aa2c33af 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -204,7 +204,7 @@ def __init__( ): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - all_configs = [config.default] + config.layers + all_configs = [config.default] + [layer.config for layer in config.layers] self._enabled = not all( layer_config.do_use_flash_attention(self._distributed_config) for layer_config in all_configs ) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 8d190061..ba53dd40 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -117,6 +117,7 @@ class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ # Variable layer config not supported. + # TODO: Find a way to support variable non-architecture parameters. ConstantImportParamConverter(fast_llm_names=(("layers", "layers"),), fast_llm_value=[]), ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), RenameParamConverter( diff --git a/tests/test_config.py b/tests/test_config.py index 15578463..20958906 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,18 +1,11 @@ import pathlib import subprocess -import unittest.mock import pytest import yaml from fast_llm.config import ValidationError -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.transformer.config import ( - AddLinearBiasChoices, - TransformerLayerArchitectureConfig, - TransformerLayerConfig, -) +from fast_llm.layers.transformer.config import AddLinearBiasChoices, TransformerLayerArchitectureConfig from fast_llm.models.auto import trainer_registry @@ -63,32 +56,6 @@ def test_validate_example_config(): trainer_registry["gpt"].from_dict(fast_llm_config_dict) -def test_do_use_flash_attention(): - # Create a mock DistributedConfig - mock_distributed_config = unittest.mock.Mock(spec=DistributedConfig) - - # Test case 1: use_flash_attention is True and training_dtype is float16 - config = TransformerLayerConfig(use_flash_attention=True, window_size=None) - mock_distributed_config.training_dtype = DataType.float16 - assert config.do_use_flash_attention(mock_distributed_config) is True - - # Test case 2: use_flash_attention is False - config = TransformerLayerConfig(use_flash_attention=False, window_size=None) - mock_distributed_config.training_dtype = DataType.float16 - assert config.do_use_flash_attention(mock_distributed_config) is False - - # Test case 3: use_flash_attention is True but training_dtype is not float16 or bfloat16 - config = TransformerLayerConfig(use_flash_attention=True, window_size=None) - mock_distributed_config.training_dtype = DataType.float32 - assert config.do_use_flash_attention(mock_distributed_config) is False - - # Test case 4: use_flash_attention is False and window_size is not None - config = TransformerLayerConfig(use_flash_attention=False, window_size=512) - mock_distributed_config.training_dtype = DataType.float32 - with pytest.raises(AssertionError): - config.do_use_flash_attention(mock_distributed_config) - - def test_add_linear_biases_valid_values(): # Valid boolean values assert TransformerLayerArchitectureConfig(add_linear_biases=True).add_linear_biases is True diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 00000000..809b9cf9 --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,42 @@ +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.common.config import NormalizationType +from fast_llm.models.gpt.config import GPTBaseModelConfig +from fast_llm.models.gpt.model import GPTBaseModel +from fast_llm.utils import Assert + + +def test_variable_window_size(): + model = GPTBaseModel( + GPTBaseModelConfig.from_dict( + { + "layers": { + "default": {"window_size": 1024, "num_layers": 8, "normalization": {"type": "rms_norm"}}, + "layers": [ + { + # Layers 5, 6 and 7 + "layer_ranges": [{"begin": 5, "end": None}], + "updates": {"window_size": None, "normalization": {"epsilon": 1}}, + }, + { + # Layers 0, 3 and 5, but 5 already covered above so excluded. + "layer_ranges": [{"begin": 0, "end": 1}, {"begin": 3, "end": 6, "step": 2}], + "updates": {"window_size": 512, "ffn_hidden_size": 64}, + }, + ], + } + } + ), + DistributedConfig(training_dtype=DataType.bfloat16), + ) + Assert.eq( + [layer._config.window_size for layer in model.layers[1:-1]], [512, 1024, 1024, 512, 1024, None, None, None] + ) + Assert.eq([layer._config.normalization.type for layer in model.layers[1:-1]], [NormalizationType.rms_norm] * 8) + Assert.eq([layer._config.normalization.epsilon for layer in model.layers[1:-1]], [1e-5] * 5 + [1] * 3) + Assert.eq( + [layer._config.ffn_hidden_size for layer in model.layers[1:-1]], [64, 4096, 4096, 64, 4096, 4096, 4096, 4096] + ) + # Non-architecture parameters (`window_size`) need to be ignored when converting to architecture config. + # (See `TransformerLayerRangeArchitectureConfig.setup`.) + model.config.get_architecture() From 8de8faec2cc470ed57f9e1786b29c7277944453b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 5 Mar 2025 18:56:20 -0500 Subject: [PATCH 7/7] fixes --- fast_llm/layers/transformer/config.py | 37 ++++++++++++++++++++------- tests/test_transformer.py | 22 +++++++++++++--- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 1401c490..cc6d5a66 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -4,7 +4,16 @@ import typing import warnings -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import ( + Config, + Field, + FieldHint, + FieldUpdate, + check_field, + config_class, + process_field, + skip_valid_if_none, +) from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace @@ -656,6 +665,10 @@ def in_range(self, index) -> bool: ) +def process_config_updates(updates: dict[str | tuple[str, ...], typing.Any]) -> dict[tuple[str, ...], typing.Any]: + return {(tuple(key.split("/")) if isinstance(key, str) else key): value for (key, value) in updates.items()} + + @config_class() class TransformerLayerRangeArchitectureConfig(BaseModelArchitectureConfig): _abstract = False @@ -664,20 +677,24 @@ class TransformerLayerRangeArchitectureConfig(BaseModelArchitectureConfig): desc="Layer range.", hint=FieldHint.core, ) - updates: dict[str, typing.Any] = Field( - default_factory=dict, + updates: dict[tuple[str, ...], typing.Any] = Field( + default_factory=dict, valid=process_field(process_config_updates) ) config: TransformerLayerArchitectureConfig = Field(init=False) + _default: TransformerLayerArchitectureConfig = Field(init=False) def setup(self, default: TransformerLayerArchitectureConfig) -> None: - # Create the full config from the default and updates. - # We use `default.from_dict` so we also have the appropriate class in `TransformerLayerRangeConfig`. - # For the architecture class we need to set `strict=False` because of possible non-architecture parameters. - self.config = default.from_dict(default, self.updates, strict=True) # isinstance(self, BaseModelConfig)) + assert not hasattr(self, "_default") + self._default = default def _validate(self) -> None: - assert hasattr(self, "config") + assert hasattr(self, "_default") assert len(self.layer_ranges) > 0 + super()._validate() + # Create the full config from the default and updates. + # We use `default.from_dict` so we also have the appropriate class in `TransformerLayerRangeConfig`. + # For the architecture class we need to set `strict=False` because of possible non-architecture parameters. + self.config = self._default.from_dict(self._default, self.updates, strict=isinstance(self, BaseModelConfig)) self.config.validate() def in_range(self, index) -> bool: @@ -687,6 +704,7 @@ def in_range(self, index) -> bool: @config_class() class TransformerLayerRangeConfig(TransformerLayerRangeArchitectureConfig, BaseModelConfig): config: TransformerLayerConfig = FieldUpdate() + _default: TransformerLayerConfig = FieldUpdate() @config_class() @@ -698,13 +716,14 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig): def _validate(self) -> None: for layer in self.layers: layer.setup(self.default) + super()._validate() + for layer in self.layers: # Hidden layers must match Assert.eq(layer.config.hidden_size, self.default.hidden_size) # TODO: Move elsewhere? Kept here because used in a few places like default initialization. Assert.eq(layer.config.num_layers, self.default.num_layers) # TODO: Rotary preprocessor doesn't support variations across layers. Assert.eq(layer.config.rotary.to_serialized(), self.default.rotary.to_serialized()) - super()._validate() def get_layer_config_and_tensor_space( self, index: int, tensor_space: TensorSpace diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 809b9cf9..93287c7a 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -16,12 +16,14 @@ def test_variable_window_size(): { # Layers 5, 6 and 7 "layer_ranges": [{"begin": 5, "end": None}], - "updates": {"window_size": None, "normalization": {"epsilon": 1}}, + # Update normalization epsilon, keep rms norm. + "updates": {"window_size": None, "normalization/epsilon": 1}, }, { # Layers 0, 3 and 5, but 5 already covered above so excluded. "layer_ranges": [{"begin": 0, "end": 1}, {"begin": 3, "end": 6, "step": 2}], - "updates": {"window_size": 512, "ffn_hidden_size": 64}, + # Override the whole normalization, type reverts back to default (layer_norm) + "updates": {"window_size": 512, "ffn_hidden_size": 64, "normalization": {"epsilon": 1}}, }, ], } @@ -32,8 +34,20 @@ def test_variable_window_size(): Assert.eq( [layer._config.window_size for layer in model.layers[1:-1]], [512, 1024, 1024, 512, 1024, None, None, None] ) - Assert.eq([layer._config.normalization.type for layer in model.layers[1:-1]], [NormalizationType.rms_norm] * 8) - Assert.eq([layer._config.normalization.epsilon for layer in model.layers[1:-1]], [1e-5] * 5 + [1] * 3) + Assert.eq( + [layer._config.normalization.type for layer in model.layers[1:-1]], + [ + NormalizationType.layer_norm, + NormalizationType.rms_norm, + NormalizationType.rms_norm, + NormalizationType.layer_norm, + NormalizationType.rms_norm, + NormalizationType.rms_norm, + NormalizationType.rms_norm, + NormalizationType.rms_norm, + ], + ) + Assert.eq([layer._config.normalization.epsilon for layer in model.layers[1:-1]], [1, 1e-5, 1e-5, 1, 1e-5, 1, 1, 1]) Assert.eq( [layer._config.ffn_hidden_size for layer in model.layers[1:-1]], [64, 4096, 4096, 64, 4096, 4096, 4096, 4096] )