diff --git a/examples/training/mamba2/LICENSE b/examples/training/mamba2/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/examples/training/mamba2/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/examples/training/mamba2/NOTICE b/examples/training/mamba2/NOTICE
new file mode 100644
index 0000000..f0df437
--- /dev/null
+++ b/examples/training/mamba2/NOTICE
@@ -0,0 +1,3 @@
+Mamba-2 port for AWS Trainium
+Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
\ No newline at end of file
diff --git a/examples/training/mamba2/convert_mamba2_ssm_checkpoint.py b/examples/training/mamba2/convert_mamba2_ssm_checkpoint.py
new file mode 100644
index 0000000..becb6d3
--- /dev/null
+++ b/examples/training/mamba2/convert_mamba2_ssm_checkpoint.py
@@ -0,0 +1,233 @@
+# coding=utf-8
+# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
+# Modifications Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script can be used to convert checkpoints provided in the `mamba2_ssm` library into the format provided in HuggingFace `transformers`.
+It depends on the `mamba2_ssm` package to be installed.
+
+Unlike the Mamba2 implementation in transformers, we split the Mamba2Mixer.in_proj in three separate linear layer.
+This version of the script has an additional flag --split_proj to convert checkpoint to our format.
+"""
+
+import argparse
+import json
+from functools import partial
+from os import path
+from typing import Dict, Optional
+
+import torch
+from safetensors import safe_open
+from safetensors.torch import save_model
+
+from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast
+from mamba2.configuration_mamba2 import Mamba2Config
+from mamba2.modeling_mamba2 import Mamba2ForCausalLM
+
+
+def load_state_dict_from_safetensors(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
+ # Load weights and config from paths
+ original_state_dict = {}
+ with safe_open(path.join(mamba2_checkpoint_path, ckpt_name), framework="pt") as f:
+ for k in f.keys():
+ newk = k.removeprefix("model.")
+ original_state_dict[newk] = f.get_tensor(k).clone()
+ return original_state_dict
+
+
+def load_state_dict_from_torch(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
+ return torch.load(path.join(mamba2_checkpoint_path, ckpt_name), map_location="cpu")
+
+
+def convert_ssm_config_to_hf_config(config_ssm: Dict, mamba2_model_dict: Dict) -> Mamba2Config:
+ """Convert a Mamba2Config from mamba_ssm to a Mamba2Config from here."""
+ hf_config = Mamba2Config()
+
+ # Switch to a different dict depending on model type
+ config_dict = mamba2_model_dict
+
+ # Set important values from config and recalculate other resulting entries
+ hf_config.hidden_size = config_ssm[config_dict["hidden_size"]]
+ hf_config.num_heads = (hf_config.hidden_size * hf_config.expand) // hf_config.head_dim
+ hf_config.num_hidden_layers = config_ssm[config_dict["num_hidden_layers"]]
+ hf_config.n_groups = config_ssm.get(config_dict["n_groups"], 1)
+ hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
+ hf_config.bos_token_id = config_dict["bos_token_id"]
+ hf_config.pad_token_id = config_dict["pad_token_id"]
+ hf_config.eos_token_id = config_dict["eos_token_id"]
+
+ # Padded vocab size, mostly of 16 but 32 is also very common in different models
+ vocab_size = config_ssm["vocab_size"]
+ pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
+ if (vocab_size % pad_vocab_size_multiple) != 0:
+ vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
+ hf_config.vocab_size = vocab_size
+
+ return hf_config
+
+
+def load_and_save_tokenizer(
+ mamba2_model_type: str,
+ output_dir: str,
+ tokenizer_model_path: Optional[str] = None,
+) -> None:
+ tokenizer = None
+
+ # Load tokenizer
+ if tokenizer_model_path is not None and mamba2_model_type == "codestral":
+ tokenizer_class = LlamaTokenizerFast
+ tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
+ elif mamba2_model_type == "mamba_ssm":
+ tokenizer = GPTNeoXTokenizerFast.from_pretrained("state-spaces/mamba-130m-hf", padding_side="left")
+
+ # Save tokenizer
+ if tokenizer is not None:
+ tokenizer.save_pretrained(output_dir)
+
+
+_MAMBA2_MODELS_DICT = {
+ "codestral": {
+ "hidden_size": "dim",
+ "num_hidden_layers": "n_layers",
+ "n_groups": "n_groups",
+ "bos_token_id": 0,
+ "pad_token_id": 1,
+ "eos_token_id": 2,
+ "config_name": "params.json",
+ "load_state_dict": partial(load_state_dict_from_safetensors, ckpt_name="consolidated.safetensors"),
+ "load_and_save_tokenizer": partial(load_and_save_tokenizer, "codestral"),
+ },
+ "mamba_ssm": {
+ "hidden_size": "d_model",
+ "num_hidden_layers": "n_layer",
+ "n_groups": "ngroups",
+ "bos_token_id": 0,
+ "pad_token_id": 0,
+ "eos_token_id": 0,
+ "config_name": "config.json",
+ "load_state_dict": partial(load_state_dict_from_torch, ckpt_name="pytorch_model.bin"),
+ "load_and_save_tokenizer": partial(load_and_save_tokenizer, "mamba_ssm"),
+ },
+}
+
+
+def split_projection_matrix(hf_model, state_dict):
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if not k.endswith('in_proj.weight'):
+ new_state_dict[k] = v
+ continue
+ layer_name = k.split('.in_proj.weight')[0]
+
+ layer = hf_model.get_submodule(layer_name)
+ w_z, w_xBC, w_dt = torch.split(v, [
+ layer.intermediate_size,
+ layer.conv_dim,
+ layer.num_heads], dim=0)
+
+ # since .split() gives a view, we do .clone() to ensure we don't keep the original tensor in memory
+ new_state_dict[layer_name + '.in_proj_z.weight'] = w_z.detach().clone()
+ new_state_dict[layer_name + '.in_proj_xBC.weight'] = w_xBC.detach().clone()
+ new_state_dict[layer_name + '.in_proj_dt.weight'] = w_dt.detach().clone()
+
+ return new_state_dict
+
+
+def convert_mamba2_checkpoint_file_to_huggingface_model_file(
+ mamba2_checkpoint_path: str,
+ mamba2_model_type: str,
+ precision: str,
+ output_dir: str,
+ tokenizer_model_path: Optional[str] = None,
+ split_proj: bool = False
+) -> None:
+ mamba2_model_dict = _MAMBA2_MODELS_DICT[mamba2_model_type]
+
+ # Load and save config based on name
+ config_path = path.join(mamba2_checkpoint_path, mamba2_model_dict["config_name"])
+ with open(config_path, "r", encoding="utf-8") as json_file:
+ config = json.load(json_file)
+ hf_config = convert_ssm_config_to_hf_config(config_ssm=config, mamba2_model_dict=mamba2_model_dict)
+ hf_config.save_pretrained(output_dir)
+
+ # Load state dict of the original model and transfer to hf model
+ original_state_dict = mamba2_model_dict["load_state_dict"](mamba2_checkpoint_path=mamba2_checkpoint_path)
+ hf_model = Mamba2ForCausalLM(hf_config)
+ if split_proj:
+ original_state_dict = split_projection_matrix(hf_model, original_state_dict)
+ hf_model.load_state_dict(original_state_dict)
+
+ # Save new model to pytorch_dump_path
+ dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
+ save_model(hf_model.to(dtype), path.join(output_dir, "model.safetensors"), metadata={"format": "pt"})
+
+ # Load and save tokenizer
+ mamba2_model_dict["load_and_save_tokenizer"](output_dir=output_dir, tokenizer_model_path=tokenizer_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-i",
+ "--mamba2_checkpoint_directory",
+ type=str,
+ required=True,
+ help="Path to a directory containing the `pytorch_model.bin` or `.safetensors` mamba2_ssm checkpoint file to be converted.",
+ )
+ parser.add_argument(
+ "-m",
+ "--mamba2_model_type",
+ type=str,
+ default="mamba_ssm",
+ required=True,
+ choices=("codestral", "mamba_ssm"),
+ help="The model type the conversion will be performed on. Can choose from either `codestral` or `mamba_ssm`.",
+ )
+ parser.add_argument(
+ "-p",
+ "--precision",
+ type=str,
+ default="fp16",
+ required=True,
+ choices=("fp32", "fp16", "bf16"),
+ help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
+ )
+ parser.add_argument(
+ "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
+ )
+ parser.add_argument(
+ "-t",
+ "--tokenizer_model_path",
+ type=str,
+ default=None,
+ required=False,
+ help="Path to a `codestral` tokenizer file.",
+ )
+ parser.add_argument(
+ "-s",
+ "--split_proj",
+ action="store_true",
+ help="Split the input projection matrix in 3 separate matrices",
+ )
+ args = parser.parse_args()
+
+ convert_mamba2_checkpoint_file_to_huggingface_model_file(
+ args.mamba2_checkpoint_directory,
+ args.mamba2_model_type,
+ args.precision,
+ args.output_dir,
+ args.tokenizer_model_path,
+ args.split_proj
+ )
diff --git a/examples/training/mamba2/mamba2/__init__.py b/examples/training/mamba2/mamba2/__init__.py
new file mode 100644
index 0000000..789dbf1
--- /dev/null
+++ b/examples/training/mamba2/mamba2/__init__.py
@@ -0,0 +1,2 @@
+from .modeling_mamba2 import *
+from .configuration_mamba2 import *
\ No newline at end of file
diff --git a/examples/training/mamba2/mamba2/configuration_mamba2.py b/examples/training/mamba2/mamba2/configuration_mamba2.py
new file mode 100644
index 0000000..71feffd
--- /dev/null
+++ b/examples/training/mamba2/mamba2/configuration_mamba2.py
@@ -0,0 +1,186 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+# Modifications Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Same as transformers.model.mamba2.Mamba2Config but uses different defaults and adds an option rmsnorm_within_groups."""
+
+import math
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Mamba2Config(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`Mamba2Model`]. It is used to instantiate a MAMBA2
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the MAMBA2
+ [state-spaces/mamba2-2.8b](https://huggingface.co/state-spaces/mamba2-2.8b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ num_heads (`int`, *optional*, defaults to 128):
+ Number of heads for the evolution matrices of mamba 2.
+ head_dim (`int`, *optional*, defaults to 64):
+ Dimension of each head.
+ vocab_size (`int`, *optional*, defaults to 32768):
+ Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Mamba2Model`].
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimensionality of the embeddings and hidden states.
+ state_size (`int`, *optional*, defaults to 128): shape of the state space latents.
+ num_hidden_layers (`int`, *optional*, defaults to 64):
+ Number of hidden layers in the model.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+ The epsilon to use in the layer normalization layers.
+ pad_token_id (`int`, *optional*, defaults to 1):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 0):
+ The id of the beginning of sentence token in the vocabulary.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the end of sentence token in the vocabulary.
+ expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
+ conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
+ n_groups (`int`, *optional*, defaults to 8):
+ Number of groups for the evolution matrices of mamba 2.
+ use_bias (`bool`, *optional*, defaults to `False`):
+ Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
+ use_conv_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not to use bias in the convolution layer of the mixer block.
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ initializer_range (`float`, *optional*, defaults to 0.1):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ residual_in_fp32 (`bool`, *optional*, defaults to `True`):
+ Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
+ time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
+ Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
+ time_step_min (`float`, *optional*, defaults to 0.001):
+ Minimum `time_step` used to bound `dt_proj.bias`.
+ time_step_max (`float`, *optional*, defaults to 0.1):
+ Maximum `time_step` used to bound `dt_proj.bias`.
+ time_step_floor (`float`, *optional*, defaults to 0.0001):
+ Minimum clamping value of the `dt_proj.bias` layer initialization.
+ time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`):
+ Accepted range of time step values.
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
+ Whether or not to rescale `out_proj` weights when initializing.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the cache should be used.
+ rms_norm (`bool`, *optional*, defaults to `True`):
+ Whether to use RMS norm or not.
+ rmsnorm_within_groups (`bool`, *optional*, defaults to `True`):
+ Whether to use RMS norm independently within n_groups or not.
+ chunk_size (`int`, *optional*, defaults to 256):
+ Size of the chunks that will comprise the sequence.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie word embeddings or not.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import Mamba2Config, Mamba2Model
+
+ >>> # Initializing a Mamba2 configuration
+ >>> configuration = Mamba2Config()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = Mamba2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "mamba2"
+
+ def __init__(
+ self,
+ num_heads=128,
+ head_dim=64,
+ vocab_size=32768,
+ hidden_size=4096,
+ state_size=128,
+ num_hidden_layers=64,
+ layer_norm_epsilon=1e-5,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ expand=2,
+ conv_kernel=4,
+ n_groups=8,
+ use_bias=False,
+ use_conv_bias=True,
+ hidden_act="silu",
+ initializer_range=0.1,
+ residual_in_fp32=True,
+ time_step_rank="auto",
+ time_step_min=0.001,
+ time_step_max=0.1,
+ time_step_floor=1e-4,
+ time_step_limit=(0.0, float("inf")),
+ rescale_prenorm_residual=False,
+ use_cache=False, # fixme: the default in HF is True but we don't support cache yet
+ rms_norm=True,
+ rmsnorm_within_groups=True,
+ chunk_size=256,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.state_size = state_size
+ self.num_hidden_layers = num_hidden_layers
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.conv_kernel = conv_kernel
+ self.expand = expand
+
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.use_bias = use_bias
+ self.use_conv_bias = use_conv_bias
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
+ self.time_step_min = time_step_min
+ self.time_step_max = time_step_max
+ self.time_step_floor = time_step_floor
+ self.rescale_prenorm_residual = rescale_prenorm_residual
+ self.residual_in_fp32 = residual_in_fp32
+ self.use_cache = use_cache
+ self.n_groups = n_groups
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.rms_norm = rms_norm
+ self.rmsnorm_within_groups=rmsnorm_within_groups
+ self.state_size = state_size
+ self.chunk_size = chunk_size
+ self.time_step_limit = time_step_limit
+ self.tie_word_embeddings = tie_word_embeddings
+
+ super().__init__(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
\ No newline at end of file
diff --git a/examples/training/mamba2/mamba2/conv1d_grouped.py b/examples/training/mamba2/mamba2/conv1d_grouped.py
new file mode 100644
index 0000000..5a4cb4c
--- /dev/null
+++ b/examples/training/mamba2/mamba2/conv1d_grouped.py
@@ -0,0 +1,365 @@
+"""NKI implementation of a causal channel-wise 1d convolution.
+
+This is a drop-in replacement of torch.nn.Conv1d when groups == in_channels == out_channels. It automatically
+applies the right amount of left zero-padding to ensure the length of the output is the same as the input.
+"""
+
+from typing import Union
+
+import torch
+
+import neuronxcc.nki as nki
+import neuronxcc.nki.language as nl
+import neuronxcc.nki.isa as nisa
+from torch_neuronx import nki_jit
+import neuronx_distributed.parallel_layers.parallel_state as ps
+
+import numpy as np
+import torch.nn as nn
+
+
+def diag(w):
+ """
+ w: (128, 1)
+ """
+ m = w.shape[0]
+ i_p = nl.arange(m)[:, None]
+ i_f = nl.arange(m)[None, :]
+ w_diag = nisa.affine_select(i_p == i_f, w.broadcast_to((m, m)), 0)
+ return w_diag
+
+
+def apply_activation(x, bias, activation: str):
+ if activation is None:
+ return x + bias if bias is not None else x
+ elif activation == 'relu':
+ return nisa.activation(nl.relu, x, bias=bias)
+ elif activation == 'd_relu':
+ assert bias is None
+ return x >= 0
+ elif activation == 'silu':
+ z = x + bias if bias is not None else x
+ return z * nisa.activation(nl.sigmoid, z)
+ elif activation == 'd_silu':
+ z = x + bias if bias is not None else x
+ return nl.sigmoid(z) * (1 + z * (1 - nl.sigmoid(z)))
+ elif activation == 'd_identity':
+ return 1
+ else:
+ raise ValueError(f'Invalid activation {activation}')
+
+
+def _conv_tile_tensor_e(data_tile, weight_tile, b_tile, dtype, activation=None):
+ """Computes and returns the convolution of a data tile given weights and bias.
+
+ Isolated from the rest of the code since we use it both for forward and backward.
+ """
+ p_size, n = data_tile.shape
+ kernel_size = weight_tile.shape[1]
+
+ conv = nl.zeros(shape=(p_size, n), dtype=dtype)
+
+ chunk_size = n // kernel_size
+
+ i_p = nl.arange(p_size)[:, None]
+
+ for j in nl.affine_range(kernel_size):
+ i_f_plus_j = j + nl.arange(chunk_size)[None, :] * kernel_size
+ res = nl.zeros((p_size, chunk_size), dtype=nl.float32, buffer=nl.psum)
+ for i in nl.affine_range(kernel_size):
+ w_diag = diag(weight_tile[i_p, i])
+ res += nisa.nc_matmul(w_diag, data_tile[i_p, i_f_plus_j + i])
+ conv[i_p, i_f_plus_j] = apply_activation(res, bias=b_tile, activation=activation)
+
+ return conv
+
+
+def _conv_tile_scalar_e(data_tile, weight_tile, b_tile, dtype, activation=None):
+ """Computes and returns the convolution of a data tile given weights and bias.
+
+ Isolated from the rest of the code since we use it both for forward and backward.
+ """
+ p_size, n = data_tile.shape
+ kernel_size = weight_tile.shape[1]
+
+ conv = nl.ndarray(shape=(p_size, n), dtype=dtype)
+
+ chunk_size = n // kernel_size
+
+ i_p = nl.arange(p_size)[:, None]
+
+ for j in nl.affine_range(kernel_size):
+ i_f_plus_j = j + nl.arange(chunk_size)[None, :] * kernel_size
+ res = nki.isa.tensor_scalar(data_tile[i_p, i_f_plus_j], op0=np.multiply,
+ operand0=weight_tile[i_p, 0], dtype=dtype)
+ for i in nl.static_range(1, kernel_size):
+ res = res + nki.isa.tensor_scalar(data_tile[i_p, i_f_plus_j + i], op0=np.multiply,
+ operand0=weight_tile[i_p, i],
+ dtype=dtype)
+ conv[i_p, i_f_plus_j] = apply_activation(res, bias=b_tile, activation=activation)
+ return conv
+
+
+# _conv_tile = _conv_tile_scalar_e
+_conv_tile = _conv_tile_tensor_e
+
+
+@nki_jit
+def conv1d_grouped_kernel(input_data, w, b, output, activation=None):
+ """NKI kernel to compute grouped 1d causal convolution, equivalent to:
+
+ D, L = x.shape
+
+ conv = nn.Conv1d(
+ in_channels=D,
+ out_channels=D,
+ bias=True,
+ kernel_size=kernel_size,
+ groups=D,
+ padding=kernel_size - 1,
+ )
+ y = conv(x)[:, :L]
+
+ Args:
+ input_data: input tensor of shape [D,L]
+ w: conv weights of shape [D, kernel_size]
+ b: conv bias of shape [D]
+ output: output tensor of shape [D, L]
+ """
+
+ batch_size, p, n = input_data.shape
+ ch, _, ks = w.shape
+ dtype = input_data.dtype
+
+ # fixme: make the code work for any size
+ assert p % 128 == 0 and ch == p and p == ch
+ assert n % ks == 0 and n > ks # check n is a multiple of kernel size
+ assert ks == 4 # fixme: don't think this constrain is needed
+ assert n <= 2048, "conv1d does not yet support sequence lengths larger than 2048"
+
+ i_w = nl.arange(ks)[None, :]
+ i_p = nl.arange(128)[:, None]
+ i_y = nl.arange(n)[None, :]
+
+ # Iterate over channel dimension then over batch dimension (so we load the weights only once for all samples)
+ for k in nl.affine_range(input_data.shape[1] // 128):
+ i_p_input = i_p + k * 128
+ # weights and biases for current tile
+ w_tile = nl.load(w.reshape((ch, ks))[i_p_input, i_w])
+ b_tile = nl.load(b[i_p_input])
+
+ for i_b in nl.affine_range(input_data.shape[0]):
+ # Load with padding
+ x = nl.zeros(shape=(128, n + ks - 1), dtype=dtype)
+ x[i_p, ks - 1 + i_y] = nl.load(input_data[i_b, i_p_input, i_y])
+ # run the convolution
+ conv = _conv_tile(x, w_tile, b_tile, dtype, activation=activation)
+ # The first positions contain the result of a zero padded window
+ nl.store(output[i_b, i_p_input, i_y], value=conv[i_p, i_y])
+
+
+@nki_jit
+def conv1d_grouped_kernel_longseq(input_data, w, b, output, activation=None):
+ """NKI kernel to compute grouped 1d causal convolution for sequences of all lengths by processing them in sub-sequences
+
+ equivalent to:
+
+ D, L = x.shape
+
+ conv = nn.Conv1d(
+ in_channels=D,
+ out_channels=D,
+ bias=True,
+ kernel_size=kernel_size,
+ groups=D,
+ padding=kernel_size - 1,
+ )
+ y = conv(x)[:, :L]
+
+ Args:
+ input_data: input tensor of shape [D,L]
+ w: conv weights of shape [D, kernel_size]
+ b: conv bias of shape [D]
+ output: output tensor of shape [D, L]
+ """
+
+ _, channels, seq_len = input_data.shape
+ ch, _, ks = w.shape
+ dtype = input_data.dtype
+
+ # fixme: make the code work for any size
+ assert channels % 128 == 0 and ch == channels
+ assert seq_len % ks == 0 and seq_len > ks # check seq_len is a multiple of kernel size
+ assert ks == 4 # fixme: don't think this constrain is needed
+
+ i_w = nl.arange(ks)[None, :]
+ i_p = nl.arange(128)[:, None]
+
+ sub_seq_len = min(2048, seq_len)
+ num_sub_seqs = (seq_len + sub_seq_len - 1) // sub_seq_len
+
+ padded_len = sub_seq_len + ks - 1
+ i_f_subseq = nl.arange(sub_seq_len)[None, :]
+ i_f_subseq_padded = nl.arange(padded_len)[None, :]
+
+ # Iterate over channel dimension then over batch dimension (so we load the weights only once for all samples)
+ for k in nl.affine_range(input_data.shape[1] // 128):
+ i_p_input = i_p + k * 128
+ # weights and biases for current tile
+ w_tile = nl.load(w.reshape((ch, ks))[i_p_input, i_w])
+ b_tile = nl.load(b[i_p_input])
+
+ for batch_id in nl.affine_range(input_data.shape[0]):
+
+ for subseq_id in nl.affine_range(num_sub_seqs):
+ i_f_subseq_padded_in = i_f_subseq_padded + subseq_id * sub_seq_len
+ x = nl.zeros(shape=(128, padded_len), dtype=dtype)
+ mask = (i_f_subseq_padded_in - (ks - 1) >= 0) & (i_f_subseq_padded_in - (ks - 1) < seq_len)
+ x[i_p, i_f_subseq_padded] = nl.load(input_data[batch_id, i_p_input, i_f_subseq_padded_in - (ks - 1)],
+ mask=mask)
+ # run the convolution
+ conv = _conv_tile(x, w_tile, b_tile, dtype, activation=activation)
+ mask_out = i_f_subseq + subseq_id * sub_seq_len < seq_len
+ # store the result
+ nl.store(output[batch_id, i_p_input, i_f_subseq + subseq_id * sub_seq_len], value=conv[i_p, i_f_subseq],
+ mask=mask_out)
+
+
+@nki_jit
+def conv1d_grouped_kernel_grad(input_data, w, d_output, d_input, d_w, d_b, activation=None):
+ batch_size, p, n = input_data.shape
+ ch, _, ks = w.shape
+ dtype = input_data.dtype
+
+ assert p % 128 == 0 and ch == p and p == ch
+ assert n % ks == 0 and n > ks # check n is a multiple of kernel size
+ assert ks == 4
+
+ i_p = nl.arange(128)[:, None]
+ i_f_n = nl.arange(n)[None, :]
+ i_f_w = nl.arange(ks)[None, :]
+ seq_len = n + ks - 1
+ i_f_seq_len = nl.arange(seq_len)[None, :]
+
+ if activation is not None:
+ d_activation = 'd_' + activation
+ else:
+ d_activation = 'd_identity'
+
+ for chunk_id in nl.affine_range(input_data.shape[1] // 128):
+ i_p_input = chunk_id * 128 + nl.arange(128)[:, None]
+ w_tile = nl.load(w[i_p_input, 0, i_f_w])
+ # we don't need the bias to compute gradients
+ b_tile = None
+
+ db_accumulation = nl.zeros([128, 1], dtype=dtype)
+ dw_accumulation = nl.zeros([128, ks], dtype=dtype)
+
+ for batch_id in nl.affine_range(input_data.shape[0]):
+ # fixme: probably don't need to pad this
+ x = nl.zeros(shape=(128, n + ks - 1), dtype=dtype)
+ x[i_p, ks - 1 + i_f_n] = nl.load(input_data[batch_id, i_p_input, i_f_n])
+
+ if activation is not None:
+ preact_grad = _conv_tile(x, w_tile, b_tile, dtype, activation=d_activation)[i_p, i_f_n]
+ else:
+ preact_grad = 1
+ dout_tile = nl.zeros(shape=(128, n + ks - 1), dtype=dtype)
+ dout_tile[i_p, i_f_n] = preact_grad * nl.load(d_output[batch_id, i_p_input, i_f_n])
+
+ # Compute db
+ db_accumulation += nl.sum(dout_tile[i_p, i_f_n], axis=[1])
+
+ # Compute d_input
+ dout_reverse = nl.ndarray((128, seq_len), dtype=dtype)
+ # fixme: we should simply index the tile with flipped indexes, no need for the copy
+ # but it will break down later as double indexing tile[i_p, i_f][i_p1, i_f1] is not supported
+ dout_reverse[i_p, i_f_seq_len] = dout_tile[i_p, seq_len - 1 - i_f_seq_len]
+ # dout_reverse = dout_tile[i_p, seq_len - 1 - i_f_seq_len]
+
+ conv = _conv_tile(dout_reverse, w_tile, b_tile=None, dtype=dtype, activation=None)
+
+ # We flip the result while storing
+ nl.store(d_input[batch_id, i_p_input, i_f_n], conv[i_p, seq_len - ks - i_f_n])
+
+ dw_batch = nl.ndarray((128, 4), dtype=dtype)
+ # Compute dw
+ for i in nl.static_range(ks):
+ # todo: the vector engine should be able to execute both element-wise product and sum in one instruction
+ dw_batch[i_p, i] = nl.sum(x[i_p, i + i_f_n] * dout_tile[i_p, i_f_n], axis=[1])
+ dw_accumulation += dw_batch
+
+ nl.store(d_b[i_p_input], db_accumulation[i_p, 0])
+ nl.store(d_w[i_p_input, 0, i_f_w], dw_accumulation[i_p, i_f_w])
+
+
+class GroupedConv1dNKI(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, weight, bias, activation=None):
+ # fixme: if output is too large we might avoid to store it and recomputed it during backprop
+ output = torch.empty_like(input)
+ # if input.shape[2] <= 2048:
+ # output = conv1d_grouped_kernel(input, weight, bias, output, activation=activation)
+ # else:
+ output = conv1d_grouped_kernel_longseq(input, weight, bias, output, activation=activation)
+ ctx.save_for_backward(input, weight, bias)
+ ctx.activation = activation
+ return output
+
+ @staticmethod
+ def backward(ctx, d_output):
+ input, weight, bias = ctx.saved_tensors
+ dinput = torch.empty_like(input)
+ dweight = torch.empty_like(weight)
+ dbias = torch.empty_like(bias)
+ if input.shape[2] > 2048:
+ raise NotImplementedError('Gradient not implemented for conv1d with seq_len>2048')
+ # dinput, dweight, dbias = conv1d_grouped_kernel_bwd(input, weight, bias, d_output)
+ conv1d_grouped_kernel_grad(input, weight, d_output, dinput, dweight, dbias, activation=ctx.activation)
+ return dinput, dweight, dbias, None
+
+
+def nki_conv1d(input, weight, bias=None, activation=None):
+ return GroupedConv1dNKI.apply(input, weight, bias, activation)
+
+
+class ConvNKI(nn.Conv1d):
+ """
+ Custom layer implemented in NKI to compute efficiently a grouped convolution,
+ equivalent to nn.Conv1d with groups == in_channels == out_channels.
+
+ Parameters:
+ input: (B_tensor, C_tensor, L)
+ weight: (C_tensor, 1, kernel_size)
+ bias: (C_tensor)
+ Return:
+ output: (B_tensor, C_tensor, L) Each input channel sequence input[b, c, :] is convolved with its own conv weight[c, 0, :].
+ The results are then stacked together.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1,
+ padding: Union[str, int] = 0, dilation: int = 1, groups: int = 1, bias: bool = True,
+ padding_mode: str = 'zeros', device=None, dtype=None, activation=None) -> None:
+ # We only support a very specific use case, check we are in it
+ assert groups == in_channels, "NKI grouped conv kernel only supports groups == in_channels"
+ assert padding == kernel_size - 1
+ assert padding_mode == 'zeros'
+ assert dilation == 1
+ assert stride == 1
+ super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode,
+ device, dtype)
+ self.activation = activation
+ self.parallel_split()
+
+ def parallel_split(self):
+ tp_rank = ps.get_tensor_model_parallel_rank()
+ tp_size = ps.get_tensor_model_parallel_size()
+
+ chunk = slice(self.out_channels // tp_size * tp_rank, self.out_channels // tp_size * (tp_rank + 1))
+ self.weight.data = self.weight.data[chunk].detach().clone()
+ self.bias.data = self.bias.data[chunk].detach().clone()
+ self.in_channels = self.out_channels // tp_size
+ self.out_channels = self.out_channels // tp_size
+
+ def forward(self, input):
+ return GroupedConv1dNKI.apply(input, self.weight, self.bias, self.activation)
diff --git a/examples/training/mamba2/mamba2/mamba2_kernel.py b/examples/training/mamba2/mamba2/mamba2_kernel.py
new file mode 100644
index 0000000..8a24479
--- /dev/null
+++ b/examples/training/mamba2/mamba2/mamba2_kernel.py
@@ -0,0 +1,375 @@
+"""NKI implementation of the forward and backward SSM kernels for a Mamba2 model.
+
+Paper: https://arxiv.org/abs/2405.21060
+"""
+
+import torch
+import neuronxcc.nki.language as nl
+from torch_neuronx import nki_jit
+import neuronxcc.nki.isa as nisa
+
+
+def compute_a_factors(logA, ones_triu, chunk_size, d_state, d_head=None):
+ dtype = logA.dtype
+ i_p = nl.arange(chunk_size)[:, None]
+ if d_head is None:
+ d_head = chunk_size
+ i_p_head = nl.arange(d_head)[:, None]
+ i_f = nl.arange(chunk_size)[None, :]
+
+ # we reuse this bcast later in the code, ensure it is large enough for all uses
+ bcast_size = max(chunk_size, d_state)
+
+ logA_bcast = logA.broadcast_to((chunk_size, bcast_size))
+ l_bcast = nl.matmul(ones_triu, logA_bcast, transpose_x=True)
+ l = l_bcast[:chunk_size, 0]
+ l_t = nl.transpose(l_bcast, dtype=dtype)
+
+ # === compute the _transpose_ of the 128x128 lower triangular matrix L ===
+ partial_sums = l_t[:chunk_size, :chunk_size] - l
+ L_full_t = nl.exp(partial_sums)
+ L_t = nisa.affine_select(i_f >= i_p, L_full_t, 0)
+
+ a_right = L_t[i_p, chunk_size - 1]
+ a_left = nl.exp(l_t[i_p_head, i_f])
+
+ if d_head != chunk_size:
+ a_center_t = a_left[:, chunk_size - 1].broadcast_to((d_head, bcast_size))
+ a_center = nl.transpose(a_center_t, dtype=dtype)
+ else:
+ a_center = a_left[:, chunk_size - 1]
+
+ return L_t, a_left, a_center, a_right
+
+
+def compute_chunk_output(BC_t, L_t, X, C_t, a_left, S,
+ transpose_gate=False,
+ transpose_broadcast=False):
+ """Utility function for computing the output, shared between forward and backward"""
+ # Diagonal term computation
+ M_diag_t = L_t * BC_t
+ Y_diag = nl.matmul(M_diag_t, X, transpose_x=not transpose_gate)
+ # Compute the off-diagonal contribution using the state
+ barC_t = C_t * a_left
+ # utility function for
+ Y_off = nl.matmul(barC_t, S, transpose_x=not transpose_broadcast)
+ return Y_diag + Y_off
+
+
+@nki_jit
+def mamba_kernel_(dt_tensor, logA_tensor, X_tensor, B_tensor, C_tensor, out_Y_tensor, D_tensor=None):
+ """
+ dt_tensor: (batch_size, n_heads, seq_len)
+ logA_tensor: (n_heads)
+ X_tensor: (batch_size, seq_len, n_heads, d_head)
+ B_tensor: (batch_size, seq_len, n_groups, d_state)
+ C_tensor: (batch_size, seq_len, n_groups, d_state)
+ D_tensor: (n_heads)
+ """
+ # Since this kernel requires high-precision, we run all internal computations in fp32.
+ # Note: the speedup by using bf16 everywhere would be ~15%
+ dtype = nl.float32
+ block_size = 128 # we will split seq_len in chunks of size `block_size`
+
+ batch_size, seq_len, n_heads, d_head = X_tensor.shape
+ _, _, n_groups, d_state = B_tensor.shape
+ assert seq_len % block_size == 0
+ n_chunks = seq_len // block_size
+ n_heads_per_group = n_heads // n_groups
+
+ batch_id = nl.program_id(0)
+
+ i_p = nl.arange(block_size)[:, None]
+ i_f = nl.arange(block_size)[None, :]
+ i_f_state = nl.arange(d_state)[None, :]
+ i_f_head = nl.arange(d_head)[None, :]
+
+ # creates a constant upper triangular matrix of ones
+ ones_triu = nisa.affine_select(i_p <= i_f, nl.ones((block_size, block_size), dtype=dtype), 0)
+
+ for group_id in nl.affine_range(n_groups): # parallel for loop over each group (they are completely independent)
+ # === Preload/compute logA, B, C_t and B @ C_t ====
+ # (they are shared between multiple heads in the same group)
+ B_cache = nl.ndarray((block_size, n_chunks, d_state), dtype=dtype)
+ # todo: storing in this format may be a bad idea if d_state != 128?
+ C_t_cache = nl.ndarray((d_state, n_chunks, block_size), dtype=dtype)
+ BC_t_cache = nl.ndarray((block_size, n_chunks, block_size), dtype=dtype)
+ for chunk_id in nl.affine_range(n_chunks):
+ i_p_in = i_p + chunk_id * block_size
+ B = nl.load(B_tensor[batch_id, i_p_in, group_id, i_f_state], dtype=dtype)
+ C = nl.load(C_tensor[batch_id, i_p_in, group_id, i_f_state], dtype=dtype)
+ C_t = nisa.nc_transpose(C)
+ B_cache[:, chunk_id, :] = B
+ C_t_cache[:, chunk_id, :] = C_t
+ BC_t_cache[:, chunk_id, :] = nl.copy(nl.matmul(B, C_t), dtype=dtype)
+
+ logA_cache = nl.load(logA_tensor.reshape((1, n_heads)), dtype=dtype).broadcast_to((block_size, n_heads))
+ if D_tensor is not None:
+ D = nl.load(D_tensor.reshape((1, n_heads)), dtype=dtype).broadcast_to((block_size, n_heads))
+ else:
+ D = None
+
+ # == Actual code ===
+ for head_id_in_group in nl.affine_range(n_heads_per_group): # parallel for loop over the n_heads
+ # get the global head_id given current group and current head in group
+ head_id = group_id * n_heads_per_group + head_id_in_group
+ # We iterate over the diagonal blocks and compute each Y_diag
+ # At the same time, we update our running sum S and use it to compute Y_off.
+ # We store Y = Y_diag + Y_off, and we move to the next block
+ S = nl.zeros((d_state, d_head), dtype=dtype)
+ for chunk_id in nl.sequential_range(n_chunks):
+ i_p_in = i_p + chunk_id * block_size
+
+ # broadcast dt and logA together
+ dt = nl.load(dt_tensor[batch_id, i_p_in, head_id], dtype=dtype)
+ logA = logA_cache[:, head_id] * dt
+
+ # load from cache the relevant blocks
+ B = B_cache[:, chunk_id, :]
+ C_t = C_t_cache[:, chunk_id, :]
+ BC_t = BC_t_cache[:, chunk_id, :]
+
+ # broadcast X and dt
+ X0 = nl.load(X_tensor[batch_id, i_p_in, head_id, i_f_head], dtype=dtype)
+ X = dt * X0
+
+ # Compute all logA related factors for this chunk
+ L_t, a_left, a_center, a_right = compute_a_factors(logA, ones_triu, block_size, d_state)
+
+ Y = compute_chunk_output(BC_t, L_t, X, C_t, a_left, S)
+
+ # Update running sum S (will be used in the next iteration)
+ barB = B * a_right
+ barBX = nl.matmul(barB, X, transpose_x=True)
+ S[...] = a_center * S + barBX
+
+ if D is not None:
+ Y[...] = Y + D[:, head_id] * X0
+
+ nl.store(out_Y_tensor[batch_id, i_p_in, head_id, i_f_head], Y)
+
+
+@nki_jit
+def mamba_kernel_bwd_(dt_tensor, logA_tensor, X_tensor, B_tensor, C_tensor,
+ d_out_tensor,
+ ddt_tensor,
+ dlogA_tensor,
+ dX_tensor,
+ dB_tensor,
+ dC_tensor,
+ D_tensor=None,
+ dD_tensor=None
+ ):
+ """
+ dt_tensor: (batch_size, seq_len, n_heads)
+ logA_tensor: (n_heads)
+ X_tensor: (batch_size, seq_len, n_heads, d_head)
+ B_tensor: (batch_size, seq_len, n_groups, d_state)
+ C_tensor: (batch_size, seq_len, n_groups, d_state)
+ D_tensor: (n_heads)
+ d_out_tensor: (batch_size, seq_len, n_heads, d_head)
+ All other derivative tensors (d_*) have the same shape as their corresponding input counterparts.
+ """
+
+ # Note: since saving the intermediate results of the forward pass would use too much memory, this kernel also
+ # recomputes the forward pass while computing the gradients.
+
+ # Since this kernel requires high-precision, we run all internal computations in fp32.
+ # Note: the speedup by using bf16 everywhere would be ~15%
+ dtype = nl.float32
+ block_size = 128 # we will split seq_len in chunks of size `block_size`
+ batch_size, seq_len, n_heads, d_head = X_tensor.shape
+ _, _, n_groups, d_state = B_tensor.shape
+ assert seq_len % block_size == 0
+ n_chunks = seq_len // block_size
+ n_heads_per_group = n_heads // n_groups
+
+ assert d_state == 128
+ assert d_head <= 128
+ assert block_size <= 128
+
+ batch_id = nl.program_id(0)
+
+ i_p = nl.arange(block_size)[:, None]
+ i_f = nl.arange(block_size)[None, :]
+ i_f_state = nl.arange(d_state)[None, :]
+ i_f_head = nl.arange(d_head)[None, :]
+
+ # upper triangular matrix of ones
+ ones_triu = nisa.affine_select(i_p <= i_f, nl.ones((block_size, block_size), dtype=dtype), 0)
+ ones_tril = nl.copy(nl.transpose(ones_triu), dtype=dtype)
+ ones_sum_right = nl.ones([d_state, 1], dtype=dtype)
+ ones_sum_left = nl.ones([1, d_state], dtype=dtype)
+ ones_sum_right_head = nl.ones([d_head, 1], dtype=dtype)
+
+ for group_id in nl.affine_range(n_groups): # iterate in parallel over all channel groups (they are independent)
+ # Preload/compute logA, B, C_t and B @ C_t (which are shared between multiple heads in the same group)
+ B_cache = nl.ndarray((block_size, n_chunks, d_state), dtype=dtype)
+ C_t_cache = nl.ndarray((d_state, n_chunks, block_size), dtype=dtype)
+ BC_t_cache = nl.ndarray((block_size, n_chunks, block_size), dtype=dtype)
+ for chunk_id in nl.affine_range(n_chunks):
+ i_p_in = i_p + chunk_id * block_size
+ B = nl.load(B_tensor[batch_id, i_p_in, group_id, i_f_state], dtype=dtype)
+ C = nl.load(C_tensor[batch_id, i_p_in, group_id, i_f_state], dtype=dtype)
+ C_t = nisa.nc_transpose(C)
+ B_cache[:, chunk_id, :] = B
+ C_t_cache[:, chunk_id, :] = C_t
+ BC_t_cache[:, chunk_id, :] = nl.copy(nl.matmul(B, C_t), dtype=dtype)
+
+ logA_cache = nl.load(logA_tensor.reshape((1, n_heads)), dtype=dtype).broadcast_to((block_size, n_heads))
+ if D_tensor is not None:
+ D = nl.load(D_tensor.reshape((1, n_heads)), dtype=dtype).broadcast_to((block_size, n_heads))
+ else:
+ D = None
+
+ dC_accumulation = nl.zeros((block_size, n_chunks, d_state), dtype=dtype)
+ dB_accumulation = nl.zeros((block_size, n_chunks, d_state), dtype=dtype)
+ dA_final = nl.zeros((1, n_heads), dtype=dtype)
+ if D is not None:
+ dD_final = nl.zeros((1, n_heads), dtype=dtype)
+
+ for head_id_in_group in nl.affine_range(n_heads_per_group): # the n_heads are completely independent
+ # get the global head_id given current group and current head in group
+ head_id = group_id * n_heads_per_group + head_id_in_group
+ dA_accumulation = nl.zeros((block_size, n_chunks, d_state), dtype=dtype)
+ S = nl.zeros((d_state, d_head), dtype=dtype)
+ for chunk_id in nl.sequential_range(n_chunks):
+ #
+ i_p_in = i_p + chunk_id * block_size
+ # broadcast dt and logA together
+ dt = nl.load(dt_tensor[batch_id, i_p_in, head_id])
+ logA = logA_cache[:, head_id] * dt
+ # Compute all logA related factors for this chunk
+ L_t, a_left, a_center, a_right = compute_a_factors(logA, ones_triu, block_size, d_state, d_head=d_head)
+ # load from cache the relevant blocks
+ B = B_cache[:, chunk_id, :]
+ C = nl.load(C_tensor[batch_id, i_p_in, group_id, i_f_state])
+ # broadcast X and dt
+ X0 = nl.load(X_tensor[batch_id, i_p_in, head_id, i_f_head])
+ X = dt * X0
+ #
+
+ # compute dC gradient
+ dO = nl.load(d_out_tensor[batch_id, i_p_in, head_id, i_f_head])
+ dO_t = nisa.nc_transpose(dO)
+ UdO_t = nl.matmul(X, dO_t) # (B, L, nheads, hdim)
+ S_t = nisa.nc_transpose(S)
+ dC = compute_chunk_output(UdO_t, L_t, B, dO_t, a_left, S_t)
+
+ #
+ # Update the state: running sum S (will be used in the next iteration)
+ barB = B * a_right
+ barBX = nl.matmul(barB, X, transpose_x=True)
+ S[...] = a_center * S + barBX
+ #
+ dC_accumulation[:, chunk_id, :] += dC
+ dA_accumulation[:, chunk_id, :] = dA_accumulation[:, chunk_id, :] + C * dC
+
+ dS = nl.zeros((d_state, d_head), dtype=dtype)
+ cumsum_dA = nl.zeros((1, d_state), dtype=dtype)
+ for chunk_id in nl.sequential_range(n_chunks):
+ chunk_id = n_chunks - 1 - chunk_id # To reverse time
+ i_p_in = i_p + chunk_id * block_size
+
+ # === Recompute forward pass ===
+ # broadcast dt and logA together
+ dt = nl.load(dt_tensor[batch_id, i_p_in, head_id])
+ logA = logA_cache[:, head_id] * dt
+ # Compute all logA related factors for this chunk
+ L_t, a_left, a_center, a_right = compute_a_factors(logA, ones_triu, block_size, d_state, d_head=d_head)
+ # load from cache the relevant blocks
+ B = B_cache[:, chunk_id, :]
+ C = nl.load(C_tensor[batch_id, i_p_in, group_id, i_f_state])
+ C_t = nisa.nc_transpose(C)
+ BC_t = BC_t_cache[:, chunk_id, :]
+ # broadcast X and dt
+ X0 = nl.load(X_tensor[batch_id, i_p_in, head_id, i_f_head])
+ X = dt * X0
+
+ # === Compute dX gradient ===
+ dO = nl.load(d_out_tensor[batch_id, i_p_in, head_id, i_f_head])
+ dU = compute_chunk_output(BC_t, L_t, dO, B, a_right, dS, transpose_gate=True, transpose_broadcast=True)
+
+ # === Compute dB gradient ===
+ X_t = nisa.nc_transpose(X)
+ dO_Xt = nl.matmul(dO, X_t)
+ L_t_ = nisa.nc_transpose(L_t)
+ dS_t = nisa.nc_transpose(dS)
+
+ C = nl.load(C_tensor[batch_id, i_p_in, group_id, i_f_state], dtype=dtype)
+
+ # === Compute dB gradient ===
+ dB = compute_chunk_output(dO_Xt, L_t_, C, X, a_right, dS_t, transpose_broadcast=True)
+
+ # === Update reverse time state dState ===
+ # Update the state: running sum dS (will be used in the next iteration)
+ barC = C_t * a_left[:1, :].broadcast_to((block_size, block_size))
+
+ barC_tX = nl.matmul(barC, dO, transpose_x=False)
+ dS[...] = a_center * dS + barC_tX
+
+ dB_accumulation[:, chunk_id, :] += dB
+ dA_accumulation[:, chunk_id, :] -= B * dB
+
+ # === Reverse cumulative sum for dA ===
+ cumsum_chunk = nl.matmul(ones_tril, dA_accumulation[:, chunk_id, :], transpose_x=True)
+ cumsum_chunk[...] = cumsum_chunk + nl.copy(cumsum_dA, dtype=dtype).broadcast_to((block_size, d_state))
+ cumsum_dA[0, i_f_state] = cumsum_chunk[0, i_f_state]
+
+ ddt = nl.matmul(cumsum_chunk * logA_cache[:, head_id], ones_sum_right) + nl.matmul(dU * X0,
+ ones_sum_right_head)
+
+ dA_chunk = nl.matmul(cumsum_chunk * dt, ones_sum_right)
+ dA_final[:, head_id] += nl.matmul(ones_sum_left, dA_chunk)
+
+ dX = dU * dt
+
+ if D is not None:
+ dD_chunk = nl.matmul(dO * X0, ones_sum_right_head)
+ dD_final[:, head_id] += nl.copy(nl.matmul(ones_sum_left, dD_chunk), dtype=dtype)
+ dX[...] = dX + dO * D[:, head_id]
+
+ nl.store(dX_tensor[batch_id, i_p_in, head_id, i_f_head], dX)
+ nl.store(ddt_tensor[batch_id, i_p_in, head_id], ddt)
+
+ nl.store(dlogA_tensor[batch_id, head_id], dA_final[0, head_id])
+ if D is not None:
+ nl.store(dD_tensor[batch_id, head_id], dD_final[0, head_id])
+
+ for chunk_id in nl.sequential_range(n_chunks):
+ i_p_in = i_p + chunk_id * block_size
+ nl.store(dC_tensor[batch_id, i_p_in, group_id, i_f_state], dC_accumulation[:, chunk_id, :])
+ nl.store(dB_tensor[batch_id, i_p_in, group_id, i_f_state], dB_accumulation[:, chunk_id, :])
+
+
+class Mamba2Kernel(torch.autograd.Function):
+ """Define the autograd function wih forward and backward kernel."""
+
+ @staticmethod
+ def forward(ctx, dt, A, X, B, C, D):
+ batch_size, seq_len, n_heads, d_head = X.shape
+ ctx.save_for_backward(dt, A, X, B, C, D)
+ out_Y = torch.empty_like(X)
+ mamba_kernel_[batch_size](dt, A, X, B, C, out_Y, D)
+ return out_Y
+
+ @staticmethod
+ def backward(ctx, d_output):
+ dt, A, X, B, C, D = ctx.saved_tensors
+ batch_size, seq_len, n_heads, d_head = X.shape
+
+ ddt = torch.empty_like(dt)
+ dA = torch.empty_like(A.unsqueeze(0).repeat(batch_size, 1))
+ dX = torch.empty_like(X)
+ dB = torch.empty_like(B)
+ dC = torch.empty_like(C)
+ dD = torch.empty_like(D.unsqueeze(0).repeat(batch_size, 1))
+
+ mamba_kernel_bwd_[batch_size](dt, A, X, B, C, d_output,
+ ddt, dA, dX, dB, dC, D, dD)
+ dA, dD = dA.sum(0), dD.sum(0)
+ return ddt, dA, dX, dB, dC, dD
+
+
+def mamba2_kernel(dt, A, X, B, C, D):
+ return Mamba2Kernel.apply(dt, A, X, B, C, D)
diff --git a/examples/training/mamba2/mamba2/mamba2_mixer.py b/examples/training/mamba2/mamba2/mamba2_mixer.py
new file mode 100644
index 0000000..af30670
--- /dev/null
+++ b/examples/training/mamba2/mamba2/mamba2_mixer.py
@@ -0,0 +1,242 @@
+# coding=utf-8
+# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
+# Modifications Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional
+
+import torch
+from torch import nn
+from torch.types import Device
+from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache
+from transformers.activations import ACT2FN
+from einops import rearrange
+import neuronx_distributed.parallel_layers.parallel_state as ps
+from neuronx_distributed.parallel_layers import RowParallelLinear, ColumnParallelLinear
+
+from .configuration_mamba2 import Mamba2Config
+
+from .mamba2_kernel import mamba2_kernel
+from .conv1d_grouped import ConvNKI
+
+
+def softplus(x, threshold=10):
+ return torch.where(x < threshold, torch.log(1 + torch.exp(x)), x)
+
+
+# Note: this implementation is different from the same module in `transformers` when n_groups>1
+# we normalize each channel group independently, while the original normalizes all channels in the same device
+# regardless of their group. Our version ensures that the checkpoint will behave the same when used with
+# a different tp degrees than during training (note however that using a large n_groups with few channels may
+# introduce training instabilities).
+class MambaRMSNormGated(nn.Module):
+ def __init__(self, d: int, eps: float = 1e-5, device: Device = None, n_groups: int = 1, rmsnorm_within_groups=True):
+ """Gated Root Mean Square Layer Normalization with support for groups
+
+ Paper: https://arxiv.org/abs/1910.07467
+
+ Mamba Official: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/layernorm_gated.py#L18
+ """
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(d, device=device))
+ self.n_groups = n_groups
+ self.rmsnorm_within_groups = rmsnorm_within_groups
+ self.parallel_split()
+
+ def parallel_split(self):
+ # Split weights across cores based on the current tensor parallelism rank.
+ tp_rank = ps.get_tensor_model_parallel_rank()
+ tp_size = ps.get_tensor_model_parallel_size()
+ dim = self.weight.shape[0]
+ assert dim % tp_size == 0
+ assert self.n_groups % tp_size == 0
+ self.n_groups = self.n_groups // tp_size
+ chunk = slice(dim // tp_size * tp_rank, dim // tp_size * (tp_rank + 1))
+ self.weight.data = self.weight.data[chunk].detach().clone()
+ return self
+
+ def forward(self, hidden_states, gate=None):
+ hidden_states = hidden_states.to(torch.float32)
+
+ if self.rmsnorm_within_groups:
+ hidden_states = rearrange(hidden_states, "... (g d) -> ... g d", g=self.n_groups)
+ gate = rearrange(gate, "... (g d) -> ... g d", g=self.n_groups)
+
+ if gate is not None:
+ hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
+
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+
+ if self.rmsnorm_within_groups:
+ res = self.weight * rearrange(hidden_states, "... g d -> ... (g d)", g=self.n_groups)
+ else:
+ res = self.weight * hidden_states
+
+ return res
+
+
+class Mamba2Mixer(nn.Module):
+ """
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
+ and is why Mamba is called **selective** state spaces)
+ """
+
+ def __init__(self, config: Mamba2Config, layer_idx: int):
+ super().__init__()
+ self.num_heads = config.num_heads
+ self.hidden_size = config.hidden_size
+ self.ssm_state_size = config.state_size
+ self.conv_kernel_size = config.conv_kernel
+ self.intermediate_size = int(config.expand * self.hidden_size)
+ self.time_step_rank = int(config.time_step_rank)
+ self.layer_idx = layer_idx
+ self.use_conv_bias = config.use_conv_bias
+ self.activation = config.hidden_act
+ self.act = ACT2FN[config.hidden_act]
+
+ self.layer_norm_epsilon = config.layer_norm_epsilon
+ self.rms_norm = config.rms_norm
+ self.rmsnorm_within_groups = config.rmsnorm_within_groups
+
+ self.n_groups = config.n_groups
+ self.head_dim = config.head_dim
+ self.chunk_size = config.chunk_size
+
+ self.time_step_limit = config.time_step_limit
+ self.time_step_min = config.time_step_min
+ self.time_step_max = config.time_step_max
+
+ assert self.intermediate_size % self.head_dim == 0
+ assert self.intermediate_size // self.head_dim == self.num_heads
+
+ self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
+ # This is a custom replacement of a grouped conv1d written as a NKI kernel for better efficiency.
+ # Note: the SiLU non-linearity is already applied inside the kernel.
+ self.conv1d = ConvNKI(
+ in_channels=self.conv_dim,
+ out_channels=self.conv_dim,
+ bias=config.use_conv_bias,
+ kernel_size=config.conv_kernel,
+ groups=self.conv_dim,
+ padding=config.conv_kernel - 1,
+ activation='silu',
+ )
+
+ # projection of the input hidden states
+ self.in_proj_z = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.use_bias, gather_output=False)
+ self.in_proj_xBC = ColumnParallelLinear(self.hidden_size, self.conv_dim, bias=config.use_bias, gather_output=False)
+ self.in_proj_dt = ColumnParallelLinear(self.hidden_size, self.num_heads, bias=config.use_bias, gather_output=False)
+
+ # time step projection (discretization)
+ dt = torch.exp(
+ torch.rand(config.num_heads)
+ * (math.log(config.time_step_max) - math.log(config.time_step_min))
+ + math.log(config.time_step_min)
+ ).clamp(min=config.time_step_floor)
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
+ self.dt_bias = nn.Parameter(inv_dt)
+ self.dt_bias._no_reinit = True
+
+ # S4D real initialization. These are not discretized!
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+ A = torch.arange(1, self.num_heads + 1)
+ self.A_log = nn.Parameter(torch.log(A))
+ self.A_log._no_weight_decay = True
+ self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon, n_groups=self.n_groups, rmsnorm_within_groups=self.rmsnorm_within_groups)
+ self.D = nn.Parameter(torch.ones(self.num_heads))
+ self.D._no_weight_decay = True
+
+ self.out_proj = RowParallelLinear(self.intermediate_size, self.hidden_size, bias=config.use_bias, input_is_parallel=True)
+ self.use_bias = config.use_bias
+ self.parallel_split()
+
+ def parallel_split(self):
+ # Split weights across cores based on the current tensor parallelism rank.
+ tp_rank = ps.get_tensor_model_parallel_rank()
+ tp_size = ps.get_tensor_model_parallel_size()
+ assert self.intermediate_size % tp_size == 0
+ assert self.n_groups % tp_size == 0
+ self.intermediate_size_tp = self.intermediate_size // tp_size
+ self.n_groups_tp = self.n_groups // tp_size
+ self.num_heads_tp = self.num_heads // tp_size
+ self.conv_dim_tp = self.conv_dim // tp_size
+ head_chunk = slice(self.num_heads_tp * tp_rank, self.num_heads_tp * (tp_rank + 1))
+ # note: we have to use .clone(), otherwise the result would be a view and the original would remain in memory
+ self.D.data = self.D.data[head_chunk].detach().clone()
+ self.A_log.data = self.A_log.data[head_chunk].detach().clone()
+ self.dt_bias.data = self.dt_bias.data[head_chunk].detach().clone()
+ return self
+
+ def nki_kernels_forward(
+ self,
+ hidden_states: torch.Tensor,
+ cache_params: Optional[Mamba2Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ # set up dimensions for reshapes later
+ batch_size, seq_len, _ = hidden_states.shape
+ groups_time_state_size_tp = self.n_groups_tp * self.ssm_state_size
+
+ assert cache_params is None, "cache not supported yet"
+ assert self.training, "only training supported right now"
+ assert attention_mask is None, "attention mask not supported yet"
+ assert self.time_step_limit[0] == 0.0 and self.time_step_limit[1] == float("inf"), "dt limit not supported yet"
+ assert self.activation in ["silu", "swish"]
+
+ A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
+
+ gate = self.in_proj_z(hidden_states)
+ hidden_states_B_C = self.in_proj_xBC(hidden_states)
+ time_step = self.in_proj_dt(hidden_states)
+
+ # 1D Convolution (SiLU non-linearity is fused inside)
+ hidden_states_B_C = self.conv1d(input=hidden_states_B_C.transpose(1, 2)).transpose(1, 2)
+ hidden_states, B, C = torch.split(
+ hidden_states_B_C,
+ [self.intermediate_size_tp, groups_time_state_size_tp, groups_time_state_size_tp],
+ dim=-1,
+ )
+
+ time_step = softplus(time_step + self.dt_bias)
+
+ scan_output = mamba2_kernel(time_step,
+ A,
+ hidden_states.view(batch_size, seq_len, self.num_heads_tp, -1),
+ B.view(batch_size, seq_len, self.n_groups_tp, -1),
+ C.view(batch_size, seq_len, self.n_groups_tp, -1),
+ self.D)
+
+ scan_output = scan_output.view(batch_size, seq_len, -1)
+ # Multiply "gate" branch and apply extra normalization layer
+ scan_output = self.norm(scan_output, gate)
+ out = self.out_proj(scan_output)
+
+ return out
+
+ def forward(
+ self,
+ hidden_states,
+ cache_params: Optional[Mamba2Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ assert "xla" in self.in_proj_xBC.weight.device.type, "This model only supports forward on an XLA device"
+ return self.nki_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
diff --git a/examples/training/mamba2/mamba2/modeling_mamba2.py b/examples/training/mamba2/mamba2/modeling_mamba2.py
new file mode 100644
index 0000000..ad491ee
--- /dev/null
+++ b/examples/training/mamba2/mamba2/modeling_mamba2.py
@@ -0,0 +1,470 @@
+# coding=utf-8
+# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
+# Modifications Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Adaptation of transformers.models.mamba2.modeling_mamba2 to be compatible with neuronx-distributed.
+
+ The class Mamba2Mixer has been moved to a separate file mamba2_mixer.py.
+"""
+
+import math
+from dataclasses import dataclass
+from functools import partial
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from neuronx_distributed.parallel_layers import ParallelEmbedding, RowParallelLinear, ColumnParallelLinear
+from neuronx_distributed.parallel_layers.loss_functions import parallel_cross_entropy
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from torch.types import Device
+import torch.distributed as distrib
+
+
+from transformers.generation import GenerationMixin
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+
+from .configuration_mamba2 import Mamba2Config
+from .mamba2_mixer import Mamba2Mixer
+
+from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2RMSNorm, MAMBA2_START_DOCSTRING, \
+ MAMBA2_INPUTS_DOCSTRING, Mamba2CausalLMOutput
+
+logger = logging.get_logger(__name__)
+
+selective_state_update = None
+
+causal_conv1d_update, causal_conv1d_fn = None, None
+
+
+_CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1"
+_CONFIG_FOR_DOC = "Mamba2Config"
+
+def _init_normal(std, w):
+ return nn.init.normal_(w, mean=0.0, std=std)
+
+class PartitionBreak(torch.autograd.Function):
+ """Workaround to help the compiler detect layers.
+
+ The compiler looks for all_gather/batch_norm/clamp to separate layers. Since those are not present in our model,
+ we insert a fake torch.clamp operation with inf boundary (so it acts like the identity) to ensure it will try to
+ break the layer at that point.
+ """
+
+ @staticmethod
+ def forward(ctx, input):
+ return torch.clamp(input, min=None, max=torch.inf)
+
+ @staticmethod
+ def backward(ctx, d_output):
+ return torch.clamp(d_output, min=None, max=torch.inf)
+
+def partition_break(x):
+ return PartitionBreak.apply(x)
+
+class Mamba2Block(nn.Module):
+ def __init__(self, config, layer_idx):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.residual_in_fp32 = config.residual_in_fp32
+ self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+ self.mixer = Mamba2Mixer(config, layer_idx=layer_idx)
+
+ def forward(
+ self,
+ hidden_states,
+ cache_params: Optional[Mamba2Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ residual = hidden_states
+ hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
+ if self.residual_in_fp32:
+ residual = residual.to(torch.float32)
+
+ hidden_states = self.mixer(
+ hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
+ )
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Mamba2PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Mamba2Config
+ base_model_prefix = "backbone"
+ _no_split_modules = ["Mamba2Block"]
+ supports_gradient_checkpointing = True
+ _is_stateful = True
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, Mamba2Mixer):
+ module.A_log._no_weight_decay = True
+ module.D._no_weight_decay = True
+ elif isinstance(module, nn.Linear):
+ if module.bias is not None:
+ if not getattr(module.bias, "_no_reinit", False):
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Embedding):
+ # fixme: this won't give consistent initialization with tp
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
+ elif isinstance(module, (ParallelEmbedding, RowParallelLinear, ColumnParallelLinear)):
+ # module.init_weight_cpu()
+ if hasattr(module, "bias") and module.bias is not None:
+ module.bias.data.zero_()
+
+ if self.config.rescale_prenorm_residual:
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
+ #
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+ for name, p in module.named_parameters():
+ if name in ["out_proj.weight"]:
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
+ # We need to reinit p since this code could be called multiple times
+ # Having just p *= scale would repeatedly scale it down
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
+ with torch.no_grad():
+ p /= math.sqrt(self.config.num_hidden_layers)
+
+
+@dataclass
+# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2
+class Mamba2Output(ModelOutput):
+ """
+ Class for the MAMBA2 model outputs.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ cache_params (`Mamba2Cache`):
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
+ avoid providing the old `input_ids`.
+
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ cache_params: Optional[Mamba2Cache] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@add_start_docstrings(
+ "The bare MAMBA2 Model transformer outputting raw hidden-states without any specific head on top.",
+ MAMBA2_START_DOCSTRING,
+)
+class Mamba2Model(Mamba2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ init_method = partial(_init_normal, config.initializer_range)
+ self.embeddings = ParallelEmbedding(config.vocab_size, config.hidden_size, init_method=init_method)
+
+ self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
+
+ self.gradient_checkpointing = False
+ self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+ # Initialize weights and apply final processing
+ self._register_load_state_dict_pre_hook(self.load_hook)
+ self.post_init()
+
+ def load_hook(self, state_dict, prefix, *args):
+ for k in state_dict:
+ if "embedding." in k:
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
+ break
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def set_input_embeddings(self, new_embeddings):
+ self.embeddings = new_embeddings
+
+ @add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Mamba2Output,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ cache_params: Optional[Mamba2Cache] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[Tuple, Mamba2Output]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embeddings(input_ids)
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ use_cache = False
+
+ if use_cache:
+ if cache_params is None:
+ cache_params = Mamba2Cache(
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
+ )
+ cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
+ elif cache_position is None:
+ # cases when we do manual forward instead of using `model.generate` which will initiate
+ # `cache_position` and makes sure it is not None, throw error here instead of doing some
+ # hack to conjecture the current cache position
+ raise ValueError(
+ "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
+ "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
+ "be initialized for you automatically"
+ )
+ else:
+ cache_params = None
+
+ hidden_states = inputs_embeds
+ all_hidden_states = () if output_hidden_states else None
+ for mixer_block in self.layers:
+ if self.gradient_checkpointing and self.training:
+ raise NotImplementedError("gradient checkpointing is not implemented yet")
+ hidden_states = self._gradient_checkpointing_func(
+ mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
+ )
+ else:
+ hidden_states = partition_break(mixer_block(
+ hidden_states,
+ cache_params=cache_params,
+ cache_position=cache_position,
+ attention_mask=attention_mask,
+ ))
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if use_cache:
+ cache_params.seqlen_offset += inputs_embeds.shape[1]
+
+ hidden_states = self.norm_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
+
+ return Mamba2Output(
+ last_hidden_state=hidden_states,
+ cache_params=cache_params if use_cache else None,
+ hidden_states=all_hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input
+ embeddings).
+ """,
+ MAMBA2_START_DOCSTRING,
+)
+class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = []
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.backbone = Mamba2Model(config)
+ # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ init_method = partial(_init_normal, config.initializer_range)
+ self.lm_head = ColumnParallelLinear(
+ config.hidden_size,
+ config.vocab_size,
+ bias=False,
+ gather_output=False,
+ init_method=init_method
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def get_input_embeddings(self):
+ return self.backbone.get_input_embeddings()
+
+ def set_input_embeddings(self, new_embeddings):
+ return self.backbone.set_input_embeddings(new_embeddings)
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ inputs_embeds=None,
+ use_cache=None,
+ cache_params: Optional[Mamba2Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ # Overwitten -- uses `cache_params` as opposed to `past_key_values`
+
+ if inputs_embeds is not None:
+ past_len = inputs_embeds.shape[1] + input_ids.shape[1]
+ else:
+ past_len = input_ids.shape[1]
+ if use_cache:
+ # `cache_position` should have been initialized in `generate`
+ if cache_position is None:
+ raise ValueError(
+ "`cache_position` should not be None as it should have been initialized in "
+ "`model.generate`, you are responsible for passing in a valid `cache_position` if "
+ "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
+ )
+ # how do we detect that we are in decoding without cache?
+ if cache_position[0] > 0:
+ input_ids = input_ids[:, -1][..., None]
+ attention_mask = attention_mask[:, -1][..., None]
+ else:
+ # we initialize the `cache_position` to full size of `conv_states` at prefill stage
+ # considering padding will be applied when input length is shorter, and truncation
+ # will be applied when it is longer, so it will be equivalent to always have it match
+ # the length of `cache_params.conv_states`, which is `config.conv_kernel`
+ cache_position = torch.arange(0, past_len, device=input_ids.device)
+ # if the cache is not used, we also do have to extend the attention mask here
+ # TODO there is likely a cleverer way to do this
+ extended_mask = torch.ones(
+ attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device
+ )
+ attention_mask = torch.cat([attention_mask, extended_mask], dim=1)
+ cache_params = None
+
+ if attention_mask.shape[1] < past_len:
+ # we have to update manually the attention mask if
+ # we are in decoding without cache
+ # and we don't have position_ids here
+ # TODO but we should be able to use cache_position though at a later time
+ extended_mask = torch.ones(
+ attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device
+ )
+ attention_mask = torch.cat([attention_mask, extended_mask], dim=1)
+ if inputs_embeds is not None and cache_params is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "attention_mask": attention_mask,
+ "cache_params": cache_params,
+ "use_cache": use_cache,
+ "cache_position": cache_position,
+ }
+ )
+ return model_inputs
+
+ @add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Mamba2CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_params: Optional[Mamba2Cache] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs, # for now we need this for generation
+ ) -> Union[Tuple, Mamba2CausalLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ mamba2_outputs = self.backbone(
+ input_ids,
+ cache_params=cache_params,
+ inputs_embeds=inputs_embeds,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ attention_mask=attention_mask,
+ )
+ hidden_states = mamba2_outputs[0]
+
+ logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ loss_fct = parallel_cross_entropy
+ # Flatten the tokens
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+ loss = loss.mean()
+
+ if not return_dict:
+ output = (logits,) + mamba2_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return Mamba2CausalLMOutput(
+ loss=loss,
+ logits=logits,
+ cache_params=mamba2_outputs.cache_params,
+ hidden_states=mamba2_outputs.hidden_states,
+ )
diff --git a/examples/training/mamba2/predefined_configs.py b/examples/training/mamba2/predefined_configs.py
new file mode 100644
index 0000000..a342057
--- /dev/null
+++ b/examples/training/mamba2/predefined_configs.py
@@ -0,0 +1,43 @@
+"""Utility script to create a Mamba2Config given the size of the model."""
+
+from typing import Dict
+
+from mamba2 import Mamba2Config
+from dataclasses import dataclass, astuple
+
+
+@dataclass
+class ConfParams:
+ d_model: int
+ n_layers: int
+ head_dim: int = 128
+
+
+CONFIGS_KWARGS: Dict[str, ConfParams] = {
+ 'Mamba130M': ConfParams(d_model=768, n_layers=24),
+ 'Mamba370M': ConfParams(d_model=1024, n_layers=48),
+ 'Mamba780M': ConfParams(d_model=1536, n_layers=48),
+ 'Mamba1B': ConfParams(d_model=2048, n_layers=48),
+ 'Mamba3B': ConfParams(d_model=2560, n_layers=64),
+ 'Mamba7B': ConfParams(d_model=4096, n_layers=64),
+}
+
+
+def get_config(name: str, vocab_size, rmsnorm_within_groups=True, n_groups=8):
+ d_model, n_layers, head_dim = astuple(CONFIGS_KWARGS[name])
+ config = Mamba2Config(
+ vocab_size=vocab_size,
+ hidden_size=d_model,
+ head_dim=head_dim,
+ num_heads=(d_model * 2) // head_dim,
+ num_hidden_layers=n_layers,
+ tie_word_embeddings=True,
+ use_cache=False,
+ n_groups=n_groups,
+ bos_token_id=0,
+ eos_token_id=0,
+ pad_token_id=0,
+ rmsnorm_within_groups=rmsnorm_within_groups,
+ )
+ return config
+
diff --git a/examples/training/mamba2/train.py b/examples/training/mamba2/train.py
new file mode 100644
index 0000000..4b47053
--- /dev/null
+++ b/examples/training/mamba2/train.py
@@ -0,0 +1,270 @@
+import os
+
+import torch
+import torch.distributed as dist
+import time
+
+from transformers import AdamW
+
+import neuronx_distributed as nxd
+import neuronx_distributed.parallel_layers.parallel_state as ps
+import torch_xla.distributed.parallel_loader as pl
+from neuronx_distributed.utils.adamw_fp32_optim_params import AdamW_FP32OptimParams
+from transformers.optimization import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
+
+from mamba2 import Mamba2ForCausalLM
+from training_utils import create_llama_pretraining_dataset
+from predefined_configs import get_config
+
+import transformers.modeling_utils as modeling_utils
+
+# For PT autocast.
+torch.cuda.is_bf16_supported = lambda: True
+
+# Environment variables set by torch.distributed.launch
+LOCAL_RANK = int(os.environ['LOCAL_RANK'])
+WORLD_SIZE = int(os.environ['WORLD_SIZE'])
+WORLD_RANK = int(os.environ['RANK'])
+
+
+def get_mixed_precision_config(args):
+ if args.use_zero_1:
+ return {
+ "use_master_weights": True,
+ "use_fp32_grad_acc": True,
+ "use_master_weights_in_ckpt": False,
+ }
+ else:
+ return {}
+
+
+def run(args, backend):
+ import numpy as np
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+
+ device = xm.xla_device()
+
+ config = get_config(args.model, args.vocab_size, rmsnorm_within_groups=(not args.rmsnorm_across_groups),
+ n_groups=args.tp)
+
+ def get_model():
+ model = Mamba2ForCausalLM(config).to(dtype=args.dtype)
+ model.train()
+
+ # check that weight tying worked
+ if model.config.tie_word_embeddings:
+ assert model.backbone.embeddings.weight is model.lm_head.weight
+ return model
+
+ nxd_config = nxd.neuronx_distributed_config(
+ tensor_parallel_size=args.tp,
+ optimizer_config={"zero_one_enabled": args.use_zero_1, "grad_clipping": True, "max_grad_norm": 1.0},
+ sequence_parallel=False,
+ model_init_config=None,
+ mixed_precision_config=get_mixed_precision_config(args),
+ )
+ model = nxd.initialize_parallel_model(nxd_config, get_model)
+ world_size = ps.get_data_parallel_size()
+ if xm.is_master_ordinal():
+ print('NEURON_CC_FLAGS: ', os.environ.get('NEURON_CC_FLAGS', None))
+ print('XLA_IR_DEBUG: ', os.environ.get('XLA_IR_DEBUG', None))
+ print('XLA_HLO_DEBUG: ', os.environ.get('XLA_HLO_DEBUG', None))
+ print('TP groups:', ps.get_tensor_model_parallel_group(as_list=True))
+ print('DP groups:', ps.get_data_parallel_group(as_list=True))
+ print('Config: ', config)
+ param_size, dtype = 0, None
+ for param in set(model.parameters()):
+ param_size += param.nelement()
+ dtype = param.dtype
+ print(f"Model size: {param_size / 10 ** 6:.1f}M parameters/core")
+ print(f"Param dtype: {dtype}")
+
+ param_optimizer = list(model.named_parameters())
+
+ no_decay = ["bias", "LayerNorm", "norm", "A", "D"]
+
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
+ "weight_decay": args.weight_decay,
+ },
+ {
+ "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ if args.no_optimizer_fp32:
+ optimizer_cls = AdamW
+ else:
+ optimizer_cls = AdamW_FP32OptimParams
+
+ # Creating NxD Optimizer
+ optimizer = nxd.initialize_parallel_optimizer(
+ nxd_config,
+ optimizer_cls,
+ optimizer_grouped_parameters,
+ lr=args.lr,
+ betas=(args.beta1, args.beta2),
+ )
+ optimizer.zero_grad()
+ if args.use_zero_1:
+ optimizer.optimizer.init_zero()
+
+ scheduler = get_cosine_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=args.warmup_steps,
+ num_training_steps=args.max_steps,
+ last_epoch=-1,
+ )
+
+ train_dataloader, _ = create_llama_pretraining_dataset(
+ args.data_dir,
+ args.batch,
+ ps.get_data_parallel_size(),
+ ps.get_data_parallel_rank(),
+ args.seed,
+ )
+ # We wrap the dataloader with MpDeviceLoader. This dataloader should take
+ # care of copying the tensors to device and also inserting the mark_step at
+ # iteration end.
+ train_device_loader = pl.MpDeviceLoader(train_dataloader, device)
+
+ running_loss = torch.zeros(1).to(device)
+ training_step, global_step = 0, 0
+ t0 = time.time()
+ xm.mark_step()
+ for i, data in enumerate(train_device_loader):
+ training_step += 1
+ input_ids = data["input_ids"]
+ labels = data["labels"]
+
+ out = model(input_ids=input_ids, labels=labels)
+ loss = out.loss / args.grad_accum_usteps
+ loss.backward()
+ xm.mark_step()
+
+ running_loss += loss.detach()
+
+ if training_step % args.grad_accum_usteps == 0:
+ xm.mark_step()
+ if xm.is_master_ordinal():
+ print(f"Global Step {global_step}")
+ print(f"Loss: {loss.item()}")
+ print(f"running_loss: {running_loss.item()}")
+
+ running_loss.zero_()
+ optimizer.step()
+ optimizer.zero_grad()
+ scheduler.step()
+ global_step += 1
+
+ xm.mark_step()
+ if xm.is_master_ordinal():
+ print(f'process time of this batch is: {time.time() - t0}')
+ t0 = time.time()
+
+ if (args.checkpoint_freq > 0) and (global_step % args.checkpoint_freq == 0):
+ xm.add_step_closure(
+ nxd.save_checkpoint, (
+ args.checkpoint_dir, # checkpoint directory
+ f"{args.tag}_step_{global_step}", # tag
+ model, # model
+ optimizer, # optimizer
+ scheduler, # scheduler
+ {"global_step": global_step, "cli_args": args.__dict__}, # user content
+ )
+ )
+
+ if global_step >= args.max_steps:
+ xm.mark_step()
+ break
+ # xm.mark_step() # final mark_step not needed when using MpDeviceLoader
+
+
+def init_processes(args, backend):
+ dist.init_process_group(backend, rank=WORLD_RANK, world_size=WORLD_SIZE)
+ run(args=args, backend=backend)
+ xm.rendezvous("_mp_fn finished")
+
+
+def tp_loader(state_dict, tp_rank, tp_size, config):
+ """Load the correct slice of weights from a checkpoint for the current core give the tensor parallel degree."""
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k.endswith('out_proj.weight'): # row parallel
+ dim_1_shape = v.shape[1]
+ cv = torch.split(v, dim_1_shape // tp_size, dim=1)
+ new_state_dict[k] = cv[tp_rank]
+ elif k.endswith('in_proj_xBC.weight') or 'conv1d' in k: # xBC and Conv Col para
+ wx, wB, wC = torch.split(v, [config.hidden_size * 2, config.n_groups * config.state_size,
+ config.n_groups * config.state_size],
+ dim=0)
+ wx_tp = torch.split(wx, wx.shape[0] // tp_size, dim=0)[tp_rank]
+ wB_tp = torch.split(wB, wB.shape[0] // tp_size, dim=0)[tp_rank]
+ wC_tp = torch.split(wC, wC.shape[0] // tp_size, dim=0)[tp_rank]
+ xBC_tp = torch.cat((wx_tp, wB_tp, wC_tp), dim=0)
+ new_state_dict[k] = xBC_tp
+ elif 'norm' in k and 'mixer' not in k:
+ new_state_dict[k] = v
+ else: # norm weight and z and dt
+ dim_0_shape = v.shape[0]
+ rv = torch.split(v, dim_0_shape // tp_size, dim=0)
+ new_state_dict[k] = rv[tp_rank]
+
+ return new_state_dict
+
+
+if __name__ == '__main__':
+ import torch_xla.core.xla_model as xm
+ import os
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Mamba2Block Configuration")
+ parser.add_argument("--seed", type=int, default=100, help="Random seed")
+ parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"], help="Data type")
+ parser.add_argument("--batch", type=int, default=1, help="Batch size")
+ parser.add_argument("--vocab_size", type=int, default=50272, help="Vocab size")
+ parser.add_argument("--model", default="Mamba130M", help="Hugging face model to profile")
+ parser.add_argument("--backend", type=str, default="xla", choices=['xla', 'nccl', 'gloo'])
+ parser.add_argument("--tp", type=int, default=1, help="Tensor Parallel Size")
+
+ parser.add_argument("--no_optimizer_fp32", action="store_true", help="Do not use FP32 for the optimizer state.")
+ parser.add_argument("--use_zero_1", action="store_true", help="Use ZeRO-1.")
+ parser.add_argument("--lr", type=float, default=4e-4, help="Learning rate.")
+ parser.add_argument("--weight_decay", default=0.01, type=float, help="weight decay")
+ parser.add_argument("--beta1", default=0.9, type=float, help="beta1 parameter for Adam optimizer")
+ parser.add_argument("--beta2", default=0.999, type=float, help="beta2 parameter for Adam optimizer")
+ parser.add_argument("--data_dir", type=str, help="Pre-tokenized dataset directory.")
+
+ parser.add_argument("--warmup_steps", type=int, default=2000,
+ help="Number of warmup accumulation-steps for learning rate .")
+ parser.add_argument("--max_steps", type=int, help="Maximum total accumulation-steps to run.")
+ parser.add_argument("--grad_accum_usteps", type=int, default=1,
+ help="Gradient accumulation micro-steps (an accumulation-step has micro-steps.")
+ parser.add_argument("--rmsnorm_across_groups", action="store_true",
+ help="Uses (HF style) RMSNorm instead of the custom one that normalizes independently for each of the n_groups.")
+ parser.add_argument("--debug", "-d", action="store_true",
+ help="Enable Neuron debugging flags to dump model graph and compiler logs.")
+ parser.add_argument("--checkpoint_freq", type=int, help="ckpt save freq.")
+ parser.add_argument("--checkpoint_dir", type=str, default="./", help="ckpt saving dir")
+ parser.add_argument("--tag", type=str, default="exp", help="ckpt saving name")
+
+ args = parser.parse_args()
+
+ os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "0"
+ args.dtype = getattr(torch, args.dtype)
+
+ if args.dtype == torch.bfloat16:
+ modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16
+
+ if args.debug:
+ # Debug flags to dump the annotated HLO graph, useful for profiling
+ os.environ["XLA_IR_DEBUG"] = "1"
+ os.environ["XLA_HLO_DEBUG"] = "1"
+ os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
+
+ os.environ["NEURON_CC_FLAGS"] = " --model-type=transformer -O1"
+
+ init_processes(args, backend=args.backend)
diff --git a/examples/training/mamba2/training_utils.py b/examples/training/mamba2/training_utils.py
new file mode 100644
index 0000000..91c1778
--- /dev/null
+++ b/examples/training/mamba2/training_utils.py
@@ -0,0 +1,359 @@
+import json
+import math
+import os
+import queue
+import time
+from datetime import datetime, timezone
+from functools import partial
+from itertools import chain
+from typing import Any, Dict, List
+
+import datasets
+import torch
+from torch.utils.data import DistributedSampler
+from torch.utils.data.dataloader import DataLoader
+from transformers import default_data_collator, set_seed
+
+try:
+ from lr import CosineAnnealing
+except ImportError:
+ CosineAnnealing = None
+
+from collections import namedtuple
+
+Metric = namedtuple("Metric", ["name", "value", "units", "additional_data"])
+remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}
+
+
+# empty list to save remainder from batches to use in next batch
+def pack_dataset(dataset, chunk_length=2048):
+ print(f"Chunking dataset into chunks of {chunk_length} tokens.")
+
+ def chunk(sample, chunk_length=chunk_length):
+ # define global remainder variable to save remainder from batches to use in next batch
+ global remainder
+ # Concatenate all texts and add remainder from previous batch
+ concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
+ concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
+ # get total number of tokens for batch
+ batch_total_length = len(concatenated_examples[list(sample.keys())[0]])
+
+ # get max number of chunks for batch
+ if batch_total_length >= chunk_length:
+ batch_chunk_length = (batch_total_length // chunk_length) * chunk_length
+
+ # Split by chunks of max_len.
+ result = {
+ k: [t[i: i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
+ for k, t in concatenated_examples.items()
+ }
+ # add remainder to global variable for next batch
+ remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
+ # prepare labels
+ result["labels"] = result["input_ids"].copy()
+ return result
+
+ # tokenize and chunk dataset
+ lm_dataset = dataset.map(
+ partial(chunk, chunk_length=chunk_length),
+ batched=True,
+ )
+ print(f"Total number of samples: {len(lm_dataset)}")
+ return lm_dataset
+
+
+def get_learning_rate_scheduler(optimizer, args, last_epoch=-1):
+ lr_scheduler = CosineAnnealing(
+ optimizer,
+ max_steps=args.max_steps,
+ min_lr=args.min_lr,
+ warmup_steps=args.warmup_steps,
+ constant_steps=args.constant_steps,
+ last_epoch=last_epoch,
+ )
+ return lr_scheduler
+
+
+def get_param_groups_by_weight_decay(model):
+ """Get param groups."""
+ if hasattr(model, "local_named_parameters") and hasattr(model, "partitioned") and model.partitioned:
+ # Zero1 use the first param in opt to decide the device
+ param_optimizer = list(model.local_named_parameters())
+ else:
+ param_optimizer = list(model.named_parameters())
+ no_decay = ["bias", "LayerNorm"] # gamma/beta are in LayerNorm.weight
+
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
+ "weight_decay": 0.01,
+ },
+ {
+ "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ ]
+ return optimizer_grouped_parameters
+
+
+def create_llama_pretraining_dataset(data_dir, mini_batch_size, dp_size, dp_rank, seed):
+ # Workaround because python functions are not picklable
+ class WorkerInitObj(object):
+ def __init__(self, seed):
+ self.seed = seed
+
+ def __call__(self, id):
+ set_seed(self.seed)
+
+ worker_init = WorkerInitObj(seed)
+ train_data = datasets.load_from_disk(data_dir)
+ train_sampler = DistributedSampler(
+ train_data,
+ num_replicas=dp_size,
+ rank=dp_rank,
+ shuffle=False,
+ drop_last=True,
+ )
+ train_dataloader = DataLoader(
+ train_data,
+ collate_fn=default_data_collator,
+ sampler=train_sampler,
+ batch_size=mini_batch_size,
+ num_workers=0,
+ worker_init_fn=worker_init,
+ drop_last=True,
+ pin_memory=True,
+ )
+ return train_dataloader, None
+
+
+def create_instruction_based_dataset(data_dir, mini_batch_size, dp_size, dp_rank, seed, tokenizer=None, task=None):
+ raw_datasets = datasets.load_dataset(data_dir, split="train")
+ if task:
+ raw_datasets = raw_datasets.filter(lambda example: example["category"] == task)
+ train_and_test_dataset = raw_datasets.train_test_split(test_size=8)
+ train_dataset = train_and_test_dataset["train"]
+ test_dataset = train_and_test_dataset["test"]
+
+ def preprocess_train_dataset(sample):
+ instruction = f"### Instruction\n{sample['instruction']}"
+ context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
+ response = f"### Answer\n{sample['response']}"
+ # join all the parts together
+ prompt = "\n".join([i for i in [instruction, context, response] if i is not None])
+ model_input = tokenizer(f"{prompt}{tokenizer.eos_token}")
+ return model_input
+
+ train_data = train_dataset.shuffle().map(preprocess_train_dataset, remove_columns=train_dataset.column_names)
+ train_data = pack_dataset(train_data, chunk_length=2048)
+
+ class WorkerInitObj(object):
+ def __init__(self, seed):
+ self.seed = seed
+
+ def __call__(self, id):
+ set_seed(self.seed)
+
+ worker_init = WorkerInitObj(seed)
+
+ train_sampler = DistributedSampler(
+ train_data,
+ num_replicas=dp_size,
+ rank=dp_rank,
+ shuffle=True,
+ drop_last=True,
+ )
+ train_dataloader = DataLoader(
+ train_data,
+ collate_fn=default_data_collator,
+ sampler=train_sampler,
+ batch_size=mini_batch_size,
+ num_workers=0,
+ worker_init_fn=worker_init,
+ drop_last=True,
+ pin_memory=True,
+ )
+
+ def preprocess_test_dataset(sample):
+ instruction = f"### Instruction\n{sample['instruction']}"
+ context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
+ response = "### Answer\n"
+ # join all the parts together
+ prompt = "\n".join([i for i in [instruction, context, response] if i is not None])
+ model_input = tokenizer(prompt, add_special_tokens=False)
+ labels = tokenizer(sample["response"], add_special_tokens=False)
+ model_input["labels"] = labels["input_ids"]
+ return model_input
+
+ test_data = test_dataset.map(preprocess_test_dataset, remove_columns=test_dataset.column_names)
+
+ test_sampler = DistributedSampler(
+ test_data,
+ num_replicas=dp_size,
+ rank=dp_rank,
+ shuffle=False,
+ drop_last=False,
+ )
+ test_dataloader = DataLoader(
+ test_data,
+ collate_fn=default_data_collator,
+ sampler=test_sampler,
+ batch_size=mini_batch_size,
+ num_workers=0,
+ drop_last=False,
+ pin_memory=True,
+ )
+
+ return train_dataloader, test_dataloader
+
+
+def create_partition(num_hidden_layers, pipeline_parallel_size):
+ """
+ Evenly split the transformer layers between the PP ranks
+ """
+ assert num_hidden_layers % pipeline_parallel_size == 0
+ num_layer_per_partition = num_hidden_layers // pipeline_parallel_size
+ pipeline_cuts = []
+ current_cut = num_layer_per_partition - 1
+ for i in range(pipeline_parallel_size - 1):
+ pipeline_cuts.append(f"model.layers.{current_cut}")
+ current_cut += num_layer_per_partition
+ return pipeline_cuts
+
+
+def get_sin_cos_matrix(config):
+ head_dim = config.hidden_size // config.num_attention_heads
+ base = config.rope_theta
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
+ t = torch.arange(config.max_position_embeddings, dtype=inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ return emb.cos()[None, None, :, :].to(torch.float32), emb.sin()[None, None, :, :].to(torch.float32)
+
+
+def get_dtype(model) -> str:
+ """
+ Reference: https://pytorch.org/xla/release/1.12/index.html#xla-tensors-and-bfloat16
+ """
+ if "XLA_USE_BF16" in os.environ:
+ return "torch.bfloat16"
+ if "XLA_DOWNCAST_BF16" in os.environ:
+ if "torch.float" in str(model.dtype):
+ return "torch.bfloat16"
+ if "torch.double" in str(model.dtype):
+ return "torch.float32"
+ return str(model.dtype)
+
+
+def print_logs(loss, global_norm, args, throughput, logger, total_steps, current_lr, input_ids, start):
+ total_norm_cpu = global_norm.cpu().item()
+ logger.log(total_steps, loss, total_norm_cpu, current_lr, input_ids, throughput, start)
+
+
+class TrainingMetrics:
+ """
+ This class is used for logging metrics to a json file. One can provide a
+ dictionary of metrics that needs to be stored, and it wpuld get
+ written to the file.
+ Arguments:
+ json_file: File used for logging. If no file exists, new file would be created.
+ """
+
+ def __init__(self, json_file):
+ self.json_file = json_file
+
+ def read_modify_write_file(self, data, key: str = "metrics") -> None:
+ """
+ data (dict of training parameters or list of metrics): Data to update in the file.
+ key (str): the dictionary key under which data is to be recorded
+ """
+ result_dict = {}
+ print(f"Writing data to the provided results file: {self.json_file}")
+ if os.path.exists(self.json_file):
+ with open(self.json_file, "r") as json_file:
+ content = json_file.read()
+ if not content.strip(): # Check if content is empty or contains only whitespace
+ print("File is empty or contains only whitespace.")
+ else:
+ result_dict = json.loads(content) or result_dict
+ print(f"Updating with {key} data: {data}")
+ if result_dict:
+ try:
+ # handle internal named entity if present
+ results = result_dict[next(iter(result_dict))]
+ except Exception:
+ results = result_dict
+ current = results.get(key)
+ if not current:
+ results[key] = data
+ else:
+ if isinstance(current, list):
+ current.extend(data)
+ elif isinstance(current, dict):
+ current.update(data)
+ else:
+ result_dict["results"] = {key: data}
+ with open(self.json_file, "w") as json_file:
+ json.dump(result_dict, json_file)
+
+ def store_metrics(self, metrics: List[Metric]) -> None:
+ """
+ Writes collected metrics to the file.
+ """
+ data = [
+ {
+ "MetricName": metric.name,
+ "MeasuredValue": metric.value,
+ "Units": metric.units,
+ "Timestamp": datetime.now(timezone.utc).isoformat(),
+ "AdditionalData": metric.additional_data,
+ }
+ for metric in metrics
+ ]
+ self.update(data=data, key="metrics")
+
+ def store_parameters(self, parameters: Dict[str, Any]) -> None:
+ """
+ Writes specified model and configuration parameters to the file.
+ """
+ self.update(data=parameters, key="parameters")
+
+ def update(self, **kwargs: Any) -> None:
+ """
+ Write specified data to the output file.
+ """
+ self.read_modify_write_file(**kwargs)
+
+
+class Throughput:
+ def __init__(self, batch_size, world_size, grad_accum_usteps, moving_avg_window_size=10, logging_interval=1):
+ """
+ Used to calculate the throughput over a moving window. It records the step time
+ between two calls and uses that time to calculate the throughput.
+ """
+ self.seqs_per_iteration = batch_size * world_size * grad_accum_usteps * logging_interval
+ self.moving_avg_window_size = math.ceil(moving_avg_window_size / logging_interval)
+ self.moving_avg_window = queue.Queue()
+ self.window_time = 0
+ self.start_time = time.time()
+
+ def get_throughput(self):
+ step_time = time.time() - self.start_time
+ self.start_time += step_time
+ self.window_time += step_time
+ self.moving_avg_window.put(step_time)
+ window_size = self.moving_avg_window.qsize()
+ if window_size > self.moving_avg_window_size:
+ self.window_time -= self.moving_avg_window.get()
+ window_size -= 1
+ throughput = window_size * self.seqs_per_iteration / self.window_time
+ return throughput
+
+
+def get_mixed_precision_config(use_gpu_compatible_precision):
+ return {
+ "use_master_weights": bool(use_gpu_compatible_precision),
+ "use_fp32_grad_acc": bool(use_gpu_compatible_precision),
+ "use_master_weights_in_ckpt": False,
+ }