Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions examples/multimodal/speech_llm/export/extract_salm_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,35 @@
import torch
from lightning.pytorch.trainer.trainer import Trainer
from megatron.core import dist_checkpointing
from omegaconf import OmegaConf
from omegaconf import OmegaConf, DictConfig
from typing import Dict

from nemo.collections.multimodal.speech_llm.modules.perception_modules import AudioPerceptionModule
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.collections.nlp.parts.utils_funcs import load_state_dict_helper
from nemo.collections.common.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.utils import logging
from nemo.utils.model_utils import inject_model_parallel_rank


def load_state_dict_helper(cls, cfg: DictConfig, trainer: Trainer, state_dict: Dict[str, torch.Tensor]):
"""Load state_dict for converted community, for example, HuggingFace models."""
model = cls(cfg, trainer)

missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
# Keys ending with '_extra_state' are related to Transformer Engine internals
missing_keys_non_extra = [key for key in missing_keys if not key.endswith("_extra_state")]
if missing_keys_non_extra:
logging.critical("Missing keys were detected during the load, something has gone wrong. Aborting.")
raise RuntimeError(f"Missing keys: \n{missing_keys_non_extra}")

if unexpected_keys:
logging.critical("Unexpected keys were detected which should not happen. Aborting.")
raise RuntimeError(f"Unexpected keys: \n{unexpected_keys}")

return model


def get_config_and_state_dict_from_nemo(filepath, map_location, output_dir, sharded_state_dict=None):
cwd = os.getcwd()
save_restore_connector = NLPSaveRestoreConnector()
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal/speech_llm/modular_audio_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from omegaconf.omegaconf import OmegaConf

from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder
from nemo.collections.common.parts.megatron_trainer_builder import MegatronTrainerBuilder
from nemo.core.config import hydra_runner
from nemo.utils import logging

Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal/speech_llm/modular_audio_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from omegaconf.omegaconf import OmegaConf, open_dict

from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder
from nemo.collections.common.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder
from nemo.core.config import hydra_runner
from nemo.utils import logging, model_utils
from nemo.utils.exp_manager import exp_manager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# pylint: disable=line-too-long
from nemo.collections.common.video_tokenizers.cosmos_tokenizer import CausalVideoTokenizer
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam
from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy
from nemo.collections.common.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy
from nemo.core.config import hydra_runner

"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# pylint: disable=line-too-long
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam
from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy
from nemo.collections.common.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy
from nemo.core.config import hydra_runner

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,12 +17,11 @@

from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.plugins.environments import TorchElasticEnvironment
from omegaconf import DictConfig, open_dict

