diff --git a/examples/training/mamba2/LICENSE b/examples/training/mamba2/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/examples/training/mamba2/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/examples/training/mamba2/NOTICE b/examples/training/mamba2/NOTICE new file mode 100644 index 0000000..f0df437 --- /dev/null +++ b/examples/training/mamba2/NOTICE @@ -0,0 +1,3 @@ +Mamba-2 port for AWS Trainium +Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. \ No newline at end of file diff --git a/examples/training/mamba2/convert_mamba2_ssm_checkpoint.py b/examples/training/mamba2/convert_mamba2_ssm_checkpoint.py new file mode 100644 index 0000000..becb6d3 --- /dev/null +++ b/examples/training/mamba2/convert_mamba2_ssm_checkpoint.py @@ -0,0 +1,233 @@ +# coding=utf-8 +# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. +# Modifications Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script can be used to convert checkpoints provided in the `mamba2_ssm` library into the format provided in HuggingFace `transformers`. +It depends on the `mamba2_ssm` package to be installed. + +Unlike the Mamba2 implementation in transformers, we split the Mamba2Mixer.in_proj in three separate linear layer. +This version of the script has an additional flag --split_proj to convert checkpoint to our format. +""" + +import argparse +import json +from functools import partial +from os import path +from typing import Dict, Optional + +import torch +from safetensors import safe_open +from safetensors.torch import save_model + +from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast +from mamba2.configuration_mamba2 import Mamba2Config +from mamba2.modeling_mamba2 import Mamba2ForCausalLM + + +def load_state_dict_from_safetensors(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]: + # Load weights and config from paths + original_state_dict = {} + with safe_open(path.join(mamba2_checkpoint_path, ckpt_name), framework="pt") as f: + for k in f.keys(): + newk = k.removeprefix("model.") + original_state_dict[newk] = f.get_tensor(k).clone() + return original_state_dict + + +def load_state_dict_from_torch(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]: + return torch.load(path.join(mamba2_checkpoint_path, ckpt_name), map_location="cpu") + + +def convert_ssm_config_to_hf_config(config_ssm: Dict, mamba2_model_dict: Dict) -> Mamba2Config: + """Convert a Mamba2Config from mamba_ssm to a Mamba2Config from here.""" + hf_config = Mamba2Config() + + # Switch to a different dict depending on model type + config_dict = mamba2_model_dict + + # Set important values from config and recalculate other resulting entries + hf_config.hidden_size = config_ssm[config_dict["hidden_size"]] + hf_config.num_heads = (hf_config.hidden_size * hf_config.expand) // hf_config.head_dim + hf_config.num_hidden_layers = config_ssm[config_dict["num_hidden_layers"]] + hf_config.n_groups = config_ssm.get(config_dict["n_groups"], 1) + hf_config.tie_word_embeddings = config_ssm["tie_embeddings"] + hf_config.bos_token_id = config_dict["bos_token_id"] + hf_config.pad_token_id = config_dict["pad_token_id"] + hf_config.eos_token_id = config_dict["eos_token_id"] + + # Padded vocab size, mostly of 16 but 32 is also very common in different models + vocab_size = config_ssm["vocab_size"] + pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"] + if (vocab_size % pad_vocab_size_multiple) != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + hf_config.vocab_size = vocab_size + + return hf_config + + +def load_and_save_tokenizer( + mamba2_model_type: str, + output_dir: str, + tokenizer_model_path: Optional[str] = None, +) -> None: + tokenizer = None + + # Load tokenizer + if tokenizer_model_path is not None and mamba2_model_type == "codestral": + tokenizer_class = LlamaTokenizerFast + tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True) + elif mamba2_model_type == "mamba_ssm": + tokenizer = GPTNeoXTokenizerFast.from_pretrained("state-spaces/mamba-130m-hf", padding_side="left") + + # Save tokenizer + if tokenizer is not None: + tokenizer.save_pretrained(output_dir) + + +_MAMBA2_MODELS_DICT = { + "codestral": { + "hidden_size": "dim", + "num_hidden_layers": "n_layers", + "n_groups": "n_groups", + "bos_token_id": 0, + "pad_token_id": 1, + "eos_token_id": 2, + "config_name": "params.json", + "load_state_dict": partial(load_state_dict_from_safetensors, ckpt_name="consolidated.safetensors"), + "load_and_save_tokenizer": partial(load_and_save_tokenizer, "codestral"), + }, + "mamba_ssm": { + "hidden_size": "d_model", + "num_hidden_layers": "n_layer", + "n_groups": "ngroups", + "bos_token_id": 0, + "pad_token_id": 0, + "eos_token_id": 0, + "config_name": "config.json", + "load_state_dict": partial(load_state_dict_from_torch, ckpt_name="pytorch_model.bin"), + "load_and_save_tokenizer": partial(load_and_save_tokenizer, "mamba_ssm"), + }, +} + + +def split_projection_matrix(hf_model, state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + if not k.endswith('in_proj.weight'): + new_state_dict[k] = v + continue + layer_name = k.split('.in_proj.weight')[0] + + layer = hf_model.get_submodule(layer_name) + w_z, w_xBC, w_dt = torch.split(v, [ + layer.intermediate_size, + layer.conv_dim, + layer.num_heads], dim=0) + + # since .split() gives a view, we do .clone() to ensure we don't keep the original tensor in memory + new_state_dict[layer_name + '.in_proj_z.weight'] = w_z.detach().clone() + new_state_dict[layer_name + '.in_proj_xBC.weight'] = w_xBC.detach().clone() + new_state_dict[layer_name + '.in_proj_dt.weight'] = w_dt.detach().clone() + + return new_state_dict + + +def convert_mamba2_checkpoint_file_to_huggingface_model_file( + mamba2_checkpoint_path: str, + mamba2_model_type: str, + precision: str, + output_dir: str, + tokenizer_model_path: Optional[str] = None, + split_proj: bool = False +) -> None: + mamba2_model_dict = _MAMBA2_MODELS_DICT[mamba2_model_type] + + # Load and save config based on name + config_path = path.join(mamba2_checkpoint_path, mamba2_model_dict["config_name"]) + with open(config_path, "r", encoding="utf-8") as json_file: + config = json.load(json_file) + hf_config = convert_ssm_config_to_hf_config(config_ssm=config, mamba2_model_dict=mamba2_model_dict) + hf_config.save_pretrained(output_dir) + + # Load state dict of the original model and transfer to hf model + original_state_dict = mamba2_model_dict["load_state_dict"](mamba2_checkpoint_path=mamba2_checkpoint_path) + hf_model = Mamba2ForCausalLM(hf_config) + if split_proj: + original_state_dict = split_projection_matrix(hf_model, original_state_dict) + hf_model.load_state_dict(original_state_dict) + + # Save new model to pytorch_dump_path + dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16) + save_model(hf_model.to(dtype), path.join(output_dir, "model.safetensors"), metadata={"format": "pt"}) + + # Load and save tokenizer + mamba2_model_dict["load_and_save_tokenizer"](output_dir=output_dir, tokenizer_model_path=tokenizer_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--mamba2_checkpoint_directory", + type=str, + required=True, + help="Path to a directory containing the `pytorch_model.bin` or `.safetensors` mamba2_ssm checkpoint file to be converted.", + ) + parser.add_argument( + "-m", + "--mamba2_model_type", + type=str, + default="mamba_ssm", + required=True, + choices=("codestral", "mamba_ssm"), + help="The model type the conversion will be performed on. Can choose from either `codestral` or `mamba_ssm`.", + ) + parser.add_argument( + "-p", + "--precision", + type=str, + default="fp16", + required=True, + choices=("fp32", "fp16", "bf16"), + help="The precision the model will be saved in. Select from fp32, fp16 or bf16.", + ) + parser.add_argument( + "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to." + ) + parser.add_argument( + "-t", + "--tokenizer_model_path", + type=str, + default=None, + required=False, + help="Path to a `codestral` tokenizer file.", + ) + parser.add_argument( + "-s", + "--split_proj", + action="store_true", + help="Split the input projection matrix in 3 separate matrices", + ) + args = parser.parse_args() + + convert_mamba2_checkpoint_file_to_huggingface_model_file( + args.mamba2_checkpoint_directory, + args.mamba2_model_type, + args.precision, + args.output_dir, + args.tokenizer_model_path, + args.split_proj + ) diff --git a/examples/training/mamba2/mamba2/__init__.py b/examples/training/mamba2/mamba2/__init__.py new file mode 100644 index 0000000..789dbf1 --- /dev/null +++ b/examples/training/mamba2/mamba2/__init__.py @@ -0,0 +1,2 @@ +from .modeling_mamba2 import * +from .configuration_mamba2 import * \ No newline at end of file diff --git a/examples/training/mamba2/mamba2/configuration_mamba2.py b/examples/training/mamba2/mamba2/configuration_mamba2.py new file mode 100644 index 0000000..71feffd --- /dev/null +++ b/examples/training/mamba2/mamba2/configuration_mamba2.py @@ -0,0 +1,186 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Modifications Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Same as transformers.model.mamba2.Mamba2Config but uses different defaults and adds an option rmsnorm_within_groups.""" + +import math + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Mamba2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Mamba2Model`]. It is used to instantiate a MAMBA2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the MAMBA2 + [state-spaces/mamba2-2.8b](https://huggingface.co/state-spaces/mamba2-2.8b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + num_heads (`int`, *optional*, defaults to 128): + Number of heads for the evolution matrices of mamba 2. + head_dim (`int`, *optional*, defaults to 64): + Dimension of each head. + vocab_size (`int`, *optional*, defaults to 32768): + Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Mamba2Model`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimensionality of the embeddings and hidden states. + state_size (`int`, *optional*, defaults to 128): shape of the state space latents. + num_hidden_layers (`int`, *optional*, defaults to 64): + Number of hidden layers in the model. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the end of sentence token in the vocabulary. + expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. + n_groups (`int`, *optional*, defaults to 8): + Number of groups for the evolution matrices of mamba 2. + use_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.1): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + residual_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model + time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`): + Accepted range of time step values. + rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): + Whether or not to rescale `out_proj` weights when initializing. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the cache should be used. + rms_norm (`bool`, *optional*, defaults to `True`): + Whether to use RMS norm or not. + rmsnorm_within_groups (`bool`, *optional*, defaults to `True`): + Whether to use RMS norm independently within n_groups or not. + chunk_size (`int`, *optional*, defaults to 256): + Size of the chunks that will comprise the sequence. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie word embeddings or not. + + + Example: + + ```python + >>> from transformers import Mamba2Config, Mamba2Model + + >>> # Initializing a Mamba2 configuration + >>> configuration = Mamba2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = Mamba2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mamba2" + + def __init__( + self, + num_heads=128, + head_dim=64, + vocab_size=32768, + hidden_size=4096, + state_size=128, + num_hidden_layers=64, + layer_norm_epsilon=1e-5, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + expand=2, + conv_kernel=4, + n_groups=8, + use_bias=False, + use_conv_bias=True, + hidden_act="silu", + initializer_range=0.1, + residual_in_fp32=True, + time_step_rank="auto", + time_step_min=0.001, + time_step_max=0.1, + time_step_floor=1e-4, + time_step_limit=(0.0, float("inf")), + rescale_prenorm_residual=False, + use_cache=False, # fixme: the default in HF is True but we don't support cache yet + rms_norm=True, + rmsnorm_within_groups=True, + chunk_size=256, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.state_size = state_size + self.num_hidden_layers = num_hidden_layers + self.layer_norm_epsilon = layer_norm_epsilon + self.conv_kernel = conv_kernel + self.expand = expand + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_floor = time_step_floor + self.rescale_prenorm_residual = rescale_prenorm_residual + self.residual_in_fp32 = residual_in_fp32 + self.use_cache = use_cache + self.n_groups = n_groups + self.num_heads = num_heads + self.head_dim = head_dim + self.rms_norm = rms_norm + self.rmsnorm_within_groups=rmsnorm_within_groups + self.state_size = state_size + self.chunk_size = chunk_size + self.time_step_limit = time_step_limit + self.tie_word_embeddings = tie_word_embeddings + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/examples/training/mamba2/mamba2/conv1d_grouped.py b/examples/training/mamba2/mamba2/conv1d_grouped.py new file mode 100644 index 0000000..5a4cb4c --- /dev/null +++ b/examples/training/mamba2/mamba2/conv1d_grouped.py @@ -0,0 +1,365 @@ +"""NKI implementation of a causal channel-wise 1d convolution. + +This is a drop-in replacement of torch.nn.Conv1d when groups == in_channels == out_channels. It automatically +applies the right amount of left zero-padding to ensure the length of the output is the same as the input. +""" + +from typing import Union + +import torch + +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl +import neuronxcc.nki.isa as nisa +from torch_neuronx import nki_jit +import neuronx_distributed.parallel_layers.parallel_state as ps + +import numpy as np +import torch.nn as nn + + +def diag(w): + """ + w: (128, 1) + """ + m = w.shape[0] + i_p = nl.arange(m)[:, None] + i_f = nl.arange(m)[None, :] + w_diag = nisa.affine_select(i_p == i_f, w.broadcast_to((m, m)), 0) + return w_diag + + +def apply_activation(x, bias, activation: str): + if activation is None: + return x + bias if bias is not None else x + elif activation == 'relu': + return nisa.activation(nl.relu, x, bias=bias) + elif activation == 'd_relu': + assert bias is None + return x >= 0 + elif activation == 'silu': + z = x + bias if bias is not None else x + return z * nisa.activation(nl.sigmoid, z) + elif activation == 'd_silu': + z = x + bias if bias is not None else x + return nl.sigmoid(z) * (1 + z * (1 - nl.sigmoid(z))) + elif activation == 'd_identity': + return 1 + else: + raise ValueError(f'Invalid activation {activation}') + + +def _conv_tile_tensor_e(data_tile, weight_tile, b_tile, dtype, activation=None): + """Computes and returns the convolution of a data tile given weights and bias. + + Isolated from the rest of the code since we use it both for forward and backward. + """ + p_size, n = data_tile.shape + kernel_size = weight_tile.shape[1] + + conv = nl.zeros(shape=(p_size, n), dtype=dtype) + + chunk_size = n // kernel_size + + i_p = nl.arange(p_size)[:, None] + + for j in nl.affine_range(kernel_size): + i_f_plus_j = j + nl.arange(chunk_size)[None, :] * kernel_size + res = nl.zeros((p_size, chunk_size), dtype=nl.float32, buffer=nl.psum) + for i in nl.affine_range(kernel_size): + w_diag = diag(weight_tile[i_p, i]) + res += nisa.nc_matmul(w_diag, data_tile[i_p, i_f_plus_j + i]) + conv[i_p, i_f_plus_j] = apply_activation(res, bias=b_tile, activation=activation) + + return conv + + +def _conv_tile_scalar_e(data_tile, weight_tile, b_tile, dtype, activation=None): + """Computes and returns the convolution of a data tile given weights and bias. + + Isolated from the rest of the code since we use it both for forward and backward. + """ + p_size, n = data_tile.shape + kernel_size = weight_tile.shape[1] + + conv = nl.ndarray(shape=(p_size, n), dtype=dtype) + + chunk_size = n // kernel_size + + i_p = nl.arange(p_size)[:, None] + + for j in nl.affine_range(kernel_size): + i_f_plus_j = j + nl.arange(chunk_size)[None, :] * kernel_size + res = nki.isa.tensor_scalar(data_tile[i_p, i_f_plus_j], op0=np.multiply, + operand0=weight_tile[i_p, 0], dtype=dtype) + for i in nl.static_range(1, kernel_size): + res = res + nki.isa.tensor_scalar(data_tile[i_p, i_f_plus_j + i], op0=np.multiply, + operand0=weight_tile[i_p, i], + dtype=dtype) + conv[i_p, i_f_plus_j] = apply_activation(res, bias=b_tile, activation=activation) + return conv + + +# _conv_tile = _conv_tile_scalar_e +_conv_tile = _conv_tile_tensor_e + + +@nki_jit +def conv1d_grouped_kernel(input_data, w, b, output, activation=None): + """NKI kernel to compute grouped 1d causal convolution, equivalent to: + + D, L = x.shape + + conv = nn.Conv1d( + in_channels=D, + out_channels=D, + bias=True, + kernel_size=kernel_size, + groups=D, + padding=kernel_size - 1, + ) + y = conv(x)[:, :L] + + Args: + input_data: input tensor of shape [D,L] + w: conv weights of shape [D, kernel_size] + b: conv bias of shape [D] + output: output tensor of shape [D, L] + """ + + batch_size, p, n = input_data.shape + ch, _, ks = w.shape + dtype = input_data.dtype + + # fixme: make the code work for any size + assert p % 128 == 0 and ch == p and p == ch + assert n % ks == 0 and n > ks # check n is a multiple of kernel size + assert ks == 4 # fixme: don't think this constrain is needed + assert n <= 2048, "conv1d does not yet support sequence lengths larger than 2048" + + i_w = nl.arange(ks)[None, :] + i_p = nl.arange(128)[:, None] + i_y = nl.arange(n)[None, :] + + # Iterate over channel dimension then over batch dimension (so we load the weights only once for all samples) + for k in nl.affine_range(input_data.shape[1] // 128): + i_p_input = i_p + k * 128 + # weights and biases for current tile + w_tile = nl.load(w.reshape((ch, ks))[i_p_input, i_w]) + b_tile = nl.load(b[i_p_input]) + + for i_b in nl.affine_range(input_data.shape[0]): + # Load with padding + x = nl.zeros(shape=(128, n + ks - 1), dtype=dtype) + x[i_p, ks - 1 + i_y] = nl.load(input_data[i_b, i_p_input, i_y]) + # run the convolution + conv = _conv_tile(x, w_tile, b_tile, dtype, activation=activation) + # The first positions contain the result of a zero padded window + nl.store(output[i_b, i_p_input, i_y], value=conv[i_p, i_y]) + + +@nki_jit +def conv1d_grouped_kernel_longseq(input_data, w, b, output, activation=None): + """NKI kernel to compute grouped 1d causal convolution for sequences of all lengths by processing them in sub-sequences + + equivalent to: + + D, L = x.shape + + conv = nn.Conv1d( + in_channels=D, + out_channels=D, + bias=True, + kernel_size=kernel_size, + groups=D, + padding=kernel_size - 1, + ) + y = conv(x)[:, :L] + + Args: + input_data: input tensor of shape [D,L] + w: conv weights of shape [D, kernel_size] + b: conv bias of shape [D] + output: output tensor of shape [D, L] + """ + + _, channels, seq_len = input_data.shape + ch, _, ks = w.shape + dtype = input_data.dtype + + # fixme: make the code work for any size + assert channels % 128 == 0 and ch == channels + assert seq_len % ks == 0 and seq_len > ks # check seq_len is a multiple of kernel size + assert ks == 4 # fixme: don't think this constrain is needed + + i_w = nl.arange(ks)[None, :] + i_p = nl.arange(128)[:, None] + + sub_seq_len = min(2048, seq_len) + num_sub_seqs = (seq_len + sub_seq_len - 1) // sub_seq_len + + padded_len = sub_seq_len + ks - 1 + i_f_subseq = nl.arange(sub_seq_len)[None, :] + i_f_subseq_padded = nl.arange(padded_len)[None, :] + + # Iterate over channel dimension then over batch dimension (so we load the weights only once for all samples) + for k in nl.affine_range(input_data.shape[1] // 128): + i_p_input = i_p + k * 128 + # weights and biases for current tile + w_tile = nl.load(w.reshape((ch, ks))[i_p_input, i_w]) + b_tile = nl.load(b[i_p_input]) + + for batch_id in nl.affine_range(input_data.shape[0]): + + for subseq_id in nl.affine_range(num_sub_seqs): + i_f_subseq_padded_in = i_f_subseq_padded + subseq_id * sub_seq_len + x = nl.zeros(shape=(128, padded_len), dtype=dtype) + mask = (i_f_subseq_padded_in - (ks - 1) >= 0) & (i_f_subseq_padded_in - (ks - 1) < seq_len) + x[i_p, i_f_subseq_padded] = nl.load(input_data[batch_id, i_p_input, i_f_subseq_padded_in - (ks - 1)], + mask=mask) + # run the convolution + conv = _conv_tile(x, w_tile, b_tile, dtype, activation=activation) + mask_out = i_f_subseq + subseq_id * sub_seq_len < seq_len + # store the result + nl.store(output[batch_id, i_p_input, i_f_subseq + subseq_id * sub_seq_len], value=conv[i_p, i_f_subseq], + mask=mask_out) + + +@nki_jit +def conv1d_grouped_kernel_grad(input_data, w, d_output, d_input, d_w, d_b, activation=None): + batch_size, p, n = input_data.shape + ch, _, ks = w.shape + dtype = input_data.dtype + + assert p % 128 == 0 and ch == p and p == ch + assert n % ks == 0 and n > ks # check n is a multiple of kernel size + assert ks == 4 + + i_p = nl.arange(128)[:, None] + i_f_n = nl.arange(n)[None, :] + i_f_w = nl.arange(ks)[None, :] + seq_len = n + ks - 1 + i_f_seq_len = nl.arange(seq_len)[None, :] + + if activation is not None: + d_activation = 'd_' + activation + else: + d_activation = 'd_identity' + + for chunk_id in nl.affine_range(input_data.shape[1] // 128): + i_p_input = chunk_id * 128 + nl.arange(128)[:, None] + w_tile = nl.load(w[i_p_input, 0, i_f_w]) + # we don't need the bias to compute gradients + b_tile = None + + db_accumulation = nl.zeros([128, 1], dtype=dtype) + dw_accumulation = nl.zeros([128, ks], dtype=dtype) + + for batch_id in nl.affine_range(input_data.shape[0]): + # fixme: probably don't need to pad this + x = nl.zeros(shape=(128, n + ks - 1), dtype=dtype) + x[i_p, ks - 1 + i_f_n] = nl.load(input_data[batch_id, i_p_input, i_f_n]) + + if activation is not None: + preact_grad = _conv_tile(x, w_tile, b_tile, dtype, activation=d_activation)[i_p, i_f_n] + else: + preact_grad = 1 + dout_tile = nl.zeros(shape=(128, n + ks - 1), dtype=dtype) + dout_tile[i_p, i_f_n] = preact_grad * nl.load(d_output[batch_id, i_p_input, i_f_n]) + + # Compute db + db_accumulation += nl.sum(dout_tile[i_p, i_f_n], axis=[1]) + + # Compute d_input + dout_reverse = nl.ndarray((128, seq_len), dtype=dtype) + # fixme: we should simply index the tile with flipped indexes, no need for the copy + # but it will break down later as double indexing tile[i_p, i_f][i_p1, i_f1] is not supported + dout_reverse[i_p, i_f_seq_len] = dout_tile[i_p, seq_len - 1 - i_f_seq_len] + # dout_reverse = dout_tile[i_p, seq_len - 1 - i_f_seq_len] + + conv = _conv_tile(dout_reverse, w_tile, b_tile=None, dtype=dtype, activation=None) + + # We flip the result while storing + nl.store(d_input[batch_id, i_p_input, i_f_n], conv[i_p, seq_len - ks - i_f_n]) + + dw_batch = nl.ndarray((128, 4), dtype=dtype) + # Compute dw + for i in nl.static_range(ks): + # todo: the vector engine should be able to execute both element-wise product and sum in one instruction + dw_batch[i_p, i] = nl.sum(x[i_p, i + i_f_n] * dout_tile[i_p, i_f_n], axis=[1]) + dw_accumulation += dw_batch + + nl.store(d_b[i_p_input], db_accumulation[i_p, 0]) + nl.store(d_w[i_p_input, 0, i_f_w], dw_accumulation[i_p, i_f_w]) + + +class GroupedConv1dNKI(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias, activation=None): + # fixme: if output is too large we might avoid to store it and recomputed it during backprop + output = torch.empty_like(input) + # if input.shape[2] <= 2048: + # output = conv1d_grouped_kernel(input, weight, bias, output, activation=activation) + # else: + output = conv1d_grouped_kernel_longseq(input, weight, bias, output, activation=activation) + ctx.save_for_backward(input, weight, bias) + ctx.activation = activation + return output + + @staticmethod + def backward(ctx, d_output): + input, weight, bias = ctx.saved_tensors + dinput = torch.empty_like(input) + dweight = torch.empty_like(weight) + dbias = torch.empty_like(bias) + if input.shape[2] > 2048: + raise NotImplementedError('Gradient not implemented for conv1d with seq_len>2048') + # dinput, dweight, dbias = conv1d_grouped_kernel_bwd(input, weight, bias, d_output) + conv1d_grouped_kernel_grad(input, weight, d_output, dinput, dweight, dbias, activation=ctx.activation) + return dinput, dweight, dbias, None + + +def nki_conv1d(input, weight, bias=None, activation=None): + return GroupedConv1dNKI.apply(input, weight, bias, activation) + + +class ConvNKI(nn.Conv1d): + """ + Custom layer implemented in NKI to compute efficiently a grouped convolution, + equivalent to nn.Conv1d with groups == in_channels == out_channels. + + Parameters: + input: (B_tensor, C_tensor, L) + weight: (C_tensor, 1, kernel_size) + bias: (C_tensor) + Return: + output: (B_tensor, C_tensor, L) Each input channel sequence input[b, c, :] is convolved with its own conv weight[c, 0, :]. + The results are then stacked together. + """ + + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, + padding: Union[str, int] = 0, dilation: int = 1, groups: int = 1, bias: bool = True, + padding_mode: str = 'zeros', device=None, dtype=None, activation=None) -> None: + # We only support a very specific use case, check we are in it + assert groups == in_channels, "NKI grouped conv kernel only supports groups == in_channels" + assert padding == kernel_size - 1 + assert padding_mode == 'zeros' + assert dilation == 1 + assert stride == 1 + super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, + device, dtype) + self.activation = activation + self.parallel_split() + + def parallel_split(self): + tp_rank = ps.get_tensor_model_parallel_rank() + tp_size = ps.get_tensor_model_parallel_size() + + chunk = slice(self.out_channels // tp_size * tp_rank, self.out_channels // tp_size * (tp_rank + 1)) + self.weight.data = self.weight.data[chunk].detach().clone() + self.bias.data = self.bias.data[chunk].detach().clone() + self.in_channels = self.out_channels // tp_size + self.out_channels = self.out_channels // tp_size + + def forward(self, input): + return GroupedConv1dNKI.apply(input, self.weight, self.bias, self.activation) diff --git a/examples/training/mamba2/mamba2/mamba2_kernel.py b/examples/training/mamba2/mamba2/mamba2_kernel.py new file mode 100644 index 0000000..8a24479 --- /dev/null +++ b/examples/training/mamba2/mamba2/mamba2_kernel.py @@ -0,0 +1,375 @@ +"""NKI implementation of the forward and backward SSM kernels for a Mamba2 model. + +Paper: https://arxiv.org/abs/2405.21060 +""" + +import torch +import neuronxcc.nki.language as nl +from torch_neuronx import nki_jit +import neuronxcc.nki.isa as nisa + + +def compute_a_factors(logA, ones_triu, chunk_size, d_state, d_head=None): + dtype = logA.dtype + i_p = nl.arange(chunk_size)[:, None] + if d_head is None: + d_head = chunk_size + i_p_head = nl.arange(d_head)[:, None] + i_f = nl.arange(chunk_size)[None, :] + + # we reuse this bcast later in the code, ensure it is large enough for all uses + bcast_size = max(chunk_size, d_state) + + logA_bcast = logA.broadcast_to((chunk_size, bcast_size)) + l_bcast = nl.matmul(ones_triu, logA_bcast, transpose_x=True) + l = l_bcast[:chunk_size, 0] + l_t = nl.transpose(l_bcast, dtype=dtype) + + # === compute the _transpose_ of the 128x128 lower triangular matrix L === + partial_sums = l_t[:chunk_size, :chunk_size] - l + L_full_t = nl.exp(partial_sums) + L_t = nisa.affine_select(i_f >= i_p, L_full_t, 0) + + a_right = L_t[i_p, chunk_size - 1] + a_left = nl.exp(l_t[i_p_head, i_f]) + + if d_head != chunk_size: + a_center_t = a_left[:, chunk_size - 1].broadcast_to((d_head, bcast_size)) + a_center = nl.transpose(a_center_t, dtype=dtype) + else: + a_center = a_left[:, chunk_size - 1] + + return L_t, a_left, a_center, a_right + + +def compute_chunk_output(BC_t, L_t, X, C_t, a_left, S, + transpose_gate=False, + transpose_broadcast=False): + """Utility function for computing the output, shared between forward and backward""" + # Diagonal term computation + M_diag_t = L_t * BC_t + Y_diag = nl.matmul(M_diag_t, X, transpose_x=not transpose_gate) + # Compute the off-diagonal contribution using the state + barC_t = C_t * a_left + # utility function for + Y_off = nl.matmul(barC_t, S, transpose_x=not transpose_broadcast) + return Y_diag + Y_off + + +@nki_jit +def mamba_kernel_(dt_tensor, logA_tensor, X_tensor, B_tensor, C_tensor, out_Y_tensor, D_tensor=None): + """ + dt_tensor: (batch_size, n_heads, seq_len) + logA_tensor: (n_heads) + X_tensor: (batch_size, seq_len, n_heads, d_head) + B_tensor: (batch_size, seq_len, n_groups, d_state) + C_tensor: (batch_size, seq_len, n_groups, d_state) + D_tensor: (n_heads) + """ + # Since this kernel requires high-precision, we run all internal computations in fp32. + # Note: the speedup by using bf16 everywhere would be ~15% + dtype = nl.float32 + block_size = 128 # we will split seq_len in chunks of size `block_size` + + batch_size, seq_len, n_heads, d_head = X_tensor.shape + _, _, n_groups, d_state = B_tensor.shape + assert seq_len % block_size == 0 + n_chunks = seq_len // block_size + n_heads_per_group = n_heads // n_groups + + batch_id = nl.program_id(0) + + i_p = nl.arange(block_size)[:, None] + i_f = nl.arange(block_size)[None, :] + i_f_state = nl.arange(d_state)[None, :] + i_f_head = nl.arange(d_head)[None, :] + + # creates a constant upper triangular matrix of ones + ones_triu = nisa.affine_select(i_p <= i_f, nl.ones((block_size, block_size), dtype=dtype), 0) + + for group_id in nl.affine_range(n_groups): # parallel for loop over each group (they are completely independent) + # === Preload/compute logA, B, C_t and B @ C_t ==== + # (they are shared between multiple heads in the same group) + B_cache = nl.ndarray((block_size, n_chunks, d_state), dtype=dtype) + # todo: storing in this format may be a bad idea if d_state != 128? + C_t_cache = nl.ndarray((d_state, n_chunks, block_size), dtype=dtype) + BC_t_cache = nl.ndarray((block_size, n_chunks, block_size), dtype=dtype) + for chunk_id in nl.affine_range(n_chunks): + i_p_in = i_p + chunk_id * block_size + B = nl.load(B_tensor[batch_id, i_p_in, group_id, i_f_state], dtype=dtype) + C = nl.load(C_tensor[batch_id, i_p_in, group_id, i_f_state], dtype=dtype) + C_t = nisa.nc_transpose(C) + B_cache[:, chunk_id, :] = B + C_t_cache[:, chunk_id, :] = C_t + BC_t_cache[:, chunk_id, :] = nl.copy(nl.matmul(B, C_t), dtype=dtype) + + logA_cache = nl.load(logA_tensor.reshape((1, n_heads)), dtype=dtype).broadcast_to((block_size, n_heads)) + if D_tensor is not None: + D = nl.load(D_tensor.reshape((1, n_heads)), dtype=dtype).broadcast_to((block_size, n_heads)) + else: + D = None + + # == Actual code === + for head_id_in_group in nl.affine_range(n_heads_per_group): # parallel for loop over the n_heads + # get the global head_id given current group and current head in group + head_id = group_id * n_heads_per_group + head_id_in_group + # We iterate over the diagonal blocks and compute each Y_diag + # At the same time, we update our running sum S and use it to compute Y_off. + # We store Y = Y_diag + Y_off, and we move to the next block + S = nl.zeros((d_state, d_head), dtype=dtype) + for chunk_id in nl.sequential_range(n_chunks): + i_p_in = i_p + chunk_id * block_size + + # broadcast dt and logA together + dt = nl.load(dt_tensor[batch_id, i_p_in, head_id], dtype=dtype) + logA = logA_cache[:, head_id] * dt + + # load from cache the relevant blocks + B = B_cache[:, chunk_id, :] + C_t = C_t_cache[:, chunk_id, :] + BC_t = BC_t_cache[:, chunk_id, :] + + # broadcast X and dt + X0 = nl.load(X_tensor[batch_id, i_p_in, head_id, i_f_head], dtype=dtype) + X = dt * X0 + + # Compute all logA related factors for this chunk + L_t, a_left, a_center, a_right = compute_a_factors(logA, ones_triu, block_size, d_state) + + Y = compute_chunk_output(BC_t, L_t, X, C_t, a_left, S) + + # Update running sum S (will be used in the next iteration) + barB = B * a_right + barBX = nl.matmul(barB, X, transpose_x=True) + S[...] = a_center * S + barBX + + if D is not None: + Y[...] = Y + D[:, head_id] * X0 + + nl.store(out_Y_tensor[batch_id, i_p_in, head_id, i_f_head], Y) + + +@nki_jit +def mamba_kernel_bwd_(dt_tensor, logA_tensor, X_tensor, B_tensor, C_tensor, + d_out_tensor, + ddt_tensor, + dlogA_tensor, + dX_tensor, + dB_tensor, + dC_tensor, + D_tensor=None, + dD_tensor=None + ): + """ + dt_tensor: (batch_size, seq_len, n_heads) + logA_tensor: (n_heads) + X_tensor: (batch_size, seq_len, n_heads, d_head) + B_tensor: (batch_size, seq_len, n_groups, d_state) + C_tensor: (batch_size, seq_len, n_groups, d_state) + D_tensor: (n_heads) + d_out_tensor: (batch_size, seq_len, n_heads, d_head) + All other derivative tensors (d_*) have the same shape as their corresponding input counterparts. + """ + + # Note: since saving the intermediate results of the forward pass would use too much memory, this kernel also + # recomputes the forward pass while computing the gradients. + + # Since this kernel requires high-precision, we run all internal computations in fp32. + # Note: the speedup by using bf16 everywhere would be ~15% + dtype = nl.float32 + block_size = 128 # we will split seq_len in chunks of size `block_size` + batch_size, seq_len, n_heads, d_head = X_tensor.shape + _, _, n_groups, d_state = B_tensor.shape + assert seq_len % block_size == 0 + n_chunks = seq_len // block_size + n_heads_per_group = n_heads // n_groups + + assert d_state == 128 + assert d_head <= 128 + assert block_size <= 128 + + batch_id = nl.program_id(0) + + i_p = nl.arange(block_size)[:, None] + i_f = nl.arange(block_size)[None, :] + i_f_state = nl.arange(d_state)[None, :] + i_f_head = nl.arange(d_head)[None, :] + + # upper triangular matrix of ones + ones_triu = nisa.affine_select(i_p <= i_f, nl.ones((block_size, block_size), dtype=dtype), 0) + ones_tril = nl.copy(nl.transpose(ones_triu), dtype=dtype) + ones_sum_right = nl.ones([d_state, 1], dtype=dtype) + ones_sum_left = nl.ones([1, d_state], dtype=dtype) + ones_sum_right_head = nl.ones([d_head, 1], dtype=dtype) + + for group_id in nl.affine_range(n_groups): # iterate in parallel over all channel groups (they are independent) + # Preload/compute logA, B, C_t and B @ C_t (which are shared between multiple heads in the same group) + B_cache = nl.ndarray((block_size, n_chunks, d_state), dtype=dtype) + C_t_cache = nl.ndarray((d_state, n_chunks, block_size), dtype=dtype) + BC_t_cache = nl.ndarray((block_size, n_chunks, block_size), dtype=dtype) + for chunk_id in nl.affine_range(n_chunks): + i_p_in = i_p + chunk_id * block_size + B = nl.load(B_tensor[batch_id, i_p_in, group_id, i_f_state], dtype=dtype) + C = nl.load(C_tensor[batch_id, i_p_in, group_id, i_f_state], dtype=dtype) + C_t = nisa.nc_transpose(C) + B_cache[:, chunk_id, :] = B + C_t_cache[:, chunk_id, :] = C_t + BC_t_cache[:, chunk_id, :] = nl.copy(nl.matmul(B, C_t), dtype=dtype) + + logA_cache = nl.load(logA_tensor.reshape((1, n_heads)), dtype=dtype).broadcast_to((block_size, n_heads)) + if D_tensor is not None: + D = nl.load(D_tensor.reshape((1, n_heads)), dtype=dtype).broadcast_to((block_size, n_heads)) + else: + D = None + + dC_accumulation = nl.zeros((block_size, n_chunks, d_state), dtype=dtype) + dB_accumulation = nl.zeros((block_size, n_chunks, d_state), dtype=dtype) + dA_final = nl.zeros((1, n_heads), dtype=dtype) + if D is not None: + dD_final = nl.zeros((1, n_heads), dtype=dtype) + + for head_id_in_group in nl.affine_range(n_heads_per_group): # the n_heads are completely independent + # get the global head_id given current group and current head in group + head_id = group_id * n_heads_per_group + head_id_in_group + dA_accumulation = nl.zeros((block_size, n_chunks, d_state), dtype=dtype) + S = nl.zeros((d_state, d_head), dtype=dtype) + for chunk_id in nl.sequential_range(n_chunks): + # + i_p_in = i_p + chunk_id * block_size + # broadcast dt and logA together + dt = nl.load(dt_tensor[batch_id, i_p_in, head_id]) + logA = logA_cache[:, head_id] * dt + # Compute all logA related factors for this chunk + L_t, a_left, a_center, a_right = compute_a_factors(logA, ones_triu, block_size, d_state, d_head=d_head) + # load from cache the relevant blocks + B = B_cache[:, chunk_id, :] + C = nl.load(C_tensor[batch_id, i_p_in, group_id, i_f_state]) + # broadcast X and dt + X0 = nl.load(X_tensor[batch_id, i_p_in, head_id, i_f_head]) + X = dt * X0 + # + + # compute dC gradient + dO = nl.load(d_out_tensor[batch_id, i_p_in, head_id, i_f_head]) + dO_t = nisa.nc_transpose(dO) + UdO_t = nl.matmul(X, dO_t) # (B, L, nheads, hdim) + S_t = nisa.nc_transpose(S) + dC = compute_chunk_output(UdO_t, L_t, B, dO_t, a_left, S_t) + + # + # Update the state: running sum S (will be used in the next iteration) + barB = B * a_right + barBX = nl.matmul(barB, X, transpose_x=True) + S[...] = a_center * S + barBX + # + dC_accumulation[:, chunk_id, :] += dC + dA_accumulation[:, chunk_id, :] = dA_accumulation[:, chunk_id, :] + C * dC + + dS = nl.zeros((d_state, d_head), dtype=dtype) + cumsum_dA = nl.zeros((1, d_state), dtype=dtype) + for chunk_id in nl.sequential_range(n_chunks): + chunk_id = n_chunks - 1 - chunk_id # To reverse time + i_p_in = i_p + chunk_id * block_size + + # === Recompute forward pass === + # broadcast dt and logA together + dt = nl.load(dt_tensor[batch_id, i_p_in, head_id]) + logA = logA_cache[:, head_id] * dt + # Compute all logA related factors for this chunk + L_t, a_left, a_center, a_right = compute_a_factors(logA, ones_triu, block_size, d_state, d_head=d_head) + # load from cache the relevant blocks + B = B_cache[:, chunk_id, :] + C = nl.load(C_tensor[batch_id, i_p_in, group_id, i_f_state]) + C_t = nisa.nc_transpose(C) + BC_t = BC_t_cache[:, chunk_id, :] + # broadcast X and dt + X0 = nl.load(X_tensor[batch_id, i_p_in, head_id, i_f_head]) + X = dt * X0 + + # === Compute dX gradient === + dO = nl.load(d_out_tensor[batch_id, i_p_in, head_id, i_f_head]) + dU = compute_chunk_output(BC_t, L_t, dO, B, a_right, dS, transpose_gate=True, transpose_broadcast=True) + + # === Compute dB gradient === + X_t = nisa.nc_transpose(X) + dO_Xt = nl.matmul(dO, X_t) + L_t_ = nisa.nc_transpose(L_t) + dS_t = nisa.nc_transpose(dS) + + C = nl.load(C_tensor[batch_id, i_p_in, group_id, i_f_state], dtype=dtype) + + # === Compute dB gradient === + dB = compute_chunk_output(dO_Xt, L_t_, C, X, a_right, dS_t, transpose_broadcast=True) + + # === Update reverse time state dState === + # Update the state: running sum dS (will be used in the next iteration) + barC = C_t * a_left[:1, :].broadcast_to((block_size, block_size)) + + barC_tX = nl.matmul(barC, dO, transpose_x=False) + dS[...] = a_center * dS + barC_tX + + dB_accumulation[:, chunk_id, :] += dB + dA_accumulation[:, chunk_id, :] -= B * dB + + # === Reverse cumulative sum for dA === + cumsum_chunk = nl.matmul(ones_tril, dA_accumulation[:, chunk_id, :], transpose_x=True) + cumsum_chunk[...] = cumsum_chunk + nl.copy(cumsum_dA, dtype=dtype).broadcast_to((block_size, d_state)) + cumsum_dA[0, i_f_state] = cumsum_chunk[0, i_f_state] + + ddt = nl.matmul(cumsum_chunk * logA_cache[:, head_id], ones_sum_right) + nl.matmul(dU * X0, + ones_sum_right_head) + + dA_chunk = nl.matmul(cumsum_chunk * dt, ones_sum_right) + dA_final[:, head_id] += nl.matmul(ones_sum_left, dA_chunk) + + dX = dU * dt + + if D is not None: + dD_chunk = nl.matmul(dO * X0, ones_sum_right_head) + dD_final[:, head_id] += nl.copy(nl.matmul(ones_sum_left, dD_chunk), dtype=dtype) + dX[...] = dX + dO * D[:, head_id] + + nl.store(dX_tensor[batch_id, i_p_in, head_id, i_f_head], dX) + nl.store(ddt_tensor[batch_id, i_p_in, head_id], ddt) + + nl.store(dlogA_tensor[batch_id, head_id], dA_final[0, head_id]) + if D is not None: + nl.store(dD_tensor[batch_id, head_id], dD_final[0, head_id]) + + for chunk_id in nl.sequential_range(n_chunks): + i_p_in = i_p + chunk_id * block_size + nl.store(dC_tensor[batch_id, i_p_in, group_id, i_f_state], dC_accumulation[:, chunk_id, :]) + nl.store(dB_tensor[batch_id, i_p_in, group_id, i_f_state], dB_accumulation[:, chunk_id, :]) + + +class Mamba2Kernel(torch.autograd.Function): + """Define the autograd function wih forward and backward kernel.""" + + @staticmethod + def forward(ctx, dt, A, X, B, C, D): + batch_size, seq_len, n_heads, d_head = X.shape + ctx.save_for_backward(dt, A, X, B, C, D) + out_Y = torch.empty_like(X) + mamba_kernel_[batch_size](dt, A, X, B, C, out_Y, D) + return out_Y + + @staticmethod + def backward(ctx, d_output): + dt, A, X, B, C, D = ctx.saved_tensors + batch_size, seq_len, n_heads, d_head = X.shape + + ddt = torch.empty_like(dt) + dA = torch.empty_like(A.unsqueeze(0).repeat(batch_size, 1)) + dX = torch.empty_like(X) + dB = torch.empty_like(B) + dC = torch.empty_like(C) + dD = torch.empty_like(D.unsqueeze(0).repeat(batch_size, 1)) + + mamba_kernel_bwd_[batch_size](dt, A, X, B, C, d_output, + ddt, dA, dX, dB, dC, D, dD) + dA, dD = dA.sum(0), dD.sum(0) + return ddt, dA, dX, dB, dC, dD + + +def mamba2_kernel(dt, A, X, B, C, D): + return Mamba2Kernel.apply(dt, A, X, B, C, D) diff --git a/examples/training/mamba2/mamba2/mamba2_mixer.py b/examples/training/mamba2/mamba2/mamba2_mixer.py new file mode 100644 index 0000000..af30670 --- /dev/null +++ b/examples/training/mamba2/mamba2/mamba2_mixer.py @@ -0,0 +1,242 @@ +# coding=utf-8 +# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. +# Modifications Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +from torch import nn +from torch.types import Device +from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache +from transformers.activations import ACT2FN +from einops import rearrange +import neuronx_distributed.parallel_layers.parallel_state as ps +from neuronx_distributed.parallel_layers import RowParallelLinear, ColumnParallelLinear + +from .configuration_mamba2 import Mamba2Config + +from .mamba2_kernel import mamba2_kernel +from .conv1d_grouped import ConvNKI + + +def softplus(x, threshold=10): + return torch.where(x < threshold, torch.log(1 + torch.exp(x)), x) + + +# Note: this implementation is different from the same module in `transformers` when n_groups>1 +# we normalize each channel group independently, while the original normalizes all channels in the same device +# regardless of their group. Our version ensures that the checkpoint will behave the same when used with +# a different tp degrees than during training (note however that using a large n_groups with few channels may +# introduce training instabilities). +class MambaRMSNormGated(nn.Module): + def __init__(self, d: int, eps: float = 1e-5, device: Device = None, n_groups: int = 1, rmsnorm_within_groups=True): + """Gated Root Mean Square Layer Normalization with support for groups + + Paper: https://arxiv.org/abs/1910.07467 + + Mamba Official: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/layernorm_gated.py#L18 + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(d, device=device)) + self.n_groups = n_groups + self.rmsnorm_within_groups = rmsnorm_within_groups + self.parallel_split() + + def parallel_split(self): + # Split weights across cores based on the current tensor parallelism rank. + tp_rank = ps.get_tensor_model_parallel_rank() + tp_size = ps.get_tensor_model_parallel_size() + dim = self.weight.shape[0] + assert dim % tp_size == 0 + assert self.n_groups % tp_size == 0 + self.n_groups = self.n_groups // tp_size + chunk = slice(dim // tp_size * tp_rank, dim // tp_size * (tp_rank + 1)) + self.weight.data = self.weight.data[chunk].detach().clone() + return self + + def forward(self, hidden_states, gate=None): + hidden_states = hidden_states.to(torch.float32) + + if self.rmsnorm_within_groups: + hidden_states = rearrange(hidden_states, "... (g d) -> ... g d", g=self.n_groups) + gate = rearrange(gate, "... (g d) -> ... g d", g=self.n_groups) + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) + + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.rmsnorm_within_groups: + res = self.weight * rearrange(hidden_states, "... g d -> ... (g d)", g=self.n_groups) + else: + res = self.weight * hidden_states + + return res + + +class Mamba2Mixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: Mamba2Config, layer_idx: int): + super().__init__() + self.num_heads = config.num_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = int(config.expand * self.hidden_size) + self.time_step_rank = int(config.time_step_rank) + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + + self.layer_norm_epsilon = config.layer_norm_epsilon + self.rms_norm = config.rms_norm + self.rmsnorm_within_groups = config.rmsnorm_within_groups + + self.n_groups = config.n_groups + self.head_dim = config.head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max + + assert self.intermediate_size % self.head_dim == 0 + assert self.intermediate_size // self.head_dim == self.num_heads + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + # This is a custom replacement of a grouped conv1d written as a NKI kernel for better efficiency. + # Note: the SiLU non-linearity is already applied inside the kernel. + self.conv1d = ConvNKI( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.conv_dim, + padding=config.conv_kernel - 1, + activation='silu', + ) + + # projection of the input hidden states + self.in_proj_z = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.use_bias, gather_output=False) + self.in_proj_xBC = ColumnParallelLinear(self.hidden_size, self.conv_dim, bias=config.use_bias, gather_output=False) + self.in_proj_dt = ColumnParallelLinear(self.hidden_size, self.num_heads, bias=config.use_bias, gather_output=False) + + # time step projection (discretization) + dt = torch.exp( + torch.rand(config.num_heads) + * (math.log(config.time_step_max) - math.log(config.time_step_min)) + + math.log(config.time_step_min) + ).clamp(min=config.time_step_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + self.dt_bias._no_reinit = True + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon, n_groups=self.n_groups, rmsnorm_within_groups=self.rmsnorm_within_groups) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = RowParallelLinear(self.intermediate_size, self.hidden_size, bias=config.use_bias, input_is_parallel=True) + self.use_bias = config.use_bias + self.parallel_split() + + def parallel_split(self): + # Split weights across cores based on the current tensor parallelism rank. + tp_rank = ps.get_tensor_model_parallel_rank() + tp_size = ps.get_tensor_model_parallel_size() + assert self.intermediate_size % tp_size == 0 + assert self.n_groups % tp_size == 0 + self.intermediate_size_tp = self.intermediate_size // tp_size + self.n_groups_tp = self.n_groups // tp_size + self.num_heads_tp = self.num_heads // tp_size + self.conv_dim_tp = self.conv_dim // tp_size + head_chunk = slice(self.num_heads_tp * tp_rank, self.num_heads_tp * (tp_rank + 1)) + # note: we have to use .clone(), otherwise the result would be a view and the original would remain in memory + self.D.data = self.D.data[head_chunk].detach().clone() + self.A_log.data = self.A_log.data[head_chunk].detach().clone() + self.dt_bias.data = self.dt_bias.data[head_chunk].detach().clone() + return self + + def nki_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size_tp = self.n_groups_tp * self.ssm_state_size + + assert cache_params is None, "cache not supported yet" + assert self.training, "only training supported right now" + assert attention_mask is None, "attention mask not supported yet" + assert self.time_step_limit[0] == 0.0 and self.time_step_limit[1] == float("inf"), "dt limit not supported yet" + assert self.activation in ["silu", "swish"] + + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + + gate = self.in_proj_z(hidden_states) + hidden_states_B_C = self.in_proj_xBC(hidden_states) + time_step = self.in_proj_dt(hidden_states) + + # 1D Convolution (SiLU non-linearity is fused inside) + hidden_states_B_C = self.conv1d(input=hidden_states_B_C.transpose(1, 2)).transpose(1, 2) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size_tp, groups_time_state_size_tp, groups_time_state_size_tp], + dim=-1, + ) + + time_step = softplus(time_step + self.dt_bias) + + scan_output = mamba2_kernel(time_step, + A, + hidden_states.view(batch_size, seq_len, self.num_heads_tp, -1), + B.view(batch_size, seq_len, self.n_groups_tp, -1), + C.view(batch_size, seq_len, self.n_groups_tp, -1), + self.D) + + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + out = self.out_proj(scan_output) + + return out + + def forward( + self, + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + assert "xla" in self.in_proj_xBC.weight.device.type, "This model only supports forward on an XLA device" + return self.nki_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) diff --git a/examples/training/mamba2/mamba2/modeling_mamba2.py b/examples/training/mamba2/mamba2/modeling_mamba2.py new file mode 100644 index 0000000..ad491ee --- /dev/null +++ b/examples/training/mamba2/mamba2/modeling_mamba2.py @@ -0,0 +1,470 @@ +# coding=utf-8 +# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. +# Modifications Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Adaptation of transformers.models.mamba2.modeling_mamba2 to be compatible with neuronx-distributed. + + The class Mamba2Mixer has been moved to a separate file mamba2_mixer.py. +""" + +import math +from dataclasses import dataclass +from functools import partial +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from neuronx_distributed.parallel_layers import ParallelEmbedding, RowParallelLinear, ColumnParallelLinear +from neuronx_distributed.parallel_layers.loss_functions import parallel_cross_entropy +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.types import Device +import torch.distributed as distrib + + +from transformers.generation import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) + +from .configuration_mamba2 import Mamba2Config +from .mamba2_mixer import Mamba2Mixer + +from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2RMSNorm, MAMBA2_START_DOCSTRING, \ + MAMBA2_INPUTS_DOCSTRING, Mamba2CausalLMOutput + +logger = logging.get_logger(__name__) + +selective_state_update = None + +causal_conv1d_update, causal_conv1d_fn = None, None + + +_CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1" +_CONFIG_FOR_DOC = "Mamba2Config" + +def _init_normal(std, w): + return nn.init.normal_(w, mean=0.0, std=std) + +class PartitionBreak(torch.autograd.Function): + """Workaround to help the compiler detect layers. + + The compiler looks for all_gather/batch_norm/clamp to separate layers. Since those are not present in our model, + we insert a fake torch.clamp operation with inf boundary (so it acts like the identity) to ensure it will try to + break the layer at that point. + """ + + @staticmethod + def forward(ctx, input): + return torch.clamp(input, min=None, max=torch.inf) + + @staticmethod + def backward(ctx, d_output): + return torch.clamp(d_output, min=None, max=torch.inf) + +def partition_break(x): + return PartitionBreak.apply(x) + +class Mamba2Block(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = Mamba2Mixer(config, layer_idx=layer_idx) + + def forward( + self, + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask + ) + hidden_states = residual + hidden_states + return hidden_states + + +class Mamba2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Mamba2Config + base_model_prefix = "backbone" + _no_split_modules = ["Mamba2Block"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, Mamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + elif isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + # fixme: this won't give consistent initialization with tp + nn.init.normal_(module.weight, std=self.config.initializer_range) + elif isinstance(module, (ParallelEmbedding, RowParallelLinear, ColumnParallelLinear)): + # module.init_weight_cpu() + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_hidden_layers) + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2 +class Mamba2Output(ModelOutput): + """ + Class for the MAMBA2 model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`Mamba2Cache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[Mamba2Cache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@add_start_docstrings( + "The bare MAMBA2 Model transformer outputting raw hidden-states without any specific head on top.", + MAMBA2_START_DOCSTRING, +) +class Mamba2Model(Mamba2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + init_method = partial(_init_normal, config.initializer_range) + self.embeddings = ParallelEmbedding(config.vocab_size, config.hidden_size, init_method=init_method) + + self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + @add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Mamba2Output, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[Mamba2Cache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, Mamba2Output]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if use_cache: + if cache_params is None: + cache_params = Mamba2Cache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) + elif cache_position is None: + # cases when we do manual forward instead of using `model.generate` which will initiate + # `cache_position` and makes sure it is not None, throw error here instead of doing some + # hack to conjecture the current cache position + raise ValueError( + "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, " + "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will " + "be initialized for you automatically" + ) + else: + cache_params = None + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if self.gradient_checkpointing and self.training: + raise NotImplementedError("gradient checkpointing is not implemented yet") + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask + ) + else: + hidden_states = partition_break(mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + )) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if use_cache: + cache_params.seqlen_offset += inputs_embeds.shape[1] + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return Mamba2Output( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +@add_start_docstrings( + """ + The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input + embeddings). + """, + MAMBA2_START_DOCSTRING, +) +class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin): + _tied_weights_keys = [] + + def __init__(self, config): + super().__init__(config) + self.backbone = Mamba2Model(config) + # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + init_method = partial(_init_normal, config.initializer_range) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + bias=False, + gather_output=False, + init_method=init_method + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def prepare_inputs_for_generation( + self, + input_ids, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + # Overwitten -- uses `cache_params` as opposed to `past_key_values` + + if inputs_embeds is not None: + past_len = inputs_embeds.shape[1] + input_ids.shape[1] + else: + past_len = input_ids.shape[1] + if use_cache: + # `cache_position` should have been initialized in `generate` + if cache_position is None: + raise ValueError( + "`cache_position` should not be None as it should have been initialized in " + "`model.generate`, you are responsible for passing in a valid `cache_position` if " + "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" + ) + # how do we detect that we are in decoding without cache? + if cache_position[0] > 0: + input_ids = input_ids[:, -1][..., None] + attention_mask = attention_mask[:, -1][..., None] + else: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange(0, past_len, device=input_ids.device) + # if the cache is not used, we also do have to extend the attention mask here + # TODO there is likely a cleverer way to do this + extended_mask = torch.ones( + attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device + ) + attention_mask = torch.cat([attention_mask, extended_mask], dim=1) + cache_params = None + + if attention_mask.shape[1] < past_len: + # we have to update manually the attention mask if + # we are in decoding without cache + # and we don't have position_ids here + # TODO but we should be able to use cache_position though at a later time + extended_mask = torch.ones( + attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device + ) + attention_mask = torch.cat([attention_mask, extended_mask], dim=1) + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "cache_params": cache_params, + "use_cache": use_cache, + "cache_position": cache_position, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Mamba2CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[Mamba2Cache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, Mamba2CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + mamba2_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = mamba2_outputs[0] + + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = parallel_cross_entropy + # Flatten the tokens + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = loss.mean() + + if not return_dict: + output = (logits,) + mamba2_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Mamba2CausalLMOutput( + loss=loss, + logits=logits, + cache_params=mamba2_outputs.cache_params, + hidden_states=mamba2_outputs.hidden_states, + ) diff --git a/examples/training/mamba2/predefined_configs.py b/examples/training/mamba2/predefined_configs.py new file mode 100644 index 0000000..a342057 --- /dev/null +++ b/examples/training/mamba2/predefined_configs.py @@ -0,0 +1,43 @@ +"""Utility script to create a Mamba2Config given the size of the model.""" + +from typing import Dict + +from mamba2 import Mamba2Config +from dataclasses import dataclass, astuple + + +@dataclass +class ConfParams: + d_model: int + n_layers: int + head_dim: int = 128 + + +CONFIGS_KWARGS: Dict[str, ConfParams] = { + 'Mamba130M': ConfParams(d_model=768, n_layers=24), + 'Mamba370M': ConfParams(d_model=1024, n_layers=48), + 'Mamba780M': ConfParams(d_model=1536, n_layers=48), + 'Mamba1B': ConfParams(d_model=2048, n_layers=48), + 'Mamba3B': ConfParams(d_model=2560, n_layers=64), + 'Mamba7B': ConfParams(d_model=4096, n_layers=64), +} + + +def get_config(name: str, vocab_size, rmsnorm_within_groups=True, n_groups=8): + d_model, n_layers, head_dim = astuple(CONFIGS_KWARGS[name]) + config = Mamba2Config( + vocab_size=vocab_size, + hidden_size=d_model, + head_dim=head_dim, + num_heads=(d_model * 2) // head_dim, + num_hidden_layers=n_layers, + tie_word_embeddings=True, + use_cache=False, + n_groups=n_groups, + bos_token_id=0, + eos_token_id=0, + pad_token_id=0, + rmsnorm_within_groups=rmsnorm_within_groups, + ) + return config + diff --git a/examples/training/mamba2/train.py b/examples/training/mamba2/train.py new file mode 100644 index 0000000..4b47053 --- /dev/null +++ b/examples/training/mamba2/train.py @@ -0,0 +1,270 @@ +import os + +import torch +import torch.distributed as dist +import time + +from transformers import AdamW + +import neuronx_distributed as nxd +import neuronx_distributed.parallel_layers.parallel_state as ps +import torch_xla.distributed.parallel_loader as pl +from neuronx_distributed.utils.adamw_fp32_optim_params import AdamW_FP32OptimParams +from transformers.optimization import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup + +from mamba2 import Mamba2ForCausalLM +from training_utils import create_llama_pretraining_dataset +from predefined_configs import get_config + +import transformers.modeling_utils as modeling_utils + +# For PT autocast. +torch.cuda.is_bf16_supported = lambda: True + +# Environment variables set by torch.distributed.launch +LOCAL_RANK = int(os.environ['LOCAL_RANK']) +WORLD_SIZE = int(os.environ['WORLD_SIZE']) +WORLD_RANK = int(os.environ['RANK']) + + +def get_mixed_precision_config(args): + if args.use_zero_1: + return { + "use_master_weights": True, + "use_fp32_grad_acc": True, + "use_master_weights_in_ckpt": False, + } + else: + return {} + + +def run(args, backend): + import numpy as np + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + device = xm.xla_device() + + config = get_config(args.model, args.vocab_size, rmsnorm_within_groups=(not args.rmsnorm_across_groups), + n_groups=args.tp) + + def get_model(): + model = Mamba2ForCausalLM(config).to(dtype=args.dtype) + model.train() + + # check that weight tying worked + if model.config.tie_word_embeddings: + assert model.backbone.embeddings.weight is model.lm_head.weight + return model + + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=args.tp, + optimizer_config={"zero_one_enabled": args.use_zero_1, "grad_clipping": True, "max_grad_norm": 1.0}, + sequence_parallel=False, + model_init_config=None, + mixed_precision_config=get_mixed_precision_config(args), + ) + model = nxd.initialize_parallel_model(nxd_config, get_model) + world_size = ps.get_data_parallel_size() + if xm.is_master_ordinal(): + print('NEURON_CC_FLAGS: ', os.environ.get('NEURON_CC_FLAGS', None)) + print('XLA_IR_DEBUG: ', os.environ.get('XLA_IR_DEBUG', None)) + print('XLA_HLO_DEBUG: ', os.environ.get('XLA_HLO_DEBUG', None)) + print('TP groups:', ps.get_tensor_model_parallel_group(as_list=True)) + print('DP groups:', ps.get_data_parallel_group(as_list=True)) + print('Config: ', config) + param_size, dtype = 0, None + for param in set(model.parameters()): + param_size += param.nelement() + dtype = param.dtype + print(f"Model size: {param_size / 10 ** 6:.1f}M parameters/core") + print(f"Param dtype: {dtype}") + + param_optimizer = list(model.named_parameters()) + + no_decay = ["bias", "LayerNorm", "norm", "A", "D"] + + optimizer_grouped_parameters = [ + { + "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + if args.no_optimizer_fp32: + optimizer_cls = AdamW + else: + optimizer_cls = AdamW_FP32OptimParams + + # Creating NxD Optimizer + optimizer = nxd.initialize_parallel_optimizer( + nxd_config, + optimizer_cls, + optimizer_grouped_parameters, + lr=args.lr, + betas=(args.beta1, args.beta2), + ) + optimizer.zero_grad() + if args.use_zero_1: + optimizer.optimizer.init_zero() + + scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=args.warmup_steps, + num_training_steps=args.max_steps, + last_epoch=-1, + ) + + train_dataloader, _ = create_llama_pretraining_dataset( + args.data_dir, + args.batch, + ps.get_data_parallel_size(), + ps.get_data_parallel_rank(), + args.seed, + ) + # We wrap the dataloader with MpDeviceLoader. This dataloader should take + # care of copying the tensors to device and also inserting the mark_step at + # iteration end. + train_device_loader = pl.MpDeviceLoader(train_dataloader, device) + + running_loss = torch.zeros(1).to(device) + training_step, global_step = 0, 0 + t0 = time.time() + xm.mark_step() + for i, data in enumerate(train_device_loader): + training_step += 1 + input_ids = data["input_ids"] + labels = data["labels"] + + out = model(input_ids=input_ids, labels=labels) + loss = out.loss / args.grad_accum_usteps + loss.backward() + xm.mark_step() + + running_loss += loss.detach() + + if training_step % args.grad_accum_usteps == 0: + xm.mark_step() + if xm.is_master_ordinal(): + print(f"Global Step {global_step}") + print(f"Loss: {loss.item()}") + print(f"running_loss: {running_loss.item()}") + + running_loss.zero_() + optimizer.step() + optimizer.zero_grad() + scheduler.step() + global_step += 1 + + xm.mark_step() + if xm.is_master_ordinal(): + print(f'process time of this batch is: {time.time() - t0}') + t0 = time.time() + + if (args.checkpoint_freq > 0) and (global_step % args.checkpoint_freq == 0): + xm.add_step_closure( + nxd.save_checkpoint, ( + args.checkpoint_dir, # checkpoint directory + f"{args.tag}_step_{global_step}", # tag + model, # model + optimizer, # optimizer + scheduler, # scheduler + {"global_step": global_step, "cli_args": args.__dict__}, # user content + ) + ) + + if global_step >= args.max_steps: + xm.mark_step() + break + # xm.mark_step() # final mark_step not needed when using MpDeviceLoader + + +def init_processes(args, backend): + dist.init_process_group(backend, rank=WORLD_RANK, world_size=WORLD_SIZE) + run(args=args, backend=backend) + xm.rendezvous("_mp_fn finished") + + +def tp_loader(state_dict, tp_rank, tp_size, config): + """Load the correct slice of weights from a checkpoint for the current core give the tensor parallel degree.""" + new_state_dict = {} + for k, v in state_dict.items(): + if k.endswith('out_proj.weight'): # row parallel + dim_1_shape = v.shape[1] + cv = torch.split(v, dim_1_shape // tp_size, dim=1) + new_state_dict[k] = cv[tp_rank] + elif k.endswith('in_proj_xBC.weight') or 'conv1d' in k: # xBC and Conv Col para + wx, wB, wC = torch.split(v, [config.hidden_size * 2, config.n_groups * config.state_size, + config.n_groups * config.state_size], + dim=0) + wx_tp = torch.split(wx, wx.shape[0] // tp_size, dim=0)[tp_rank] + wB_tp = torch.split(wB, wB.shape[0] // tp_size, dim=0)[tp_rank] + wC_tp = torch.split(wC, wC.shape[0] // tp_size, dim=0)[tp_rank] + xBC_tp = torch.cat((wx_tp, wB_tp, wC_tp), dim=0) + new_state_dict[k] = xBC_tp + elif 'norm' in k and 'mixer' not in k: + new_state_dict[k] = v + else: # norm weight and z and dt + dim_0_shape = v.shape[0] + rv = torch.split(v, dim_0_shape // tp_size, dim=0) + new_state_dict[k] = rv[tp_rank] + + return new_state_dict + + +if __name__ == '__main__': + import torch_xla.core.xla_model as xm + import os + import argparse + + parser = argparse.ArgumentParser(description="Mamba2Block Configuration") + parser.add_argument("--seed", type=int, default=100, help="Random seed") + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"], help="Data type") + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--vocab_size", type=int, default=50272, help="Vocab size") + parser.add_argument("--model", default="Mamba130M", help="Hugging face model to profile") + parser.add_argument("--backend", type=str, default="xla", choices=['xla', 'nccl', 'gloo']) + parser.add_argument("--tp", type=int, default=1, help="Tensor Parallel Size") + + parser.add_argument("--no_optimizer_fp32", action="store_true", help="Do not use FP32 for the optimizer state.") + parser.add_argument("--use_zero_1", action="store_true", help="Use ZeRO-1.") + parser.add_argument("--lr", type=float, default=4e-4, help="Learning rate.") + parser.add_argument("--weight_decay", default=0.01, type=float, help="weight decay") + parser.add_argument("--beta1", default=0.9, type=float, help="beta1 parameter for Adam optimizer") + parser.add_argument("--beta2", default=0.999, type=float, help="beta2 parameter for Adam optimizer") + parser.add_argument("--data_dir", type=str, help="Pre-tokenized dataset directory.") + + parser.add_argument("--warmup_steps", type=int, default=2000, + help="Number of warmup accumulation-steps for learning rate .") + parser.add_argument("--max_steps", type=int, help="Maximum total accumulation-steps to run.") + parser.add_argument("--grad_accum_usteps", type=int, default=1, + help="Gradient accumulation micro-steps (an accumulation-step has micro-steps.") + parser.add_argument("--rmsnorm_across_groups", action="store_true", + help="Uses (HF style) RMSNorm instead of the custom one that normalizes independently for each of the n_groups.") + parser.add_argument("--debug", "-d", action="store_true", + help="Enable Neuron debugging flags to dump model graph and compiler logs.") + parser.add_argument("--checkpoint_freq", type=int, help="ckpt save freq.") + parser.add_argument("--checkpoint_dir", type=str, default="./", help="ckpt saving dir") + parser.add_argument("--tag", type=str, default="exp", help="ckpt saving name") + + args = parser.parse_args() + + os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "0" + args.dtype = getattr(torch, args.dtype) + + if args.dtype == torch.bfloat16: + modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16 + + if args.debug: + # Debug flags to dump the annotated HLO graph, useful for profiling + os.environ["XLA_IR_DEBUG"] = "1" + os.environ["XLA_HLO_DEBUG"] = "1" + os.environ["NEURON_FRAMEWORK_DEBUG"] = "1" + + os.environ["NEURON_CC_FLAGS"] = " --model-type=transformer -O1" + + init_processes(args, backend=args.backend) diff --git a/examples/training/mamba2/training_utils.py b/examples/training/mamba2/training_utils.py new file mode 100644 index 0000000..91c1778 --- /dev/null +++ b/examples/training/mamba2/training_utils.py @@ -0,0 +1,359 @@ +import json +import math +import os +import queue +import time +from datetime import datetime, timezone +from functools import partial +from itertools import chain +from typing import Any, Dict, List + +import datasets +import torch +from torch.utils.data import DistributedSampler +from torch.utils.data.dataloader import DataLoader +from transformers import default_data_collator, set_seed + +try: + from lr import CosineAnnealing +except ImportError: + CosineAnnealing = None + +from collections import namedtuple + +Metric = namedtuple("Metric", ["name", "value", "units", "additional_data"]) +remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []} + + +# empty list to save remainder from batches to use in next batch +def pack_dataset(dataset, chunk_length=2048): + print(f"Chunking dataset into chunks of {chunk_length} tokens.") + + def chunk(sample, chunk_length=chunk_length): + # define global remainder variable to save remainder from batches to use in next batch + global remainder + # Concatenate all texts and add remainder from previous batch + concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()} + concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()} + # get total number of tokens for batch + batch_total_length = len(concatenated_examples[list(sample.keys())[0]]) + + # get max number of chunks for batch + if batch_total_length >= chunk_length: + batch_chunk_length = (batch_total_length // chunk_length) * chunk_length + + # Split by chunks of max_len. + result = { + k: [t[i: i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)] + for k, t in concatenated_examples.items() + } + # add remainder to global variable for next batch + remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()} + # prepare labels + result["labels"] = result["input_ids"].copy() + return result + + # tokenize and chunk dataset + lm_dataset = dataset.map( + partial(chunk, chunk_length=chunk_length), + batched=True, + ) + print(f"Total number of samples: {len(lm_dataset)}") + return lm_dataset + + +def get_learning_rate_scheduler(optimizer, args, last_epoch=-1): + lr_scheduler = CosineAnnealing( + optimizer, + max_steps=args.max_steps, + min_lr=args.min_lr, + warmup_steps=args.warmup_steps, + constant_steps=args.constant_steps, + last_epoch=last_epoch, + ) + return lr_scheduler + + +def get_param_groups_by_weight_decay(model): + """Get param groups.""" + if hasattr(model, "local_named_parameters") and hasattr(model, "partitioned") and model.partitioned: + # Zero1 use the first param in opt to decide the device + param_optimizer = list(model.local_named_parameters()) + else: + param_optimizer = list(model.named_parameters()) + no_decay = ["bias", "LayerNorm"] # gamma/beta are in LayerNorm.weight + + optimizer_grouped_parameters = [ + { + "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], + "weight_decay": 0.01, + }, + { + "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +def create_llama_pretraining_dataset(data_dir, mini_batch_size, dp_size, dp_rank, seed): + # Workaround because python functions are not picklable + class WorkerInitObj(object): + def __init__(self, seed): + self.seed = seed + + def __call__(self, id): + set_seed(self.seed) + + worker_init = WorkerInitObj(seed) + train_data = datasets.load_from_disk(data_dir) + train_sampler = DistributedSampler( + train_data, + num_replicas=dp_size, + rank=dp_rank, + shuffle=False, + drop_last=True, + ) + train_dataloader = DataLoader( + train_data, + collate_fn=default_data_collator, + sampler=train_sampler, + batch_size=mini_batch_size, + num_workers=0, + worker_init_fn=worker_init, + drop_last=True, + pin_memory=True, + ) + return train_dataloader, None + + +def create_instruction_based_dataset(data_dir, mini_batch_size, dp_size, dp_rank, seed, tokenizer=None, task=None): + raw_datasets = datasets.load_dataset(data_dir, split="train") + if task: + raw_datasets = raw_datasets.filter(lambda example: example["category"] == task) + train_and_test_dataset = raw_datasets.train_test_split(test_size=8) + train_dataset = train_and_test_dataset["train"] + test_dataset = train_and_test_dataset["test"] + + def preprocess_train_dataset(sample): + instruction = f"### Instruction\n{sample['instruction']}" + context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None + response = f"### Answer\n{sample['response']}" + # join all the parts together + prompt = "\n".join([i for i in [instruction, context, response] if i is not None]) + model_input = tokenizer(f"{prompt}{tokenizer.eos_token}") + return model_input + + train_data = train_dataset.shuffle().map(preprocess_train_dataset, remove_columns=train_dataset.column_names) + train_data = pack_dataset(train_data, chunk_length=2048) + + class WorkerInitObj(object): + def __init__(self, seed): + self.seed = seed + + def __call__(self, id): + set_seed(self.seed) + + worker_init = WorkerInitObj(seed) + + train_sampler = DistributedSampler( + train_data, + num_replicas=dp_size, + rank=dp_rank, + shuffle=True, + drop_last=True, + ) + train_dataloader = DataLoader( + train_data, + collate_fn=default_data_collator, + sampler=train_sampler, + batch_size=mini_batch_size, + num_workers=0, + worker_init_fn=worker_init, + drop_last=True, + pin_memory=True, + ) + + def preprocess_test_dataset(sample): + instruction = f"### Instruction\n{sample['instruction']}" + context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None + response = "### Answer\n" + # join all the parts together + prompt = "\n".join([i for i in [instruction, context, response] if i is not None]) + model_input = tokenizer(prompt, add_special_tokens=False) + labels = tokenizer(sample["response"], add_special_tokens=False) + model_input["labels"] = labels["input_ids"] + return model_input + + test_data = test_dataset.map(preprocess_test_dataset, remove_columns=test_dataset.column_names) + + test_sampler = DistributedSampler( + test_data, + num_replicas=dp_size, + rank=dp_rank, + shuffle=False, + drop_last=False, + ) + test_dataloader = DataLoader( + test_data, + collate_fn=default_data_collator, + sampler=test_sampler, + batch_size=mini_batch_size, + num_workers=0, + drop_last=False, + pin_memory=True, + ) + + return train_dataloader, test_dataloader + + +def create_partition(num_hidden_layers, pipeline_parallel_size): + """ + Evenly split the transformer layers between the PP ranks + """ + assert num_hidden_layers % pipeline_parallel_size == 0 + num_layer_per_partition = num_hidden_layers // pipeline_parallel_size + pipeline_cuts = [] + current_cut = num_layer_per_partition - 1 + for i in range(pipeline_parallel_size - 1): + pipeline_cuts.append(f"model.layers.{current_cut}") + current_cut += num_layer_per_partition + return pipeline_cuts + + +def get_sin_cos_matrix(config): + head_dim = config.hidden_size // config.num_attention_heads + base = config.rope_theta + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + t = torch.arange(config.max_position_embeddings, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos()[None, None, :, :].to(torch.float32), emb.sin()[None, None, :, :].to(torch.float32) + + +def get_dtype(model) -> str: + """ + Reference: https://pytorch.org/xla/release/1.12/index.html#xla-tensors-and-bfloat16 + """ + if "XLA_USE_BF16" in os.environ: + return "torch.bfloat16" + if "XLA_DOWNCAST_BF16" in os.environ: + if "torch.float" in str(model.dtype): + return "torch.bfloat16" + if "torch.double" in str(model.dtype): + return "torch.float32" + return str(model.dtype) + + +def print_logs(loss, global_norm, args, throughput, logger, total_steps, current_lr, input_ids, start): + total_norm_cpu = global_norm.cpu().item() + logger.log(total_steps, loss, total_norm_cpu, current_lr, input_ids, throughput, start) + + +class TrainingMetrics: + """ + This class is used for logging metrics to a json file. One can provide a + dictionary of metrics that needs to be stored, and it wpuld get + written to the file. + Arguments: + json_file: File used for logging. If no file exists, new file would be created. + """ + + def __init__(self, json_file): + self.json_file = json_file + + def read_modify_write_file(self, data, key: str = "metrics") -> None: + """ + data (dict of training parameters or list of metrics): Data to update in the file. + key (str): the dictionary key under which data is to be recorded + """ + result_dict = {} + print(f"Writing data to the provided results file: {self.json_file}") + if os.path.exists(self.json_file): + with open(self.json_file, "r") as json_file: + content = json_file.read() + if not content.strip(): # Check if content is empty or contains only whitespace + print("File is empty or contains only whitespace.") + else: + result_dict = json.loads(content) or result_dict + print(f"Updating with {key} data: {data}") + if result_dict: + try: + # handle internal named entity if present + results = result_dict[next(iter(result_dict))] + except Exception: + results = result_dict + current = results.get(key) + if not current: + results[key] = data + else: + if isinstance(current, list): + current.extend(data) + elif isinstance(current, dict): + current.update(data) + else: + result_dict["results"] = {key: data} + with open(self.json_file, "w") as json_file: + json.dump(result_dict, json_file) + + def store_metrics(self, metrics: List[Metric]) -> None: + """ + Writes collected metrics to the file. + """ + data = [ + { + "MetricName": metric.name, + "MeasuredValue": metric.value, + "Units": metric.units, + "Timestamp": datetime.now(timezone.utc).isoformat(), + "AdditionalData": metric.additional_data, + } + for metric in metrics + ] + self.update(data=data, key="metrics") + + def store_parameters(self, parameters: Dict[str, Any]) -> None: + """ + Writes specified model and configuration parameters to the file. + """ + self.update(data=parameters, key="parameters") + + def update(self, **kwargs: Any) -> None: + """ + Write specified data to the output file. + """ + self.read_modify_write_file(**kwargs) + + +class Throughput: + def __init__(self, batch_size, world_size, grad_accum_usteps, moving_avg_window_size=10, logging_interval=1): + """ + Used to calculate the throughput over a moving window. It records the step time + between two calls and uses that time to calculate the throughput. + """ + self.seqs_per_iteration = batch_size * world_size * grad_accum_usteps * logging_interval + self.moving_avg_window_size = math.ceil(moving_avg_window_size / logging_interval) + self.moving_avg_window = queue.Queue() + self.window_time = 0 + self.start_time = time.time() + + def get_throughput(self): + step_time = time.time() - self.start_time + self.start_time += step_time + self.window_time += step_time + self.moving_avg_window.put(step_time) + window_size = self.moving_avg_window.qsize() + if window_size > self.moving_avg_window_size: + self.window_time -= self.moving_avg_window.get() + window_size -= 1 + throughput = window_size * self.seqs_per_iteration / self.window_time + return throughput + + +def get_mixed_precision_config(use_gpu_compatible_precision): + return { + "use_master_weights": bool(use_gpu_compatible_precision), + "use_fp32_grad_acc": bool(use_gpu_compatible_precision), + "use_master_weights_in_ckpt": False, + }