from nemo.collections.common.metrics.perf_metrics import FLOPsMeasurementCallback
from nemo.collections.nlp.parts.nlp_overrides import (
from nemo.collections.common.parts.nlp_overrides import (
CustomProgressBar,
FSDPMixedPrecisionPlugin,
GradScaler,
Expand Down Expand Up @@ -171,7 +170,8 @@ def _callbacks(self, callbacks: Optional[list]) -> list:
"""
if callbacks is None:
callbacks = []
# enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks
# enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False,
# CustomProgressBar is not appended to callbacks
if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar:
callbacks.append(CustomProgressBar())

Expand All @@ -186,6 +186,7 @@ def _callbacks(self, callbacks: Optional[list]) -> list:
return callbacks

def create_trainer(self, callbacks=None) -> Trainer:
""" """
# Make a dummy train step if skip_train
if self.cfg.model.get("skip_train", False):
self.cfg.trainer.max_steps = 1
Expand All @@ -195,7 +196,8 @@ def create_trainer(self, callbacks=None) -> Trainer:
self.cfg.trainer.num_sanity_val_steps = 0
self.cfg.exp_manager.create_checkpoint_callback = False

# cfg.trainer.precision becomes None in Trainer if precision_plugins exist since both precision plugins and precision
# cfg.trainer.precision becomes None in Trainer if precision_plugins exist
# since both precision plugins and precision
precision = self.cfg.trainer.precision
strategy = self._training_strategy()
plugins = self._plugins()
Expand All @@ -206,54 +208,6 @@ def create_trainer(self, callbacks=None) -> Trainer:
return trainer


class MegatronBertTrainerBuilder(MegatronTrainerBuilder):
"""Builder for BERT model Trainer with overrides."""

def _grad_scaler(self) -> GradScaler:
return GradScaler(
init_scale=self.cfg.model.get('native_amp_init_scale', 2**32),
growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000),
)


class MegatronT5TrainerBuilder(MegatronTrainerBuilder):
"""Builder for T5 model Trainer with overrides."""

def _callbacks(self, callbacks: Optional[list]) -> list:
callbacks = super()._callbacks(callbacks)
callbacks.append(ModelSummary(max_depth=3))
return callbacks

def create_trainer(self, callbacks=None) -> Trainer:
strategy = self._training_strategy()
plugins = self._plugins()
callbacks = self._callbacks(callbacks)
return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks)


class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder):
"""Builder for SD model Trainer with overrides."""

def _training_strategy(self) -> NLPDDPStrategy:
"""
Returns a ddp strategy passed to Trainer.strategy.
"""
ddp_overlap = self.cfg.model.get("ddp_overlap", True)
if ddp_overlap:
return NLPDDPStrategy(
no_ddp_communication_hook=False,
gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view,
find_unused_parameters=True,
bucket_cap_mb=256,
)
else:
return NLPDDPStrategy(
no_ddp_communication_hook=True,
gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)


class MegatronLMPPTrainerBuilder(MegatronTrainerBuilder):
"""Builder for scripts where grad scaler is turned off for pipeline parallel LM model. E.g. PEFT tuning scripts"""

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -66,7 +66,6 @@

from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.transformer import AutocastTransformerLayer, ParallelTransformerLayer
from nemo.collections.nlp.parts import utils_funcs
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.core.optim import MainParamsOptimizerWrapper
from nemo.core.optim.optimizers import init_optimizer_states
Expand Down Expand Up @@ -140,6 +139,21 @@
"""


def torch_dtype_from_precision(precision: Union[int, str], megatron_amp_O2: Optional[bool] = None) -> torch.dtype:
"""Mapping from PTL precision types to corresponding PyTorch parameter datatype."""
if megatron_amp_O2 is not None and megatron_amp_O2 is False:
return torch.float32

if precision in ['bf16', 'bf16-mixed']:
return torch.bfloat16
elif precision in [16, '16', '16-mixed']:
return torch.float16
elif precision in [32, '32', '32-true']:
return torch.float32
else:
raise ValueError(f"Could not parse the precision of `{precision}` to a valid torch.dtype")


def init_model_parallel(
sharp: bool, nccl_communicator_config_path: str = None, distributed_timeout_minutes: int = 30
) -> None:
Expand Down Expand Up @@ -846,9 +860,9 @@ def _set_mixed_precision_recipe(
raise ValueError(f"Was unable to infer precision type, received {precision!r}.")
# Over-write gradient reduction dtype to support bf16 computation with fp32 grad reduction
if grad_reduce_dtype is not None:
reduce_dtype = utils_funcs.torch_dtype_from_precision(grad_reduce_dtype, None)
reduce_dtype = torch_dtype_from_precision(grad_reduce_dtype, None)
if set_buffer_dtype is not None:
buffer_dtype = utils_funcs.torch_dtype_from_precision(buffer_dtype, None)
buffer_dtype = torch_dtype_from_precision(buffer_dtype, None)
return MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@
get_iterator_k_split,
)
from nemo.collections.nlp.modules.common.text_generation_utils import get_computeprob_response
from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.collections.multimodal.speech_llm.parts.peft_config import PEFT_CONFIG_MAP
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.classes.mixins import adapter_mixins
Expand Down Expand Up @@ -93,6 +92,10 @@
default_inference_config = {'tokens_to_generate': 30}


def get_last_rank():
return torch.distributed.get_world_size() - 1


class ModularAudioGPTModel(SpeechLLMAdapterMixin, MegatronGPTSFTModel):
"""Modularized speech GPT model."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@
build_position_ids,
get_iterator_k_split,
)
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.collections.common.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.core.classes.mixins import adapter_mixins
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType
from nemo.utils import AppState, logging, model_utils
Expand Down Expand Up @@ -89,6 +88,10 @@
default_inference_config = {'tokens_to_generate': 30}


def get_last_rank():
return torch.distributed.get_world_size() - 1


class ModularizedAudioT5Model(MegatronT5LoraModel):
"""Modularized speech GPT model."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

NLPAdapterModelMixin = ABC

from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP, PEFTConfig
from nemo.collections.multimodal.speech_llm.parts.peft_config import PEFT_CONFIG_MAP, PEFTConfig
from nemo.utils import logging


Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,7 +37,6 @@
LoraDenseAttentionAdapterConfig,
LoraHto4HAdapterConfig,
LoraKQVAdapterConfig,
LoraKQVAdapterWeightTyingConfig,
LoraMoe4HtoHAdapterConfig,
LoraMoeHto4HAdapterConfig,
LoraUnfusedHto4HAdapterConfig,
Expand Down
21 changes: 19 additions & 2 deletions nemo/collections/nlp/modules/common/hyena/hyena.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@
# https://github.com/HazyResearch/safari/blob/flashfftconv/src/models/sequence/hyena.py
# https://github.com/athms/mad-lab/blob/main/mad/model/layers/hyena.py

# flake8: noqa
# pylint: skip-file

from dataclasses import dataclass
from typing import Union
from typing import Union, Optional

import torch
import torch.nn as nn
Expand All @@ -57,7 +60,6 @@

from nemo.collections.common.parts.utils import activation_registry
from nemo.collections.nlp.modules.common.hyena.hyena_filter import HyenaFilter, HyenaFilterSubmodules
from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.utils.metaclasses import Singleton

try:
Expand Down Expand Up @@ -102,6 +104,21 @@ def auto_assign_attrs(cls, **kwargs):
setattr(cls, k, v)


def torch_dtype_from_precision(precision: Union[int, str], megatron_amp_O2: Optional[bool] = None) -> torch.dtype:
"""Mapping from PTL precision types to corresponding PyTorch parameter datatype."""
if megatron_amp_O2 is not None and megatron_amp_O2 is False:
return torch.float32

if precision in ['bf16', 'bf16-mixed']:
return torch.bfloat16
elif precision in [16, '16', '16-mixed']:
return torch.float16
elif precision in [32, '32', '32-true']:
return torch.float32
else:
raise ValueError(f"Could not parse the precision of `{precision}` to a valid torch.dtype")


class CausalDepthWiseConv1d(nn.Module):
def __init__(self, channels, width, bias=True):
if not HAVE_CAUSAL_CONV1D:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa
# pylint: skip-file

from importlib.metadata import version
from typing import TYPE_CHECKING, Dict, Optional

Expand All @@ -20,7 +23,7 @@
import torch.nn.functional as F
from torch import Tensor, nn

from nemo.collections.nlp.parts.peft_config import LORA_CONFIG_TO_MCORE_MAP, get_target_modules
from nemo.collections.multimodal.speech_llm.parts.peft_config import LORA_CONFIG_TO_MCORE_MAP, get_target_modules
from nemo.utils import logging
from nemo.utils.import_utils import safe_import_from

Expand Down
Loading
Loading