diff --git a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml index 572331a84..9e662c4e1 100644 --- a/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml +++ b/examples/compress/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml @@ -1,3 +1,5 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model block_size: 8192 bos_rate: 0.5 data_column: messages diff --git a/examples/compress/main.py b/examples/compress/main.py index c8b287fcc..2c3343c37 100644 --- a/examples/compress/main.py +++ b/examples/compress/main.py @@ -29,18 +29,18 @@ """ import argparse -import datetime +from datetime import timedelta from pathlib import Path -import mip_and_realize_models -import torch -from puzzle_tools.hydra_utils import register_hydra_resolvers - +import modelopt.torch._compress.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch.nas as mtn +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel -from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch._compress.tools.hydra_utils import ( + initialize_hydra_config_for_dir, + register_hydra_resolvers, +) from modelopt.torch._compress.tools.logger import mprint -from tests.utils.test_utils import initialize_hydra_config_for_dir def parse_args(): @@ -70,50 +70,52 @@ def run_full_compress(hydra_config_path: str): config_path: Path to the YAML configuration file """ mprint("Compress Progress 1/8: starting compression pipeline") - with NativeDdpRuntime(dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)): - # Register Hydra custom resolvers (needed for config resolution) - register_hydra_resolvers() - - hydra_config_path = Path(hydra_config_path).resolve() - hydra_config_dir = str(hydra_config_path.parent) - hydra_config_name = hydra_config_path.stem - - # Load hydra config - hydra_cfg = initialize_hydra_config_for_dir( - config_dir=hydra_config_dir, - config_name=hydra_config_name, - overrides=[], - ) - - # Convert model (convert from HF to DeciLM, score pruning activations, - # prune the model and save pruned checkpoints) - input_model = CompressModel() - converted_model = mtn.convert( - input_model, - mode=[ - ( - "compress", - { - "puzzle_dir": str(hydra_cfg.puzzle_dir), - "input_model_path": hydra_cfg.input_hf_model_path, - "hydra_config_dir": hydra_config_dir, - "hydra_config_name": hydra_config_name, - "dataset_path": str(hydra_cfg.dataset_path), - }, - ) - ], - ) - - # Run NAS search (build replacement library and compute stats, - # compute one block scores, run MIP and realize models) - mtn.search( - converted_model, - constraints={}, # this is not used as the search space is defined in the hydra config - dummy_input=None, # Not used - config={}, # this is not used as the search space is defined in the hydra config - ) - - mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)") + dist.setup(timeout=timedelta(10)) + + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + hydra_config_path = Path(hydra_config_path).resolve() + hydra_config_dir = str(hydra_config_path.parent) + hydra_config_name = hydra_config_path.stem + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[], + ) + + # Convert model (convert from HF to DeciLM, score pruning activations, + # prune the model and save pruned checkpoints) + input_model = CompressModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(hydra_cfg.puzzle_dir), + "input_model_path": hydra_cfg.input_hf_model_path, + "hydra_config_dir": hydra_config_dir, + "hydra_config_name": hydra_config_name, + "dataset_path": str(hydra_cfg.dataset_path), + }, + ) + ], + ) + + # Run NAS search (build replacement library and compute stats, + # compute one block scores, run MIP and realize models) + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) + + dist.cleanup() + mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)") def run_mip_only(hydra_config_path: str): @@ -125,30 +127,29 @@ def run_mip_only(hydra_config_path: str): Args: hydra_config_path: Path to the YAML configuration file """ + dist.setup(timeout=timedelta(10)) + + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + hydra_config_path = Path(hydra_config_path).resolve() + hydra_config_dir = str(hydra_config_path.parent) + hydra_config_name = hydra_config_path.stem + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[], + ) + + # mip_and_realize_models (distributed processing) + # TODO: How to make it part of mnt.search() api, similarly to run_full_compress() API + mprint("Compress Progress 7/8: running MIP and realizing models (multi-gpu)") + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) - with NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) as runtime: - # Register Hydra custom resolvers (needed for config resolution) - register_hydra_resolvers() - - hydra_config_path = Path(hydra_config_path).resolve() - hydra_config_dir = str(hydra_config_path.parent) - hydra_config_name = hydra_config_path.stem - - # Load hydra config - hydra_cfg = initialize_hydra_config_for_dir( - config_dir=hydra_config_dir, - config_name=hydra_config_name, - overrides=[], - ) - - # mip_and_realize_models (distributed processing) - # TODO: How to make it part of mnt.search() api, similarly to run_full_compress() API - mprint("Compress Progress 7/8: running MIP and realizing models (multi-gpu)") - mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) - - mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)") + dist.cleanup() + mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)") def main(): diff --git a/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py b/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py index 6339d55ab..510f69111 100644 --- a/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py +++ b/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py @@ -18,7 +18,6 @@ activation scoring for pruning. """ -import argparse import gc import json from abc import ABC, abstractmethod @@ -30,6 +29,8 @@ from omegaconf import DictConfig, OmegaConf from torch import nn +import modelopt.torch.utils.distributed as dist + # BlockConfig used at runtime, not just type hints (lines 680, 790) from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import BlockConfig # noqa: TC001 from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import ( @@ -38,7 +39,6 @@ from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMRMSNorm from modelopt.torch._compress.tools.logger import aprint from modelopt.torch._compress.tools.robust_json import json_dump -from modelopt.torch._compress.tools.runtime import IRuntime def clear_gpu_memory(clear: bool) -> None: @@ -97,9 +97,8 @@ def dump_activations_logs( cls: type["ActivationsHook"], activation_hooks: dict[str, "ActivationsHook"], activations_log_dir: Path | str, - args: argparse.Namespace, - runtime: IRuntime | None, - ): + args: DictConfig, + ) -> None: """ Default implementation for dumping final activation scores logs to disk. This is called only at the end of scoring to save final results. @@ -107,7 +106,7 @@ def dump_activations_logs( activations_log_dir = Path(activations_log_dir) activations_log_dir.mkdir(exist_ok=True, parents=True) - rank = runtime.global_rank if runtime is not None else 0 + rank = dist.rank() activations_log_path = activations_log_dir / f"rank_{rank}.pth" activations_log = { module_name: hook.to_dict() for module_name, hook in activation_hooks.items() @@ -116,14 +115,8 @@ def dump_activations_logs( if rank == 0: args.activation_hooks_kwargs.pop("model") - json_dump( - OmegaConf.to_container(args, resolve=True) - if isinstance(args, DictConfig) - else vars(args), - activations_log_dir / "args.json", - ) - if runtime is not None: - runtime.wait_for_everyone() # rank 0 will not wait before dumping args.json + json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json") + dist.barrier() aprint(f"Dumped final activations log to {activations_log_path}") @@ -132,8 +125,7 @@ def save_hook_states( cls: type["ActivationsHook"], activation_hooks: dict[str, "ActivationsHook"], activations_log_dir: Path | str, - runtime: IRuntime | None, - ): + ) -> None: """ Save hook states for checkpointing (separate from final results). This can be called periodically during scoring. @@ -141,7 +133,7 @@ def save_hook_states( """ activations_log_dir = Path(activations_log_dir) activations_log_dir.mkdir(exist_ok=True, parents=True) - rank = runtime.global_rank if runtime is not None else 0 + rank = dist.rank() hook_states_path = activations_log_dir / f"hook_states_rank_{rank}.pth" hook_states = { @@ -461,29 +453,28 @@ def dump_activations_logs( cls: type["LayerNormContributionHook"], activation_hooks: dict[str, "ActivationsHook"], activations_log_dir: Path | str, - args: argparse.Namespace, - runtime: IRuntime | None, - ): + args: DictConfig, + ) -> None: """ At the end of the default implementation of dumping activation scores to disc, save aggregated channel importance results. """ - super().dump_activations_logs(activation_hooks, activations_log_dir, args, runtime) + super().dump_activations_logs(activation_hooks, activations_log_dir, args) - rank = runtime.global_rank if runtime is not None else 0 + rank = dist.rank() if rank == 0: LayerNormContributionHook._save_channel_importance_results( activation_hooks, activations_log_dir, args ) - runtime.wait_for_everyone() + dist.barrier() @staticmethod def _save_channel_importance_results( activation_hooks: dict[str, ActivationsHook], activations_log_dir: Path, - args: argparse.Namespace, + args: DictConfig, ) -> None: """ Save channel importance results from activation hooks. diff --git a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py index 4a276e8e8..f271a5f4f 100644 --- a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py +++ b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py @@ -15,16 +15,12 @@ from pathlib import Path -import hydra import torch from omegaconf import DictConfig -from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.runtime import BaseRuntime, NativeDdpRuntime from modelopt.torch._compress.tools.validate_model import validate_model -from modelopt.torch._compress.utils.dist_utils import is_distributed -from modelopt.torch._compress.utils.parsing import format_global_config def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: @@ -50,23 +46,20 @@ def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: return method in supported_methods -def check_scoring_completion( - activations_log_dir: str, runtime, activation_hooks_kwargs=None -) -> bool: +def check_scoring_completion(activations_log_dir: str, activation_hooks_kwargs=None) -> bool: """ Check if scoring is already completed by looking for the expected output files. Also checks if the scoring method is safe for resume. Args: activations_log_dir: Directory where activation logs should be stored - runtime: Runtime object for distributed processing activation_hooks_kwargs: Hook configuration to check if resume is safe Returns: bool: True if scoring is completed (has rank files and args.json) """ - # Only check completion on main process (or if no distributed runtime) - if runtime is None or runtime.is_main_process: + # Only check completion on main process + if dist.is_master(): log_dir = Path(activations_log_dir) # Check if directory exists @@ -95,14 +88,13 @@ def check_scoring_completion( return False -def should_skip_scoring_completely(cfg: DictConfig, runtime) -> bool: +def should_skip_scoring_completely(cfg: DictConfig) -> bool: """ Determine if we should skip scoring entirely (only if 100% complete). Partial progress should proceed to validate_model for proper resume. Args: cfg: Configuration object - runtime: Runtime object for distributed processing Returns: bool: True if we should skip scoring (100% completed), False if we should run/resume it @@ -123,11 +115,11 @@ def should_skip_scoring_completely(cfg: DictConfig, runtime) -> bool: # Check if scoring is already completed is_completed = check_scoring_completion( - cfg.pruning.activations_log_dir, runtime, activation_hooks_kwargs + cfg.pruning.activations_log_dir, activation_hooks_kwargs ) # Broadcast the result to all processes in distributed mode - if runtime is not None and runtime.world_size > 1: + if dist.size() > 1: should_skip = [is_completed] # Use list for mutable object torch.distributed.broadcast_object_list(should_skip, src=0) is_completed = should_skip[0] @@ -141,34 +133,12 @@ def should_skip_scoring_completely(cfg: DictConfig, runtime) -> bool: # Old progress tracking removed - checkpoint manager handles all progress tracking -def launch_score_activations(cfg: DictConfig, runtime): +def launch_score_activations(cfg: DictConfig): # Check if we should skip scoring entirely (only if 100% complete) - if should_skip_scoring_completely(cfg, runtime): + if should_skip_scoring_completely(cfg): return mprint("Starting pruning activation scoring...") # The checkpoint manager inside validate_model handles all progress tracking - validate_model(args=cfg.pruning, runtime=runtime) - - -@hydra.main("", version_base="1.3") -def main(cfg: DictConfig) -> None: - cfg = hydra.utils.instantiate(cfg) - mprint(format_global_config(cfg, title="Score Pruning Activations")) - - _runtime = ( - NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") - ) - if is_distributed() - else BaseRuntime(dtype=torch.bfloat16) - ) - with _runtime as runtime: - launch_score_activations(cfg, runtime) - runtime.wait_for_everyone() - - -if __name__ == "__main__": - register_hydra_resolvers() - main() + validate_model(args=cfg.pruning, pipeline_parallel=True) diff --git a/modelopt/torch/_compress/build_library_and_stats.py b/modelopt/torch/_compress/build_library_and_stats.py index f0735c98f..28e0f386c 100644 --- a/modelopt/torch/_compress/build_library_and_stats.py +++ b/modelopt/torch/_compress/build_library_and_stats.py @@ -88,22 +88,3 @@ def launch_build_library_and_stats(cfg: DictConfig) -> None: mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.subblock_stats_filename}") if hasattr(cfg.calc_subblock_stats, "moe_stats_filename"): mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.moe_stats_filename}") - - -@hydra.main("", version_base="1.3") -def main(cfg: DictConfig) -> None: - """ - Main entry point for the unified build library and stats command. - - This function uses Hydra for configuration management and runs both - build_replacement_library and calc_subblock_stats in sequence. - """ - cfg = hydra.utils.instantiate(cfg) - mprint("Unified Build Library and Stats Configuration:") - mprint(format_global_config(cfg)) - launch_build_library_and_stats(cfg) - - -if __name__ == "__main__": - register_hydra_resolvers() - main() diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index 8504631cb..21e9df2af 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -27,12 +27,12 @@ import modelopt.torch._compress.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts import modelopt.torch._compress.scoring.scoring as scoring +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir -from modelopt.torch._compress.tools.runtime import IRuntime def compress( - hydra_config_dir: str, hydra_config: str, puzzle_dir: str, dataset_path: str, runtime: IRuntime + hydra_config_dir: str, hydra_config: str, puzzle_dir: str, dataset_path: str ) -> DictConfig: """Compress a puzzletron model using the MIP-based NAS search algorithm. @@ -41,8 +41,6 @@ def compress( hydra_config (str): the corresponding hydra config file puzzle_dir (str): directory with a puzzletron model to compress dataset_path (str): dataset used for scoring and distillation - runtime: distributed runtime to use to run the compression steps, e.g., - NativeDdpRuntime(dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)) Returns: Hydra config object after compressing the model. @@ -60,22 +58,22 @@ def compress( ) # Step 1: score_pruning_activations (distributed processing) - score_pruning_activations.launch_score_activations(hydra_cfg, runtime) + score_pruning_activations.launch_score_activations(hydra_cfg) # Step 2: pruning_ckpts (single process) - if runtime.global_rank == 0: + if dist.is_master(): pruning_ckpts.launch_prune_ckpt(hydra_cfg) - runtime.wait_for_everyone() + dist.barrier() # Step 4: build_library_and_stats (single process) - if runtime.global_rank == 0: + if dist.is_master(): build_library_and_stats.launch_build_library_and_stats(hydra_cfg) - runtime.wait_for_everyone() + dist.barrier() # Step 5: calc_one_block_scores (distributed processing) - scoring.launch_scoring(hydra_cfg, runtime) + scoring.launch_scoring(hydra_cfg) # Step 6: mip_and_realize_models (distributed processing) - mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) return hydra_cfg diff --git a/modelopt/torch/_compress/dataset/prepare_dataset.py b/modelopt/torch/_compress/dataset/prepare_dataset.py index 49d63d122..072640777 100644 --- a/modelopt/torch/_compress/dataset/prepare_dataset.py +++ b/modelopt/torch/_compress/dataset/prepare_dataset.py @@ -18,7 +18,8 @@ import datasets import fire import numpy as np -from logger import mprint + +from modelopt.torch._compress.tools.logger import mprint def process_and_save_dataset( diff --git a/modelopt/torch/_compress/mip/mip_and_realize_models.py b/modelopt/torch/_compress/mip/mip_and_realize_models.py index f6d77d262..a3a1a84b9 100644 --- a/modelopt/torch/_compress/mip/mip_and_realize_models.py +++ b/modelopt/torch/_compress/mip/mip_and_realize_models.py @@ -19,19 +19,15 @@ from pathlib import Path from typing import List -import hydra import torch -import torch.distributed as dist from omegaconf import DictConfig +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.mip.run_puzzle import run_puzzle -from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.runtime import BaseRuntime, IRuntime, NativeDdpRuntime from modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements import ( validate_puzzle_solutions, ) -from modelopt.torch._compress.utils.dist_utils import is_distributed def launch_mip(cfg: DictConfig) -> List[str]: @@ -39,19 +35,18 @@ def launch_mip(cfg: DictConfig) -> List[str]: return solution_paths -def launch_realize_model(cfg: DictConfig, runtime: IRuntime): - validate_puzzle_solutions(args=cfg.realize_model, runtime=runtime) +def launch_realize_model(cfg: DictConfig): + validate_puzzle_solutions(args=cfg.realize_model) -def launch_mip_and_realize_model(cfg: DictConfig, runtime: IRuntime): +def launch_mip_and_realize_model(cfg: DictConfig): # Determine device for distributed operations (NCCL requires CUDA tensors) device = "cpu" - if runtime.world_size > 1 and dist.is_initialized(): - backend = dist.get_backend() - if backend == "nccl": + if dist.size() > 1: + if torch.distributed.get_backend() == "nccl": device = torch.cuda.current_device() - if runtime.is_main_process: + if dist.is_master(): solution_paths = launch_mip(cfg) length_tensor = torch.tensor([len(solution_paths)], dtype=torch.long, device=device) else: @@ -59,39 +54,19 @@ def launch_mip_and_realize_model(cfg: DictConfig, runtime: IRuntime): length_tensor = torch.tensor([0], dtype=torch.long, device=device) if not cfg.skip_realize_model: - if runtime.world_size > 1: - dist.broadcast(length_tensor, src=0) + if dist.size() > 1: + torch.distributed.broadcast(length_tensor, src=0) list_length = length_tensor.item() - if runtime.global_rank != 0: + if not dist.is_master(): solution_paths = [None] * list_length - if runtime.world_size > 1: - dist.broadcast_object_list(solution_paths, src=0) + if dist.size() > 1: + torch.distributed.broadcast_object_list(solution_paths, src=0) for solution_path in solution_paths: mprint(f"Realize model for the solution: {solution_path}") cfg.realize_model.solutions_path = Path(solution_path) - launch_realize_model(cfg, runtime=runtime) - runtime.wait_for_everyone() - - -@hydra.main("", version_base="1.3") -def main(cfg: DictConfig) -> None: - cfg = hydra.utils.instantiate(cfg) - - _runtime = ( - NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") - ) - if is_distributed() - else BaseRuntime(dtype=torch.bfloat16) - ) - with _runtime as runtime: - launch_mip_and_realize_model(cfg, runtime) - - -if __name__ == "__main__": - register_hydra_resolvers() - main() + launch_realize_model(cfg) + dist.barrier() diff --git a/modelopt/torch/_compress/mip/run_puzzle.py b/modelopt/torch/_compress/mip/run_puzzle.py index 5773349c1..4868479e2 100644 --- a/modelopt/torch/_compress/mip/run_puzzle.py +++ b/modelopt/torch/_compress/mip/run_puzzle.py @@ -226,7 +226,7 @@ def parse_args() -> argparse.Namespace: def run_single_puzzle_config( - args: argparse.Namespace, + args: argparse.Namespace | DictConfig, gathered_metrics: dict, subblock_stats: dict, subblock_stats_args: dict, @@ -426,7 +426,7 @@ def _get_minimal_unique_names(dicts: List[dict]) -> List[str]: return ["-".join(f"{k}_{d[k]}".replace(".", "_") for k in non_common_keys) for d in dicts] -def run_puzzle(args: argparse.Namespace) -> List[str]: +def run_puzzle(args: argparse.Namespace | DictConfig) -> List[str]: # Loads config from args/puzzle_profile if args.puzzle_profile is not None: with open(args.puzzle_profile) as f: diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 5c08c693a..55b9d10b0 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -29,6 +29,7 @@ import modelopt.torch._compress.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts import modelopt.torch._compress.scoring.scoring as scoring +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress import build_library_and_stats from modelopt.torch._compress.activation_scoring import score_pruning_activations from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( @@ -36,7 +37,6 @@ ) from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.runtime import NativeDdpRuntime from modelopt.torch.nas.conversion import NASModeRegistry from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.mode import ( @@ -99,13 +99,6 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR The output of this step will be used by mnt.search() to perform the NAS search. """ - - # NativeDdpRuntime must be initialized/closed from outside of this function, so we are - # NOT calling runtime.cleanup() here. TODO: Not optimal - redesign it. - runtime = NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) - # Required for mtn.search() to read NAS configuration model.hydra_config_dir = config.hydra_config_dir model.hydra_config_name = config.hydra_config_name @@ -124,26 +117,26 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR # Convert Llama3 model to DeciLM model # TODO: Make it generic, do not call convert_llama3_to_decilm directly. - if runtime.global_rank == 0: + if dist.is_master(): mprint("Compress Progress 2/8: converting model from HF to DeciLM (single-gpu)") hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable convert_llama3_to_decilm( input_dir=config.input_model_path, output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, ) - runtime.wait_for_everyone() + dist.barrier() # Score_pruning_activations (distributed processing) mprint("Compress Progress 3/8: scoring pruning activations (multi-gpu)") - score_pruning_activations.launch_score_activations(hydra_cfg, runtime) + score_pruning_activations.launch_score_activations(hydra_cfg) # Prune the model and save pruned checkpoints - if runtime.global_rank == 0: + if dist.is_master(): mprint( "Compress Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)" ) pruning_ckpts.launch_prune_ckpt(hydra_cfg) - runtime.wait_for_everyone() + dist.barrier() return model, {} @@ -203,12 +196,6 @@ def default_state_dict(self) -> SearchStateDict: return {} def run_search(self) -> None: - # NativeDdpRuntime must be initialized/closed from outside of this function, so we are - # NOT calling runtime.cleanup() here. TODO: Not optimal - redesign it. - runtime = NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) - # Load hydra config hydra_cfg = initialize_hydra_config_for_dir( config_dir=self.model.hydra_config_dir, @@ -220,17 +207,17 @@ def run_search(self) -> None: ) # Build_library_and_stats (single process) - if runtime.global_rank == 0: + if dist.is_master(): mprint( "Compress Progress 5/8: building replacement library and subblock statistics (single-gpu)" ) build_library_and_stats.launch_build_library_and_stats(hydra_cfg) - runtime.wait_for_everyone() + dist.barrier() # Calc_one_block_scores (distributed processing) mprint("Compress Progress 6/8: calculating one block scores (multi-gpu)") - scoring.launch_scoring(hydra_cfg, runtime) + scoring.launch_scoring(hydra_cfg) # mip_and_realize_models (distributed processing) mprint("Compress Progress 7/8: running MIP and realizing models (multi-gpu)") - mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) diff --git a/modelopt/torch/_compress/pruning/pruning_ckpts.py b/modelopt/torch/_compress/pruning/pruning_ckpts.py index 4a0e5c15c..b413a3f78 100644 --- a/modelopt/torch/_compress/pruning/pruning_ckpts.py +++ b/modelopt/torch/_compress/pruning/pruning_ckpts.py @@ -337,15 +337,3 @@ def launch_prune_ckpt(cfg: DictConfig): raise NotImplementedError( f"checkpoint pruning is not currently supported for target layer: {target_layer}" ) - - -@hydra.main("", version_base="1.3") -def main(cfg: DictConfig) -> None: - cfg = hydra.utils.instantiate(cfg) - mprint(cfg) - launch_prune_ckpt(cfg) - - -if __name__ == "__main__": - register_hydra_resolvers() - main() diff --git a/modelopt/torch/_compress/replacement_library/build_replacement_library.py b/modelopt/torch/_compress/replacement_library/build_replacement_library.py index a8b2b7f9b..760952a60 100644 --- a/modelopt/torch/_compress/replacement_library/build_replacement_library.py +++ b/modelopt/torch/_compress/replacement_library/build_replacement_library.py @@ -40,7 +40,6 @@ from pathlib import Path from typing import Any, Type -import hydra import pandas as pd from omegaconf import DictConfig @@ -59,7 +58,6 @@ is_valid_decilm_checkpoint, load_model_config, ) -from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.tools.robust_json import json_dump from modelopt.torch._compress.utils.parsing import format_global_config @@ -591,15 +589,3 @@ def _build_single_sequence_replacement_solutions( ) return solutions - - -@hydra.main("", version_base="1.3") -def main(cfg: DictConfig) -> None: - cfg = hydra.utils.instantiate(cfg) - mprint(format_global_config(cfg)) - launch_build_replacement_library(cfg) - - -if __name__ == "__main__": - register_hydra_resolvers() - main() diff --git a/modelopt/torch/_compress/replacement_library/replacement_library.py b/modelopt/torch/_compress/replacement_library/replacement_library.py index ccfaaee0d..5e2fee6f0 100644 --- a/modelopt/torch/_compress/replacement_library/replacement_library.py +++ b/modelopt/torch/_compress/replacement_library/replacement_library.py @@ -30,6 +30,7 @@ from safetensors.torch import load_file as safe_load_file from torch import nn +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, @@ -124,18 +125,11 @@ def create_model_config(self, layer_replacements: list[dict]): model_config = self.model_config.set_block_configs(block_configs) return model_config - def load_model( - self, - layer_replacements: list[dict], - world_size: int, - global_rank: int, - ) -> DeciLMForCausalLM: + def load_model(self, layer_replacements: list[dict]) -> DeciLMForCausalLM: block_configs, block_locations = extract_block_configs_and_locations(layer_replacements) model_config = self.model_config.set_block_configs(block_configs) - owned_block_indexes = _get_owned_block_indexes( - model_config.get_num_hidden_layers(), world_size, global_rank - ) + owned_block_indexes = _get_owned_block_indexes(model_config.get_num_hidden_layers()) model = create_dummy_model(model_config, self.dtype) is_first_shard = 0 in owned_block_indexes @@ -157,15 +151,10 @@ def load_model( self._move_inactive_blocks_to_cpu(active_blocks) return model - def load_checkpoint( - self, - checkpoint_dir: str | Path, - world_size: int, - global_rank: int, - ) -> DeciLMForCausalLM: + def load_checkpoint(self, checkpoint_dir: str | Path) -> DeciLMForCausalLM: checkpoint_dir = Path(checkpoint_dir).resolve() layer_replacements = self._locate_replacements_of_entire_checkpoint(checkpoint_dir) - model = self.load_model(layer_replacements, world_size, global_rank) + model = self.load_model(layer_replacements) return model def _locate_replacements_of_entire_checkpoint(self, checkpoint_dir: str | Path) -> list[dict]: @@ -371,18 +360,18 @@ def _error_message_ensure_split(checkpoint_dir: Path) -> str: ) -def _get_owned_block_indexes(n_layer: int, world_size: int, global_rank: int) -> list[int]: +def _get_owned_block_indexes(n_layer: int) -> list[int]: last_process_blocks = np.array([n_layer - 1]) # less params in last gpu, leave room for logits - if world_size == 1: + if dist.size() == 1: # Only one process: assign everything (including the "last process" block) to rank 0 owned_block_indexes_per_process = [ np.concatenate([np.arange(n_layer - 1), last_process_blocks]) ] else: # Multiple processes: split n_layer-1 blocks, reserve the last for "last process" - owned_block_indexes_per_process = np.array_split(range(n_layer - 1), world_size - 1) + owned_block_indexes_per_process = np.array_split(range(n_layer - 1), dist.size() - 1) owned_block_indexes_per_process.append(last_process_blocks) - owned_block_indexes = owned_block_indexes_per_process[global_rank].tolist() + owned_block_indexes = owned_block_indexes_per_process[dist.rank()].tolist() return owned_block_indexes diff --git a/modelopt/torch/_compress/scoring/scoring.py b/modelopt/torch/_compress/scoring/scoring.py index f17b8cd3e..5f745b399 100644 --- a/modelopt/torch/_compress/scoring/scoring.py +++ b/modelopt/torch/_compress/scoring/scoring.py @@ -27,13 +27,12 @@ import torch from omegaconf import DictConfig +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.runtime import BaseRuntime, IRuntime, NativeDdpRuntime from modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements import ( validate_puzzle_solutions, ) -from modelopt.torch._compress.utils.dist_utils import is_distributed def extract_solution_id(filename): @@ -73,26 +72,19 @@ def get_solutions_to_validate(cfg: DictConfig): return _solutions_to_validate -def launch_scoring(cfg: DictConfig, runtime: IRuntime): +def launch_scoring(cfg: DictConfig): cfg.scoring.solutions_to_validate = get_solutions_to_validate(cfg) mprint(f"Solutions to validate: {cfg.scoring.solutions_to_validate}") - validate_puzzle_solutions(args=cfg.scoring, runtime=runtime) + validate_puzzle_solutions(args=cfg.scoring) @hydra.main("", version_base="1.3") def main(cfg: DictConfig) -> None: cfg = hydra.utils.instantiate(cfg) mprint(cfg) - - _runtime = ( - NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") - ) - if is_distributed() - else BaseRuntime(dtype=torch.bfloat16) - ) - with _runtime as runtime: - launch_scoring(cfg, runtime) + dist.setup(timeout=cfg.nccl_timeout_minutes) + launch_scoring(cfg) + dist.cleanup() if __name__ == "__main__": diff --git a/modelopt/torch/_compress/sewing_kit/common.py b/modelopt/torch/_compress/sewing_kit/common.py deleted file mode 100644 index 5bc573232..000000000 --- a/modelopt/torch/_compress/sewing_kit/common.py +++ /dev/null @@ -1,19 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -logger = logging.getLogger("sewing_kit") -logger.setLevel(logging.WARN) diff --git a/modelopt/torch/_compress/sewing_kit/passage/core.py b/modelopt/torch/_compress/sewing_kit/passage/core.py index 71164f061..22c720b50 100644 --- a/modelopt/torch/_compress/sewing_kit/passage/core.py +++ b/modelopt/torch/_compress/sewing_kit/passage/core.py @@ -16,20 +16,13 @@ # mypy: ignore-errors from __future__ import annotations -import sys from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Any, ContextManager, Iterable, Mapping, Optional, Union -try: - from typing import Self -except ImportError: - from typing_extensions import Self - import torch.nn as nn from typing_extensions import override -from ..common import logger from ..utils import ( ActivityContext, dynamo_skip, diff --git a/modelopt/torch/_compress/sewing_kit/utils.py b/modelopt/torch/_compress/sewing_kit/utils.py index ff47c289b..25ee8c9ea 100644 --- a/modelopt/torch/_compress/sewing_kit/utils.py +++ b/modelopt/torch/_compress/sewing_kit/utils.py @@ -16,7 +16,7 @@ from __future__ import annotations import inspect -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from contextlib import contextmanager from typing import ( Any, @@ -76,65 +76,6 @@ def __init__(self, module: TModule): self.module = module -Reduction = Literal["none", "mean", "sum"] - - -def normalized_mse_loss( - input: Tensor, target: Tensor, reduction: Reduction = "mean", epsilon: float = 1e-6 -): - loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( - target, torch.zeros_like(target) + epsilon, reduction=reduction - ) - return loss - - -def mse_loss(input: Tensor, target: Tensor, reduction: Reduction = "mean", epsilon: float = 1e-6): - loss = F.mse_loss(input, target, reduction=reduction) - return loss - - -class NormalizedMSELoss(nn.modules.loss._Loss): - __constants__ = ["reduction", "epsilon"] - - def __init__(self, reduction: Reduction = "mean", epsilon: float = 1e-6) -> None: - super().__init__(None, None, reduction) - self.epsilon = epsilon - - def forward(self, input: Tensor, target: Tensor) -> Tensor: - loss = normalized_mse_loss( - input, - target, - cast(Reduction, self.reduction), - self.epsilon, - ) - return loss - - -def vectorwise_normalized_mse_loss(input: Tensor, target: Tensor, epsilon: float = 1e-6): - """ - Like normalized_mse_loss, but the input is treated as a multi-dimensional batch of vectors. - Normalization is done on each vector separately (the last dim), then results are averaged. - """ - return batched_normalized_mse_loss(input, target, epsilon, batch_dims=range(input.ndim - 1)) - - -def batched_normalized_mse_loss( - input: Tensor, target: Tensor, epsilon: float = 1e-6, batch_dims: Sequence[int] = (0,) -): - """ - Like normalized_mse_loss, but the input is treated as a batch of tensors. - Normalization is done on the non-batch dims, then results are averaged. - """ - norm_dims = list(set(range(input.ndim)) - set(batch_dims)) - norm_of_target_vectors = F.mse_loss( - target, torch.zeros_like(target) + epsilon, reduction="none" - ).mean(dim=norm_dims) - vectorwise_mse = F.mse_loss(input, target, reduction="none").mean(dim=norm_dims) - normalized_vectorwise_mse = vectorwise_mse / norm_of_target_vectors - loss = normalized_vectorwise_mse.mean() - return loss - - class ActivityContextMaxDepthException(Exception): pass @@ -216,20 +157,6 @@ def is_submodule_or_same(module_name: str, other_module_name: str) -> bool: return result -def reduce_losses(losses: Iterable[Tensor]) -> Tensor: - total_loss = None - for loss in losses: - if total_loss is None: - total_loss = loss - else: - total_loss += loss - - if total_loss is None: - return torch.Tensor(torch.nan) - - return total_loss - - fake_mode = FakeTensorMode( allow_non_fake_inputs=True, # allow_fallback_kernels=False, @@ -423,30 +350,6 @@ def has_fake_tensor(v: Any) -> bool: return result -@dynamo_skip -def is_real_tensor(t: Any) -> bool: - return isinstance(t, Tensor) and not t.is_meta and not isinstance(t, FakeTensor) - - -@dynamo_skip -def get_parent_module_name(module_name: str): - if "." not in module_name: - return "" - else: - return module_name.rsplit(".", 1)[0] - - -@dynamo_skip -def get_parent_module_names(module_name: str): - parent_module_names = set[str]() - - while len(module_name) > 0: - module_name = get_parent_module_name(module_name) - parent_module_names.add(module_name) - - return parent_module_names - - def _get_device_for_distributed( group: Optional[torch.distributed.ProcessGroup] = None, ) -> str: diff --git a/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py b/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py index 7f5a41778..e25c8e38d 100644 --- a/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py +++ b/modelopt/torch/_compress/subblock_stats/calc_subblock_params_and_memory.py @@ -50,7 +50,7 @@ def calculate_subblock_memory( prefill_queue_size: int, n_embd: int, n_head: int, - weights_dtype: torch.dtype | str, + weights_dtype: torch.dtype, kv_cache_dtype: torch.dtype, allocate_prefill_query: bool, ) -> float | dict[str, float]: @@ -174,7 +174,7 @@ def calculate_attention_memory( prefill_queue_size: int, n_embd: int, n_head: int, - weights_dtype: torch.dtype | str, + weights_dtype: torch.dtype, kv_cache_dtype: torch.dtype, allocate_prefill_query: bool, ) -> dict[str, float]: @@ -221,8 +221,8 @@ def calculate_mamba_memory( mamba_config: MambaConfig, n_embd: int, batch_size: int, - weights_dtype: torch.dtype | str, - kv_cache_dtype: torch.dtype | str, + weights_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, ) -> int: return ( calculate_mamba_params(mamba_config, n_embd) * sizeof_dtype(weights_dtype) @@ -274,7 +274,7 @@ def _calculate_mamba_intermediates(mamba_config: MambaConfig) -> tuple[int, ...] def calculate_linear_memory( n_embd: int, - weights_dtype: torch.dtype | str, + weights_dtype: torch.dtype, ) -> float: return calculate_linear_params(n_embd) * sizeof_dtype(weights_dtype) / 2**20 @@ -288,7 +288,7 @@ def calculate_linear_params( def calculate_ffn_memory( ffn_config: FFNConfig, n_embd: int, - weights_dtype: torch.dtype | str, + weights_dtype: torch.dtype, ) -> float: num_params = calculate_ffn_params(ffn_config, n_embd) return num_params * sizeof_dtype(weights_dtype) / 2**20 diff --git a/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py b/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py index d3e73a0cf..76e6c3428 100644 --- a/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py +++ b/modelopt/torch/_compress/subblock_stats/calc_subblock_stats.py @@ -16,20 +16,16 @@ """Calc subblock stats to compute memory and runtime statistics for subblocks.""" -import os -from itertools import product - -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - import dataclasses import json +import os from functools import partial +from itertools import product from pathlib import Path from typing import Iterable, Optional, Type, TypeVar -import hydra +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + import pandas as pd import torch from immutabledict import immutabledict @@ -42,6 +38,7 @@ FFNConfig, SubblockConfig, ) +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.replacement_library.replacement_utils import parse_layer_replacement from modelopt.torch._compress.subblock_stats.calc_subblock_params_and_memory import ( calc_subblock_active_params, @@ -51,7 +48,6 @@ calculate_subblock_params, ) from modelopt.torch._compress.tools.checkpoint_utils import load_model_config -from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.tools.robust_json import json_dump from modelopt.torch._compress.utils.parsing import format_global_config @@ -91,9 +87,7 @@ def calculate_subblock_stats( ) -> dict: is_calc_runtime = benchmark_iterations is not None if is_calc_runtime: - from puzzle_tools.subblock_stats.runtime_stats.calc_runtime_stats import ( - calc_runtime_ms_for_subblocks, - ) + raise NotImplementedError("Runtime stats calculation is not implemented yet") gpu = None if not torch.cuda.is_available() else torch.cuda.get_device_name() subblock_stats = { @@ -540,15 +534,3 @@ def _find_corresponding_bf16_stats(args: dict, subblock_stats: list[dict]) -> di if len(matching_bf16_stats) == 1: return matching_bf16_stats[0] raise ValueError(f"Found more than 1 matching bf16 stats for {args=}") - - -@hydra.main("configs", version_base="1.3", config_name="search_space") -def main(cfg: DictConfig) -> None: - cfg = hydra.utils.instantiate(cfg) - mprint(format_global_config(cfg)) - launch_calc_subblock_stats(cfg) - - -if __name__ == "__main__": - register_hydra_resolvers() - main() diff --git a/modelopt/torch/_compress/tools/bypassed_training/child_init.py b/modelopt/torch/_compress/tools/bypassed_training/child_init.py index d9ead79a1..3e2c42f09 100644 --- a/modelopt/torch/_compress/tools/bypassed_training/child_init.py +++ b/modelopt/torch/_compress/tools/bypassed_training/child_init.py @@ -39,7 +39,6 @@ ) from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.tools.logger import aprint, mprint -from modelopt.torch._compress.tools.runtime import IRuntime class GQAInitMode(Enum): @@ -331,7 +330,6 @@ def create_child_state_dict( new_config: DeciLMConfig, gqa_init_mode: GQAInitMode, ignore_fn: IgnoreFn = default_ignore_fn, - runtime: Optional[IRuntime] = Printer, mlp_init_mode: MlpInitMode = MlpInitMode.CopyAsIs, mlp_init_config: Optional[dict[str, Any]] = None, owned_block_indexes: Optional[set[int]] = None, diff --git a/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py index dbb4eac0c..f06db92fb 100644 --- a/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py +++ b/modelopt/torch/_compress/tools/bypassed_training/init_child_from_parent.py @@ -220,47 +220,3 @@ def init_child_from_parent( mprint(f"Total core processing: {total_core_time:.2f}s") mprint(f"Optimizations: I/O workers={actual_io_workers}, Layer workers={actual_layer_workers}") mprint(f"=========================\n") - - -def parse_args(): - parser = argparse.ArgumentParser() - - # Arguments for single checkpoint creation - parser.add_argument("--parent_checkpoint_dir", type=str, required=True) - parser.add_argument("--model_config_overrides_json", type=str, required=True) - parser.add_argument("--output_checkpoint_dir", type=str, required=True) - parser.add_argument( - "--gqa_init_mode", type=str, default="AverageKV", choices=GQAInitMode._member_names_ - ) - parser.add_argument( - "--mlp_init_mode", type=str, default="Truncate", choices=MlpInitMode._member_names_ - ) - parser.add_argument("--mlp_init_config_yaml", type=str, default=None) - parser.add_argument( - "--linear_init_mode", type=str, default="FromTeacher", choices=LinearInitMode._member_names_ - ) - parser.add_argument( - "--hidden_size_init_mode", type=str, default=None, choices=HiddenSizeInitMode._member_names_ - ) - parser.add_argument("--channel_importance_path", type=str, required=False) - parser.add_argument("--target_hidden_sizes", type=int, nargs="+", required=False) - - args = parser.parse_args() - return args - - -if __name__ == "__main__": - args = parse_args() - - init_child_from_parent( - parent_checkpoint_dir=args.parent_checkpoint_dir, - model_config_overrides_json=args.model_config_overrides_json, - output_checkpoint_dir=args.output_checkpoint_dir, - gqa_init_mode=GQAInitMode(args.gqa_init_mode), - mlp_init_mode=MlpInitMode(args.mlp_init_mode), - mlp_init_config_yaml=args.mlp_init_config_yaml, - linear_init_mode=LinearInitMode(args.linear_init_mode), - hidden_size_init_mode=HiddenSizeInitMode(args.hidden_size_init_mode) - if args.hidden_size_init_mode - else None, - ) diff --git a/modelopt/torch/_compress/tools/hydra.py b/modelopt/torch/_compress/tools/hydra.py deleted file mode 100644 index 8c36d309e..000000000 --- a/modelopt/torch/_compress/tools/hydra.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from hydra import compose, initialize, initialize_config_dir -from omegaconf import DictConfig, OmegaConf - -""" -Utilities for hydra config initialization. -""" - - -def initialize_hydra_config_for_dir( - config_dir: str, config_name: str, overrides: list[str] -) -> DictConfig: - """Initialize a hydra config from an absolute path for a config directory - - Args: - config_dir (str): - config_name (str): - overrides (List[str]): - - Returns: - DictConfig: - """ - - with initialize_config_dir(version_base=None, config_dir=config_dir): - args = compose(config_name, overrides) - args._set_flag("allow_objects", True) - OmegaConf.resolve(args) # resolve object attributes - OmegaConf.set_struct(args, False) - - return args - - -def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig: - with initialize(version_base=None, config_path=config_path): - args = compose(config_name, overrides) - args._set_flag("allow_objects", True) - OmegaConf.resolve(args) # resolve object attributes - OmegaConf.set_struct(args, False) - - return args diff --git a/modelopt/torch/_compress/tools/runtime.py b/modelopt/torch/_compress/tools/runtime.py deleted file mode 100644 index 46f561a5d..000000000 --- a/modelopt/torch/_compress/tools/runtime.py +++ /dev/null @@ -1,556 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Classes for torch distributed runtime management""" - -import os -import random -from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable, Iterator, Sequence -from contextlib import AbstractContextManager, suppress -from datetime import timedelta -from pathlib import Path -from typing import Literal, TypeVar, cast - -import numpy as np -import torch -import torch.distributed -import torch.nn as nn -from torch.utils.data import DataLoader -from tqdm import tqdm -from typing_extensions import override - -PrepareModelsT = TypeVar("PrepareModelsT", bound=Sequence[nn.Module]) -PrepareDataLoaderT = TypeVar("PrepareDataLoaderT", bound=DataLoader) -CompileT = TypeVar("CompileT", bound=nn.Module) -Filter = ( - Literal["main_process", "last", "local_main_process", "local_last", "all"] - | list[int] - | set[int] - | Callable[[int], bool] -) - - -class IRuntime(ABC): - @abstractmethod - def setup(self) -> None: ... - - @abstractmethod - def cleanup(self) -> None: ... - - @abstractmethod - def autocast(self) -> AbstractContextManager: ... - - @abstractmethod - def wait_for_everyone(self) -> None: ... - - @abstractmethod - def set_seed(self, seed: int, device_specific: bool = False) -> int: ... - - @abstractmethod - def prepare_models(self, models: PrepareModelsT) -> PrepareModelsT: ... - - @abstractmethod - def prepare_train_dataloader( - self, train_dataloader: PrepareDataLoaderT - ) -> PrepareDataLoaderT: ... - - @abstractmethod - def prepare_val_dataloader(self, val_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: ... - - @abstractmethod - def compile(self, model: CompileT) -> CompileT: ... - - @abstractmethod - def backward(self, loss: torch.Tensor) -> None: ... - - @abstractmethod - def clip_grad_norm_( - self, - parameters: Iterable[torch.Tensor] | torch.Tensor, - max_norm: float, - norm_type: float = 2, - ) -> torch.Tensor: ... - - @abstractmethod - def clip_grad_value_( - self, parameters: Iterable[torch.Tensor] | torch.Tensor, clip_value: float - ) -> None: ... - - @abstractmethod - def save_state(self, path: str | Path) -> None: ... - - @abstractmethod - def load_state(self, path: str | Path) -> None: ... - - @abstractmethod - def skip_first_batches(self, dataloader_iterator: Iterator, num_batches: int) -> None: ... - - @property - @abstractmethod - def sync_gradients(self) -> bool: ... - - @property - @abstractmethod - def device(self) -> torch.device: ... - - @property - @abstractmethod - def is_main_process(self) -> bool: ... - - @property - @abstractmethod - def is_local_main_process(self) -> bool: ... - - @property - @abstractmethod - def is_last_process(self) -> bool: ... - - @property - @abstractmethod - def is_local_last_process(self) -> bool: ... - - @property - @abstractmethod - def local_rank(self) -> int: ... - - @property - @abstractmethod - def global_rank(self) -> int: ... - - @property - @abstractmethod - def local_world_size(self) -> int: ... - - @property - @abstractmethod - def world_size(self) -> int: ... - - @property - @abstractmethod - def dtype(self) -> torch.dtype: ... - - def __enter__(self): - self.setup() - return self - - def __exit__(self, exc_type, exc_value, traceback): - # avoid barrier if exceution errored - if exc_type is None: - self.cleanup() - - # if exc_type is not None: - # raise exc_value - # Handle exceptions if necessary - # pass - - # def __del__(self): - # torch.distributed.barrier() - # torch.distributed.destroy_process_group() - - def check_filter(self, filter_: Filter): - return ( - filter_ == "all" - or (filter_ == "main_process" and self.is_main_process) - or (filter_ == "local_main_process" and self.is_local_main_process) - or (filter_ == "last" and self.is_last_process) - or (filter_ == "local_last" and self.is_local_last_process) - or (isinstance(filter_, (list, set)) and self.global_rank in filter_) - or (callable(filter_) and filter_(self.global_rank)) - ) - - def print( - self, *args, filter_: Filter = "main_process", rank_prefix=False, flush=True, **kwargs - ) -> None: - if not self.check_filter(filter_): - return - - if rank_prefix: - print(f"[global_rank={self.global_rank}]", *args, flush=flush, **kwargs) - else: - print(*args, flush=flush, **kwargs) - - def process_print( - self, *args, filter_: Filter = "all", rank_prefix=True, flush=True, **kwargs - ) -> None: - if not self.check_filter(filter_): - return - - if rank_prefix: - prefix = f"[global_rank={self.global_rank}]" - if len(args) == 1: # avoid out-of-order printing if possible - out = f"{prefix} {args[0]}" - args = (out,) - else: - args = (prefix, *args) - print(*args, flush=flush, **kwargs) - else: - print(*args, flush=flush, **kwargs) - - -class NativeDdpRuntime(IRuntime): - def __init__( - self, - dtype: torch.dtype = torch.float, - torch_distributed_timeout: timedelta | None = None, - ): - self._master_addr = os.environ["MASTER_ADDR"] - self._master_port = int(os.environ["MASTER_PORT"]) - self._local_rank = int(os.environ["LOCAL_RANK"]) - self._global_rank = int(os.environ["RANK"]) - self._local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - self._world_size = int(os.environ["WORLD_SIZE"]) - self._device = torch.device(self.local_rank) - self._dtype = dtype - self._torch_distributed_timeout = torch_distributed_timeout - - @override - def setup(self): - torch.cuda.set_device(self._device) - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group( - "cpu:gloo,cuda:nccl", timeout=self._torch_distributed_timeout - ) - input_tensors = [ - torch.tensor([0], dtype=torch.float32, device=self._device) - for _ in range(self.world_size) - ] - output_tensors = [ - torch.tensor([0], dtype=torch.float32, device=self._device) - for _ in range(self.world_size) - ] - torch.distributed.all_to_all(input_tensors, output_tensors) - - @override - def cleanup(self): - with suppress(Exception): - torch.distributed.barrier() - torch.distributed.destroy_process_group() - - @override - def autocast(self) -> AbstractContextManager: - result = torch.autocast(device_type="cuda", dtype=self._dtype, enabled=True) - return result - - @override - def wait_for_everyone(self): - torch.distributed.barrier() - - @override - def set_seed(self, seed: int, device_specific: bool = False) -> int: - """ - Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. - - Args: - seed (`int`): - The seed to set. - device_specific (`bool`, *optional*, defaults to `False`): - Whether to differ the seed on each device slightly with `self.process_index`. - """ - if device_specific: - seed += self.global_rank - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - return seed - - @override - def prepare_models(self, models: PrepareModelsT) -> PrepareModelsT: - assert all(isinstance(x, nn.Module) for x in models) - new_models = [nn.parallel.DistributedDataParallel(m) for m in models] - new_models = cast("PrepareModelsT", new_models) - return new_models # type: ignore[return-value] - - @override - def prepare_train_dataloader(self, train_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: - return train_dataloader - - @override - def prepare_val_dataloader(self, val_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: - return val_dataloader - - @override - def compile(self, model: CompileT) -> CompileT: - result = torch.compile(model) - result = cast("CompileT", result) - return result - - @override - def backward(self, loss: torch.Tensor) -> None: - loss.backward() - - @override - def clip_grad_norm_( - self, - parameters: Iterable[torch.Tensor] | torch.Tensor, - max_norm: float, - norm_type: float = 2, - ) -> torch.Tensor: - result = torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type) - return result - - @override - def clip_grad_value_( - self, parameters: Iterable[torch.Tensor] | torch.Tensor, clip_value: float - ) -> None: - torch.nn.utils.clip_grad_value_(parameters, clip_value) - - @override - def save_state(self, path: str | Path) -> None: - pass - - @override - def load_state(self, path: str | Path) -> None: - pass - - @override - def skip_first_batches(self, dataloader_iterator: Iterator, num_batches: int) -> None: - for _ in tqdm( - range(num_batches), desc=f"rank {self._global_rank}: skip_first_batches({num_batches=})" - ): - next(dataloader_iterator) - - @property - @override - def sync_gradients(self) -> bool: - return True - - @property - @override - def is_main_process(self) -> bool: - result = self.global_rank == 0 - return result - - @property - @override - def is_local_main_process(self) -> bool: - result = self.local_rank == 0 - return result - - @property - @override - def is_last_process(self) -> bool: - result = self.global_rank == self.world_size - 1 - return result - - @property - @override - def is_local_last_process(self) -> bool: - result = self.local_rank == self.local_world_size - 1 - return result - - @property - @override - def local_rank(self) -> int: - return self._local_rank - - @property - @override - def global_rank(self) -> int: - return self._global_rank - - @property - @override - def local_world_size(self) -> int: - return self._local_world_size - - @property - @override - def world_size(self) -> int: - return self._world_size - - @property - @override - def device(self) -> torch.device: - return self._device - - @property - @override - def dtype(self) -> torch.dtype: - return self._dtype - - @property - def master_addr(self) -> str: - return self._master_addr - - @property - def master_port(self) -> int: - return self._master_port - - -class BaseRuntime(IRuntime): - def __init__(self, dtype: torch.dtype = torch.float): - self._device = torch.device(self.local_rank) - self._dtype = dtype - - @override - def setup(self): - torch.cuda.set_device(self._device) - - @override - def cleanup(self): ... - - @override - def autocast(self) -> AbstractContextManager: - result = torch.autocast(device_type="cuda", dtype=self._dtype, enabled=True) - return result - - @override - def wait_for_everyone(self): ... - - @override - def set_seed(self, seed: int, device_specific: bool = False) -> int: - """ - Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. - - Args: - seed (`int`): - The seed to set. - device_specific (`bool`, *optional*, defaults to `False`): - Whether to differ the seed on each device slightly with `self.process_index`. - """ - if device_specific: - seed += self.global_rank - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - return seed - - @override - def prepare_models(self, models: PrepareModelsT) -> PrepareModelsT: - assert all(isinstance(x, nn.Module) for x in models) - return models - - @override - def prepare_train_dataloader(self, train_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: - return train_dataloader - - @override - def prepare_val_dataloader(self, val_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: - return val_dataloader - - @override - def compile(self, model: CompileT) -> CompileT: - result = torch.compile(model) - result = cast("CompileT", result) - return result - - @override - def backward(self, loss: torch.Tensor) -> None: - loss.backward() - - @override - def clip_grad_norm_( - self, - parameters: Iterable[torch.Tensor] | torch.Tensor, - max_norm: float, - norm_type: float = 2, - ) -> torch.Tensor: - result = torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type) - return result - - @override - def clip_grad_value_( - self, parameters: Iterable[torch.Tensor] | torch.Tensor, clip_value: float - ) -> None: - torch.nn.utils.clip_grad_value_(parameters, clip_value) - - @override - def save_state(self, path: str | Path) -> None: - pass - - @override - def load_state(self, path: str | Path) -> None: - pass - - @override - def skip_first_batches(self, dataloader_iterator: Iterator, num_batches: int) -> None: - for _ in tqdm( - range(num_batches), desc=f"rank {self.global_rank}: skip_first_batches({num_batches=})" - ): - next(dataloader_iterator) - - @property - @override - def sync_gradients(self) -> bool: - return True - - @property - @override - def is_main_process(self) -> bool: - result = self.global_rank == 0 - return result - - @property - @override - def is_local_main_process(self) -> bool: - result = self.local_rank == 0 - return result - - @property - @override - def is_last_process(self) -> bool: - result = self.global_rank == self.world_size - 1 - return result - - @property - @override - def is_local_last_process(self) -> bool: - result = self.local_rank == self.local_world_size - 1 - return result - - @property - @override - def local_rank(self) -> int: - return 0 - - @property - @override - def global_rank(self) -> int: - return 0 - - @property - @override - def local_world_size(self) -> int: - return 1 - - @property - @override - def world_size(self) -> int: - return 1 - - @property - @override - def device(self) -> torch.device: - return self._device - - @property - @override - def dtype(self) -> torch.dtype: - return self._dtype - - @property - def master_addr(self) -> str | None: - return None - - @property - def master_port(self) -> int | None: - return None diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py index 8d1a222c8..7a247bbdf 100644 --- a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -37,6 +37,7 @@ from transformers.utils.hub import cached_file, get_checkpoint_shard_files from typing_extensions import override +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, @@ -45,7 +46,6 @@ ) from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.runtime import IRuntime from modelopt.torch._compress.utils.utils import EmptyInitOnDevice @@ -144,14 +144,14 @@ def create_dummy_model( def load_and_shard_model( - runtime: IRuntime, checkpoint_path: str | Path, owned_block_indexes: set[int] | Literal["auto"] = "auto", model_config: DeciLMConfig | None = None, model_config_overrides: Mapping | None = None, + model_dtype: torch.dtype = torch.bfloat16, ) -> DeciLMForCausalLM: checkpoint_path = Path(checkpoint_path) - with runtime.device: + with torch.device(dist.local_rank()): if model_config is None: model_config = load_model_config( checkpoint_path, model_config_overrides, ignore_unexpected_config_keys=True @@ -159,14 +159,13 @@ def load_and_shard_model( if owned_block_indexes == "auto": owned_block_indexes = set( - np.array_split(np.arange(model_config.get_num_hidden_layers()), runtime.world_size)[ - runtime.global_rank + np.array_split(np.arange(model_config.get_num_hidden_layers()), dist.size())[ + dist.rank() ] ) mprint("Initializing model shards") model_shard = create_sharded_model( - runtime=runtime, model_config=model_config, owned_block_indexes=owned_block_indexes, ) @@ -182,7 +181,7 @@ def load_and_shard_model( shard_state_dict = load_sharded_state_dict( model_name_or_path=str(checkpoint_path), keys_to_load=shard_keys, - device=runtime.device, + device=torch.device(dist.local_rank()), ) new_names = set(shard_state_dict.keys()) @@ -196,15 +195,13 @@ def load_and_shard_model( model_shard.tie_weights() else: mprint("Loading state_dict in main process") - state_dict = load_state_dict(checkpoint_path) if runtime.is_main_process else None + state_dict = load_state_dict(checkpoint_path) if dist.is_master() else None mprint("Distributing model to shards") - load_state_dict_to_shards( - runtime=runtime, model_shard=model_shard, loaded_state_dict=state_dict - ) + load_state_dict_to_shards(model_shard=model_shard, loaded_state_dict=state_dict) del state_dict - model_shard.type(runtime.dtype) + model_shard.type(model_dtype) params_on_meta_device = [ param_name @@ -212,14 +209,13 @@ def load_and_shard_model( if param.device == torch.device("meta") ] assert len(params_on_meta_device) == 0, ( - f"[global_rank={runtime.global_rank}] Couldn't load params {params_on_meta_device}" + f"[global_rank={dist.rank()}] Couldn't load params {params_on_meta_device}" ) return model_shard def create_sharded_model( - runtime: IRuntime, model_config: DeciLMConfig, owned_block_indexes: set[int], device: str | torch.device | None = "meta", @@ -228,7 +224,7 @@ def create_sharded_model( if isinstance(device, str): device = torch.device(device) - runtime.wait_for_everyone() + dist.barrier() with EmptyInitOnDevice(device="meta", dtype=dtype): model = DeciLMForCausalLM(model_config) @@ -245,15 +241,18 @@ def create_sharded_model( def load_state_dict_to_shards( - runtime: IRuntime, model_shard: torch.nn.Module, loaded_state_dict: dict | None = None + model_shard: torch.nn.Module, loaded_state_dict: dict | None = None ) -> None: - from sewing_kit.utils import distributed_isend_obj, distributed_recv_obj + from modelopt.torch._compress.sewing_kit.utils import ( + distributed_isend_obj, + distributed_recv_obj, + ) model_shard.to("meta") local_state_dict_keys = list(model_shard.state_dict().keys()) - if runtime.is_main_process: - gathered_state_dict_keys = [None] * runtime.world_size + if dist.is_master(): + gathered_state_dict_keys = [None] * dist.size() torch.distributed.gather_object(local_state_dict_keys, gathered_state_dict_keys) assert loaded_state_dict is not None @@ -276,7 +275,7 @@ def load_state_dict_to_shards( torch.distributed.gather_object(local_state_dict_keys) shard_state_dict = distributed_recv_obj() - print(f"{runtime.global_rank=} loaded state_dict shard") + print(f"{dist.rank()} loaded state_dict shard") missing_keys, unexpected_keys = model_shard.load_state_dict( shard_state_dict, strict=False, assign=True @@ -284,20 +283,18 @@ def load_state_dict_to_shards( assert len(unexpected_keys) == 0 assert all("dummy_param" in key for key in missing_keys) - model_shard.to(runtime.device) + model_shard.cuda(dist.local_rank()) - runtime.wait_for_everyone() + dist.barrier() def save_sharded_model( - runtime: IRuntime, - model_shard: torch.nn.Module | dict[str, torch.Tensor], - out_path: str | Path, + model_shard: torch.nn.Module | dict[str, torch.Tensor], out_path: str | Path ): """ out_path is usually output_checkpoint_path / "model.safetensors" """ - runtime.wait_for_everyone() + dist.barrier() if isinstance(model_shard, torch.nn.Module): shard_state_dict = model_shard.state_dict() @@ -311,8 +308,8 @@ def save_sharded_model( weight.numel() * weight.element_size() for weight in shard_state_dict.values() ) - num_shards = runtime.world_size - idx = runtime.global_rank + num_shards = dist.size() + idx = dist.rank() out_path = Path(out_path) shard_file = out_path.with_stem(f"{out_path.stem}-{idx + 1:05d}-of-{num_shards:05d}") @@ -323,8 +320,8 @@ def save_sharded_model( "shard_file": str(shard_file), } - if runtime.is_main_process: - shard_metadatas = [{} for _ in range(runtime.world_size)] + if dist.is_master(): + shard_metadatas = [{} for _ in range(dist.size())] torch.distributed.gather_object(shard_metadata, shard_metadatas, dst=0) total_size = sum(x["total_shard_size"] for x in shard_metadatas) metadata = {"total_size": total_size} @@ -346,33 +343,7 @@ def save_sharded_model( else: torch.save(shard_state_dict, shard_file) - runtime.wait_for_everyone() - - -def save_sharded_state_dict( - state_dict: dict[str, torch.Tensor], - save_directory: str | Path, - max_shard_size: str = "10GB", -) -> None: - save_directory = Path(save_directory) - save_directory.mkdir(exist_ok=True, parents=True) - state_dict = {k: v.cpu() for k, v in state_dict.items()} - - state_dict_split = split_torch_state_dict_into_shards(state_dict, max_shard_size=max_shard_size) - - for shard_filename, param_names in tqdm( - state_dict_split.filename_to_tensors.items(), desc="saving sharded state dict" - ): - shard_path = save_directory / shard_filename - shard = {param_name: state_dict[param_name] for param_name in param_names} - safe_save_file(shard, shard_path, metadata={"format": "pt"}) - - index = { - "metadata": state_dict_split.metadata, - "weight_map": state_dict_split.tensor_to_filename, - } - index_path = save_directory / SAFE_WEIGHTS_INDEX_NAME - index_path.write_text(json.dumps(index, indent=2)) + dist.barrier() def load_sharded_state_dict( @@ -410,13 +381,3 @@ def _resolve_shard_paths(model_name_or_path: str) -> list[str]: def is_in_safetensors_format(checkpoint_dir: Path) -> bool: return len(list(checkpoint_dir.glob("*.safetensors"))) > 0 - - -def load_state_dict_shapes(model_name_or_path: str | Path) -> dict[str, tuple]: - shard_paths = _resolve_shard_paths(model_name_or_path) - state_dict_shapes = {} - for safetensors_path in shard_paths: - with safe_open(safetensors_path, framework="pt") as f: - for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable - state_dict_shapes[key] = tuple(f.get_tensor(key).shape) - return state_dict_shapes diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py index 8ec1d6f17..d3d71a419 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -21,11 +21,10 @@ TODO: Consider moving this a separate module dedicated for scoring. """ -import argparse import textwrap from pathlib import Path -import torch.distributed +import torch from omegaconf import DictConfig from torch import nn from torch.utils.data import DataLoader @@ -36,12 +35,12 @@ PreTrainedTokenizerBase, ) +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.activation_scoring.activation_hooks.utils import ( register_activation_hooks, ) from modelopt.torch._compress.tools.checkpoint_utils_hf import load_checkpoint from modelopt.torch._compress.tools.logger import aprint, mprint -from modelopt.torch._compress.tools.runtime import IRuntime, NativeDdpRuntime from modelopt.torch._compress.tools.sharded_checkpoint_utils import load_and_shard_model from modelopt.torch._compress.utils.data.dataloaders import create_validation_dataloader from modelopt.torch._compress.utils.parsing import simple_parse_args_string @@ -51,12 +50,6 @@ ) from modelopt.torch._compress.utils.validation import calculate_losses -# #TODO:Import slack from root utils directory -# root_path = os.path.join(os.path.dirname(__file__), "..", "..") -# if root_path not in sys.path: -# sys.path.append(root_path) -# from utils.slack import send_slack_message - """ Two goals: 1) Calculate lm loss and token accuracy for a model. @@ -67,88 +60,89 @@ 2) Register hooks to capture the inputs and the outputs of pytorch modules. For example, to collect activations scores for various layers (ffn, layer_norm, etc.) that are used for pruning (ffn_hidden_size, embedding_pruning, etc). -See --activations_log_dir and --activation_hooks_kwargs args arguments. - +See activations_log_dir and activation_hooks_kwargs arguments. """ -def build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument( - "--model_name_or_path", - type=str, - default=None, - help="Required unless a model is passed to the function", - ) - parser.add_argument("--dataset_path", type=str, required=True) - - parser.add_argument("--output_dir_name", type=str, default="validation") - parser.add_argument( - "--calculate_full_score_ablations", - action="store_true", - help="Calculates a diverse suite of teacher similarity scores. " - "By default only a small suite is calculated, which is good for most use-cases.", - ) - - parser.add_argument("--tokenizer_name", type=str, default=None) - parser.add_argument("--data_column", type=str, default="content") - # TODO: Add help text for FIM rate, also for others less obvious args - parser.add_argument("--fim_rate", type=float, default=0) - parser.add_argument("--fim_spm_rate", type=float, default=0) - parser.add_argument("--eval_samples", type=int, default=None) - parser.add_argument("--block_size", type=int, default=4096) - parser.add_argument("--micro_batch_size", type=int, default=4) - parser.add_argument("--val_dataset_name", type=str, default="__auto__") - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--source_datasets_to_discard", nargs="+", type=str) - parser.add_argument("--bos_rate", type=float, default=1.0) - parser.add_argument("--shuffle_seed", type=int, default=None) - parser.add_argument("--varlen", action="store_true") - parser.add_argument("--pipeline_parallel", action="store_true") - parser.add_argument("--write_results", action="store_true") - parser.add_argument("--activations_log_dir", type=str, default=None) - parser.add_argument( - "--activation_hooks_kwargs", - type=str, - default=None, - help="Comma separated string arguments, e.g. `arg1=val1,arg2=val2`", - ) - parser.add_argument( - "--calc_losses_on_cpu", - action="store_true", - help="Very slow, not recommended. Can help avoid OOM.", - ) - return parser - - -def parse_args() -> argparse.Namespace: - parser = build_arg_parser() - args, unknown_args = parser.parse_known_args() - return args - - @torch.no_grad() def validate_model( - args: argparse.Namespace | DictConfig, + args: DictConfig, model: PreTrainedModel | None = None, tokenizer: PreTrainedTokenizerBase | None = None, target_hidden_states_per_batch: list[torch.Tensor] | None = None, return_hidden_states: bool = False, - runtime: IRuntime | None = None, + pipeline_parallel: bool = False, calculate_full_score_ablations: bool = False, val_dataloader: DataLoader | None = None, ) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: + """Validate a language model on a dataset by calculating loss and optionally capturing activations. + + Args: + args: Configuration object containing the following attributes: + + Model Configuration: + - model_name_or_path (str): Path to model checkpoint or HuggingFace model name. + Required unless model is passed directly. + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration: + - dataset_path (str): Path to the validation dataset. + - tokenizer_name (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. Uses all if None. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing: + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. Uses seed if None. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Activation Hooks: + - activations_log_dir (str, optional): Directory to log activation scores. If provided, + hooks will be registered to capture activations. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. + If string, comma-separated format: "arg1=val1,arg2=val2". + + Execution Options: + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. + - write_results (bool): Write validation results to file. + + model: Pre-loaded model. If None, will be loaded from args.model_name_or_path. + tokenizer: Pre-loaded tokenizer. If None, will be loaded based on args. + target_hidden_states_per_batch: Target hidden states for pipeline parallel evaluation. + return_hidden_states: Whether to return hidden states from the model. + pipeline_parallel: Enable pipeline parallelism for large models. + calculate_full_score_ablations: Calculate comprehensive teacher similarity scores. + False calculates only a small suite for efficiency. + val_dataloader: Pre-created validation dataloader. If None, will be created from args. + + Returns: + A tuple containing: + - losses: Dictionary mapping loss names to loss statistics (avg, per_sample). + - hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None. + Returns (None, None) if not on master rank. + """ + # convert model_dtype and autocast_dtype from string to torch.dtype + if isinstance(args.model_dtype, str): + args.model_dtype = getattr(torch, args.model_dtype.strip("torch.")) + if isinstance(args.autocast_dtype, str): + args.autocast_dtype = getattr(torch, args.autocast_dtype.strip("torch.")) + if val_dataloader is None: - val_dataloader = ( - prepare_dataloader(args, tokenizer) - if (runtime is None or runtime.is_main_process) - else None - ) + val_dataloader = prepare_dataloader(args, tokenizer) if dist.is_master() else None validation_full_iters = ( args.eval_samples // args.micro_batch_size ) # model pipeline, single data rank - model = prepare_model(args, model, runtime) + model = prepare_model(args, model, pipeline_parallel) just_model_forward = False checkpoint_manager = None @@ -175,7 +169,6 @@ def validate_model( ) checkpoint_manager = ScoringCheckpointManager( checkpoint_dir=args.activations_log_dir, - runtime=runtime, activation_hooks=activation_hooks, checkpoint_interval=50, # Save every 50 batches ) @@ -190,7 +183,7 @@ def validate_model( just_model_forward = True model.lm_head = nn.Identity() - if runtime is None: + if not pipeline_parallel: losses, hidden_states_per_batch = calculate_losses( model=model, dataloader=val_dataloader, @@ -198,7 +191,6 @@ def validate_model( ) else: losses, hidden_states_per_batch = calculate_losses_pipeline( - runtime=runtime, stitched_model=model, dataloader=val_dataloader, target_hidden_states_per_batch=target_hidden_states_per_batch, @@ -207,6 +199,7 @@ def validate_model( calc_on_cpu=args.calc_losses_on_cpu, just_model_forward=just_model_forward, checkpoint_manager=checkpoint_manager, + autocast_dtype=args.autocast_dtype, ) if losses is not None: @@ -223,26 +216,23 @@ def validate_model( aprint(results_str) if args.write_results: Path(f"{args.model_name_or_path}/validate_model_results.txt").write_text(results_str) - # TODO: send_slack_message(results_str) if activation_hooks is not None: - hook_class.dump_activations_logs(activation_hooks, args.activations_log_dir, args, runtime) + hook_class.dump_activations_logs(activation_hooks, args.activations_log_dir, args) return losses, hidden_states_per_batch def prepare_model( - args: argparse.Namespace, - model: PreTrainedModel | None = None, - runtime: IRuntime | None = None, + args: DictConfig, model: PreTrainedModel | None = None, pipeline_parallel: bool = False ) -> nn.Module: if model is None: assert args.model_name_or_path is not None - if runtime is not None: + if pipeline_parallel: model = load_and_shard_model( - runtime, args.model_name_or_path, model_config_overrides={"block_size": args.block_size}, + model_dtype=args.model_dtype, ) else: try: @@ -265,8 +255,7 @@ def prepare_model( def prepare_dataloader( - args: argparse.Namespace, - tokenizer: PreTrainedTokenizerBase | None = None, + args: DictConfig, tokenizer: PreTrainedTokenizerBase | None = None ) -> DataLoader: if tokenizer is None: tokenizer_name = getattr(args, "tokenizer_name", None) @@ -295,16 +284,3 @@ def prepare_dataloader( ) return val_dataloader - - -def main(): - args = parse_args() - if args.pipeline_parallel: - with NativeDdpRuntime(dtype=torch.bfloat16) as runtime: - validate_model(args=args, runtime=runtime) - else: - validate_model(args=args, runtime=None) - - -if __name__ == "__main__": - main() diff --git a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py index e947e97e4..ca0299868 100644 --- a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py @@ -20,7 +20,6 @@ # mypy: ignore-errors -import argparse import json import shutil import warnings @@ -29,11 +28,12 @@ from typing import Optional import torch +from omegaconf import DictConfig from tqdm import tqdm from transformers import AutoTokenizer, PreTrainedTokenizerBase +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch._compress.replacement_library.build_replacement_library import infer_teacher_dir from modelopt.torch._compress.replacement_library.replacement_library import ReplacementLibrary from modelopt.torch._compress.replacement_library.replacement_utils import parse_layer_replacement from modelopt.torch._compress.tools import validate_model @@ -45,7 +45,6 @@ save_checkpoint, save_safetensors_index, ) -from modelopt.torch._compress.tools.runtime import IRuntime from modelopt.torch._compress.tools.validation_utils import ( validate_model_and_extract_hidden_states, validate_model_with_teacher_similarity_metrics, @@ -54,64 +53,71 @@ from modelopt.torch._compress.utils.validate_runtime_pipeline import perform_pipeline_stitches """ -Usage: -====== - -Validate single_block_replacement_solutions -=========================================== - -( -export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"; -PUZZLE_DIR=".../Llama-3_2-1B-Instruct/parallel_puzzle"; - -torchrun --nproc-per-node=8 \ - -m modelopt.torch._compress.tools.validate_puzzle_with_multi_replacements \ - --replacement_library_path ${PUZZLE_DIR}/replacement_library.json \ - --solutions_path ${PUZZLE_DIR}/single_sequence_replacement_solutions.json \ - --solutions_to_validate 0 \ - \ - --dataset_path .../v0.4/valid \ - --data_column conversation --block_size 8192 --seed 42 --shuffle_seed 444 --bos_rate 0.5 \ - --eval_samples 32 --micro_batch_size 1 \ - \ - --save_models \ - -) +Usage Example: +============== +Validate single_block_replacement_solutions by calling validate_puzzle_solutions() directly +with an args object containing the required attributes. See the function docstring for details. """ -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument("--replacement_library_path", type=parse_path, required=True) - parser.add_argument("--solutions_path", type=parse_path, required=True) - parser.add_argument("--teacher_dir", type=parse_path, default=None) - parser.add_argument("--solutions_to_validate", type=int, nargs="+", default=None) - parser.add_argument("--sort_solutions_by", type=str, default=None) - parser.add_argument("--bigger_is_better", action="store_true") - parser.add_argument("--skip_validation", action="store_true") - parser.add_argument("--save_models", action="store_true") - args, unknown_args = parser.parse_known_args() - if not args.skip_validation: - validation_args = validate_model.build_arg_parser().parse_args(unknown_args) - args = argparse.Namespace( - **{**validation_args.__dict__, **args.__dict__} - ) # if arg names overlap, the latter one wins - else: - args.block_size = None - - args.teacher_dir = _try_infer_teacher_dir(args.replacement_library_path, args.teacher_dir) - - args.tokenizer_name = getattr(args, "tokenizer_name", None) - if args.tokenizer_name is None: - args.tokenizer_name = args.teacher_dir - - return args - - @torch.no_grad() -def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> None: +def validate_puzzle_solutions(args: DictConfig) -> None: + """Validate puzzle solutions by applying layer replacements and evaluating model performance. + + Args: + args: Configuration object containing the following attributes: + + Puzzle Configuration (Required): + - replacement_library_path (Path): Path to the replacement library JSON file. + - solutions_path (Path): Path to puzzle solutions JSON file or directory containing solution files. + - solutions_to_validate (list[int], optional): Indices of specific solutions to validate. + Validates all solutions if None. + - sort_solutions_by (str, optional): JSON field path to sort solutions by before validation. + - bigger_is_better (bool): If True, sort solutions in descending order. Used with sort_solutions_by. + - skip_validation (bool): If True, skip model validation and only save models if requested. + - save_models (bool): If True, save realized model checkpoints for each solution. + + Teacher/Tokenizer Configuration: + - teacher_dir (Path, optional): Path to teacher model directory. Auto-inferred if not provided. + - tokenizer_name (str, optional): Tokenizer name/path. Uses teacher_dir if not specified. + + Model Configuration (Required if skip_validation=False): + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration (Required if skip_validation=False): + - dataset_path (str): Path to the validation dataset. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing (Required if skip_validation=False): + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Output Configuration: + - output_dir (Path, optional): Directory to save validation results. + Auto-generated from solutions_path if not provided. + + Execution Options (Optional if skip_validation=False): + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. + - write_results (bool): Write validation results to file. + - activations_log_dir (str, optional): Directory to log activation scores. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. + + Returns: + None. Saves validation results and optionally model checkpoints to disk. + """ puzzle_solutions = load_puzzle_solutions( args.solutions_path, args.sort_solutions_by, args.bigger_is_better ) @@ -122,9 +128,7 @@ def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> No tokenizer = _load_tokenizer(args) if not args.skip_validation: val_dataloader = ( - validate_model.prepare_dataloader(args, tokenizer) - if (runtime is None or runtime.is_main_process) - else None + validate_model.prepare_dataloader(args, tokenizer) if dist.is_master() else None ) output_dir = ( @@ -137,18 +141,16 @@ def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> No teacher_hidden_states = None if (args.teacher_dir is not None) and (not args.skip_validation): - teacher_model = replacement_library.load_checkpoint( - args.teacher_dir, runtime.world_size, runtime.global_rank - ) - teacher_model.to(runtime.device) - stitched_model = perform_pipeline_stitches(teacher_model, runtime) + teacher_model = replacement_library.load_checkpoint(args.teacher_dir) + teacher_model.cuda(dist.local_rank()) + stitched_model = perform_pipeline_stitches(teacher_model) teacher_hidden_states = validate_model_and_extract_hidden_states( args, stitched_model, tokenizer, output_dir, model_name="teacher", - runtime=runtime, + pipeline_parallel=True, val_dataloader=val_dataloader, ) @@ -160,9 +162,7 @@ def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> No # realizable_as_symlinks = False model_config = replacement_library.create_model_config(layer_replacements) if (args.save_models and not realizable_as_symlinks) or (not args.skip_validation): - model = replacement_library.load_model( - layer_replacements, runtime.world_size, runtime.global_rank - ) + model = replacement_library.load_model(layer_replacements) model_config = model.config if args.save_models: @@ -171,10 +171,10 @@ def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> No / f"solution_{i_solution}" ) - model_config.dtype = "bfloat16" + model_config.dtype = args.model_dtype model_config.architectures = ["DeciLMForCausalLM"] if realizable_as_symlinks: - if runtime.global_rank == 0: + if dist.is_master(): save_checkpoint_as_symlinks( layer_replacements, model_config, checkpoint_dir, replacement_library ) @@ -184,13 +184,11 @@ def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> No copy_tokenizer(args.tokenizer_name, checkpoint_dir) copy_hf_code(checkpoint_dir) - runtime.wait_for_everyone() - - runtime.wait_for_everyone() + dist.barrier() if not args.skip_validation: - model.to(runtime.device) - stitched_model = perform_pipeline_stitches(model, runtime) + model.cuda(dist.local_rank()) + stitched_model = perform_pipeline_stitches(model) validate_model_with_teacher_similarity_metrics( args, stitched_model, @@ -199,11 +197,11 @@ def validate_puzzle_solutions(args: argparse.Namespace, runtime: IRuntime) -> No output_dir, model_name=f"solution_{i_solution}", extra_payload={"i_solution": i_solution, "puzzle_solution": puzzle_solution}, - runtime=runtime, + pipeline_parallel=True, val_dataloader=val_dataloader, ) - runtime.wait_for_everyone() + dist.barrier() def can_realize_as_symlinks(layer_replacements: list[dict]) -> bool: @@ -255,23 +253,7 @@ def copy_hf_code(checkpoint_dir: Path) -> None: shutil.copy(file, checkpoint_dir / file.name) -def _try_infer_teacher_dir( - replacement_library_path: str | Path, - teacher_dir: str | Path | None, -) -> Path | None: - if teacher_dir is not None: - return teacher_dir - - try: - teacher_dir = infer_teacher_dir( - master_puzzle_dir=Path(replacement_library_path).parent, teacher_checkpoint_dir=None - ) - return teacher_dir - except: - return None - - -def _load_tokenizer(args: argparse.Namespace) -> PreTrainedTokenizerBase: +def _load_tokenizer(args: DictConfig) -> PreTrainedTokenizerBase: tokenizer = None if (tokenizer_name := getattr(args, "tokenizer_name", None)) is not None: tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) @@ -324,7 +306,3 @@ def load_puzzle_solutions( print(f"sorted solutions by {sort_solutions_by}. {vals[:10]=} {vals[-10:]=}") return puzzle_solutions - - -if __name__ == "__main__": - validate_puzzle_solutions(args=parse_args()) diff --git a/modelopt/torch/_compress/tools/validation_utils.py b/modelopt/torch/_compress/tools/validation_utils.py index 907dee402..6f0b1fcb5 100644 --- a/modelopt/torch/_compress/tools/validation_utils.py +++ b/modelopt/torch/_compress/tools/validation_utils.py @@ -20,31 +20,32 @@ # mypy: ignore-errors -import argparse from pathlib import Path -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch from omegaconf import DictConfig, OmegaConf from torch import nn from transformers import PreTrainedTokenizerBase -from modelopt.torch._compress.sewing_kit import StitchedModule +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.tools import validate_model from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.tools.robust_json import json_dump -from modelopt.torch._compress.tools.runtime import IRuntime from modelopt.torch._compress.utils.validation import LowMemorySparseTensor +if TYPE_CHECKING: + from modelopt.torch._compress.sewing_kit import StitchedModule + def validate_model_and_extract_hidden_states( - args: argparse.Namespace, - model: nn.Module | StitchedModule, + args: DictConfig, + model: "nn.Module | StitchedModule", tokenizer: PreTrainedTokenizerBase, - output_dir: Union[str, Path], + output_dir: str | Path, model_name: str, extra_payload: Optional[dict[str, Any]] = None, - runtime: Optional[IRuntime] = None, + pipeline_parallel: bool = False, val_dataloader=None, ) -> list[torch.Tensor | LowMemorySparseTensor]: mprint(f""" @@ -59,10 +60,10 @@ def validate_model_and_extract_hidden_states( model, tokenizer, return_hidden_states=True, - runtime=runtime, + pipeline_parallel=pipeline_parallel, val_dataloader=val_dataloader, ) - if runtime is None or runtime.is_last_process: + if dist.is_last_process(): output_dir = output_dir if (output_dir is not None) else args.bypass_dir extra_payload = extra_payload if (extra_payload is not None) else dict() write_results(output_dir, model_name, args, {**losses, **extra_payload}) @@ -70,14 +71,14 @@ def validate_model_and_extract_hidden_states( def validate_model_with_teacher_similarity_metrics( - args: argparse.Namespace, - model: nn.Module | StitchedModule, + args: DictConfig, + model: "nn.Module | StitchedModule", tokenizer: PreTrainedTokenizerBase, target_hidden_states_per_batch: list[torch.Tensor], - output_dir: Union[str, Path], + output_dir: str | Path, model_name: str, extra_payload: Optional[dict[str, Any]] = None, - runtime: Optional[IRuntime] = None, + pipeline_parallel: bool = False, calculate_full_score_ablations: bool = False, val_dataloader=None, ) -> None: @@ -94,20 +95,17 @@ def validate_model_with_teacher_similarity_metrics( model, tokenizer, target_hidden_states_per_batch=target_hidden_states_per_batch, - runtime=runtime, + pipeline_parallel=pipeline_parallel, calculate_full_score_ablations=calculate_full_score_ablations, val_dataloader=val_dataloader, ) - if runtime is None or runtime.is_last_process: + if dist.is_last_process(): extra_payload = extra_payload if (extra_payload is not None) else dict() write_results(output_dir, model_name, args, {**losses, **extra_payload}) def write_results( - output_dir: Union[str, Path], - result_name: str, - args: argparse.Namespace, - payload: dict[str, Any], + output_dir: str | Path, result_name: str, args: DictConfig, payload: dict[str, Any] ) -> None: output_path = Path(output_dir) / f"{result_name}.json" output_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/modelopt/torch/_compress/utils/checkpoint_manager.py b/modelopt/torch/_compress/utils/checkpoint_manager.py index b96fd21a5..7a2733446 100644 --- a/modelopt/torch/_compress/utils/checkpoint_manager.py +++ b/modelopt/torch/_compress/utils/checkpoint_manager.py @@ -22,30 +22,27 @@ from pathlib import Path from typing import Any, Dict, Optional +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.tools.logger import aprint, mprint class ScoringCheckpointManager: """Manages checkpointing for activation hook scoring with periodic saves.""" - def __init__( - self, checkpoint_dir: str, runtime, activation_hooks=None, checkpoint_interval: int = 100 - ): + def __init__(self, checkpoint_dir: str, activation_hooks=None, checkpoint_interval: int = 100): """ Initialize checkpoint manager. Args: checkpoint_dir: Directory to save checkpoints - runtime: Runtime object for distributed processing activation_hooks: Dictionary of activation hooks to manage checkpoint_interval: Save checkpoint every N batches """ self.checkpoint_dir = Path(checkpoint_dir) - self.runtime = runtime self.activation_hooks = activation_hooks self.checkpoint_interval = checkpoint_interval - self.rank = runtime.global_rank if runtime is not None else 0 - self.is_main_process = runtime is None or runtime.is_main_process + self.rank = dist.rank() + self.is_main_process = dist.is_master() # Debug: Log checkpoint manager initialization hook_count = len(activation_hooks) if activation_hooks else 0 @@ -200,9 +197,7 @@ def update_progress(self, batch_idx: int, total_batches: int): ActivationsHook, ) - saved_path = ActivationsHook.save_hook_states( - self.activation_hooks, self.checkpoint_dir, self.runtime - ) + ActivationsHook.save_hook_states(self.activation_hooks, self.checkpoint_dir) except Exception as e: mprint(f"Warning: Failed to save hook states: {e}") @@ -211,8 +206,7 @@ def update_progress(self, batch_idx: int, total_batches: int): self.save_checkpoint() # Synchronize all ranks after checkpointing - if self.runtime is not None: - self.runtime.wait_for_everyone() + dist.barrier() def save_checkpoint(self): """ @@ -260,7 +254,7 @@ def finalize(self): ) saved_path = ActivationsHook.save_hook_states( - self.activation_hooks, self.checkpoint_dir, self.runtime + self.activation_hooks, self.checkpoint_dir ) mprint(f"Final hook states saved to {saved_path}") except Exception as e: @@ -273,5 +267,4 @@ def finalize(self): mprint(f"Scoring completed and finalized: {self.total_batches} batches processed") # Synchronize all ranks after finalization - if self.runtime is not None: - self.runtime.wait_for_everyone() + dist.barrier() diff --git a/modelopt/torch/_compress/utils/data/dataloaders.py b/modelopt/torch/_compress/utils/data/dataloaders.py index 584e32480..865ad89fb 100644 --- a/modelopt/torch/_compress/utils/data/dataloaders.py +++ b/modelopt/torch/_compress/utils/data/dataloaders.py @@ -17,7 +17,6 @@ DataLoader utilities for language model training and validation. """ -import os from collections.abc import Callable, Mapping, Sequence from functools import partial from typing import Protocol, TypeVar @@ -74,58 +73,6 @@ def load_streaming_fn( return dataset -def create_train_dataloader( - accelerator: Accelerator, - seed: int, - tokenizer: PreTrainedTokenizerBase, - block_size: int, - dataset: str | Mapping[str, Dataset], - content_field: str, - fim_rate: float, - fim_spm_rate: float, - micro_batch_size: int, - load_dataset_fn: LoadDatasetFn = load_from_disk_fn, - dataset_name="train", - keep_in_memory: bool = False, - shuffle_train_data_seed: int | None = None, - source_datasets_to_discard: Sequence[str] = (), - bos_rate: float = 1.0, - varlen: bool = True, -): - mprint(f"\ncreate_train_dataloader on rank {accelerator.process_index}") - if isinstance(dataset, str): - dataset = load_dataset_fn(dataset, content_field, keep_in_memory) - - train_data = dataset[dataset_name] - if shuffle_train_data_seed is not None: - train_data = train_data.shuffle(seed=shuffle_train_data_seed) - - train_dataset = ConstantLengthDataset( - tokenizer, - train_data, - infinite=True, - seq_length=block_size * micro_batch_size if varlen else block_size, - content_field=content_field, - fim_rate=fim_rate, - fim_spm_rate=fim_spm_rate, - seed=seed, - source_datasets_to_discard=source_datasets_to_discard, - bos_rate=bos_rate, - # return_cu_seqlens=varlen, - # seqlen_cap=block_size if varlen else None - ) - - train_dataloader = DataLoader( - train_dataset, - batch_size=1 if varlen else micro_batch_size, - pin_memory=True, - collate_fn=collate_fn_with_none_support, - num_workers=os.cpu_count() // 2 // 8, - ) - - return train_dataloader - - def create_validation_dataloader( accelerator: Accelerator | None, seed: int, @@ -231,75 +178,6 @@ def realize_dataset_in_memory(dataset: IterableDataset, eval_samples: int | None return offloaded_dataset -def create_dataloaders( - accelerator: Accelerator, - seed: int, - tokenizer: PreTrainedTokenizerBase, - block_size: int, - dataset_path: str, - content_field: str, - fim_rate: float, - fim_spm_rate: float, - micro_batch_size: int, - val_micro_batch_size: int | None = None, - eval_samples: int | None = None, - load_dataset_fn: LoadDatasetFn = load_from_disk_fn, - train_dataset_name: str = "train", - val_dataset_name: str = "__auto__", - disable_validation: bool = False, - keep_in_memory: bool = False, - shuffle_train_data_seed: int | None = None, - source_datasets_to_discard: Sequence[str] = (), - bos_rate: float = 1.0, - varlen: bool = True, -): - if val_micro_batch_size is None: - val_micro_batch_size = micro_batch_size - - dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory=keep_in_memory) - - train_dataloader = create_train_dataloader( - accelerator, - seed, - tokenizer, - block_size, - dataset, - content_field, - fim_rate, - fim_spm_rate, - micro_batch_size, - load_dataset_fn, - train_dataset_name, - shuffle_train_data_seed=shuffle_train_data_seed, - source_datasets_to_discard=source_datasets_to_discard, - bos_rate=bos_rate, - varlen=varlen, - ) - - if not disable_validation: - val_dataloader = create_validation_dataloader( - accelerator, - seed, - tokenizer, - block_size, - dataset, - content_field, - fim_rate, - fim_spm_rate, - val_micro_batch_size, - eval_samples, - load_dataset_fn, - val_dataset_name, - source_datasets_to_discard=source_datasets_to_discard, - bos_rate=bos_rate, - varlen=varlen, - ) - else: - val_dataloader = None - - return train_dataloader, val_dataloader - - TensorT = TypeVar("TensorT", bound=torch.Tensor) diff --git a/modelopt/torch/_compress/utils/dist_utils.py b/modelopt/torch/_compress/utils/dist_utils.py deleted file mode 100644 index 84f8f2bab..000000000 --- a/modelopt/torch/_compress/utils/dist_utils.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import torch.distributed as dist - - -def is_distributed(): - """ - From torchtune.utils.is_distributed() : https://docs.pytorch.org/torchtune/0.2/generated/torchtune.utils.is_distributed.html - """ - port = os.environ.get("MASTER_PORT", "") - addr = os.environ.get("MASTER_ADDR", "") - size = int(os.environ.get("WORLD_SIZE", 1)) - rank = int(os.environ.get("RANK", -1)) - avlb = dist.is_available() - return bool(port and addr and size > 1 and rank >= 0 and avlb) diff --git a/modelopt/torch/_compress/utils/utils.py b/modelopt/torch/_compress/utils/utils.py index d03ea8040..62b7678eb 100644 --- a/modelopt/torch/_compress/utils/utils.py +++ b/modelopt/torch/_compress/utils/utils.py @@ -63,7 +63,7 @@ def raise_unknown_subblock_config_error(subblock_config: Any) -> None: ) -def sizeof_dtype(dtype: torch.dtype | str) -> int | float: +def sizeof_dtype(dtype: torch.dtype) -> int | float: """Return the size in bytes of the given data type. TODO: Consider a better place for this function. diff --git a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py index aa8a4f304..b3be70644 100644 --- a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py +++ b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py @@ -23,16 +23,12 @@ """ # mypy: ignore-errors -from statistics import mean - import numpy as np import torch -import torch.distributed -import wandb from torch.utils.data import DataLoader from tqdm import tqdm -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMForCausalLM, LMHead, @@ -52,148 +48,10 @@ fake_tensor, ) from modelopt.torch._compress.tools.checkpoint_utils import init_module_with_state_dict -from modelopt.torch._compress.tools.logger import mprint -from modelopt.torch._compress.tools.runtime import IRuntime from modelopt.torch._compress.tools.sharded_checkpoint_utils import DummyBlock from modelopt.torch._compress.utils.validation import _organize_outputs, calculate_batch_outputs -@torch.no_grad() -def validate_pipeline_inner( - runtime: IRuntime, - stitched_model: StitchedModule, - val_dataloader: DataLoader | None, -) -> float: - if runtime.is_main_process: - assert val_dataloader.batch_size is not None - model_device = next(stitched_model.parameters()).device - - with runtime.autocast(): - stitched_model.eval() - - all_logits: list[torch.Tensor] = [] - all_targets: list[torch.Tensor] = [] - losses: list[float] = [] - - if runtime.is_main_process: - input_ids: torch.Tensor - targets: torch.Tensor - - for i_batch, batch in enumerate(tqdm(val_dataloader)): - input_ids, targets = ( - batch["input_ids"].to(model_device), - batch["targets"].to(model_device), - ) - - if i_batch == 0: - num_batches = len(val_dataloader) - seq_len = input_ids.shape[1] - if torch.distributed.is_initialized(): - torch.distributed.broadcast_object_list([(num_batches, seq_len)]) - - all_targets.append(targets.cpu()) - - output = stitched_model({}, {}, input_ids) - logits = output.captured_outputs.get("model_output") - logits = getattr(logits, "logits", logits) - - if logits is not None: - all_logits.append(logits.cpu()) - - del output, logits - - if len(all_targets) > 0: - distributed_send_obj(all_targets, dst=runtime.world_size - 1) - - else: - obj_list: list[tuple] = [None] - torch.distributed.broadcast_object_list(obj_list) - num_batches, seq_len = obj_list[0] - - fake_input_ids = fake_tensor(1, seq_len, dtype=runtime.dtype) - - for i in range(num_batches): - output = stitched_model({}, {}, fake_input_ids) - logits = output.captured_outputs.get("model_output") - logits = getattr(logits, "logits", logits) - if logits is not None: - all_logits.append(logits.cpu()) - del output, logits - - if len(all_targets) == 0 and runtime.global_rank == runtime.world_size - 1: - all_targets = distributed_recv_obj(src=0) - - torch.distributed.barrier() - - if len(all_logits) > 0: - for logits, targets in zip(all_logits, all_targets): - logits = logits.to("cuda") - targets = targets.to("cuda") - logit_losses = torch.nn.functional.cross_entropy( - logits.transpose(1, 2), targets, ignore_index=-1, reduction="none" - ) - - mean_losses = logit_losses.cpu().mean(dim=-1) - losses.extend(mean_losses.tolist()) - - val_loss = mean(losses) - - if not runtime.is_main_process: - distributed_send_obj(val_loss, dst=0) - elif runtime.is_main_process: - val_loss = distributed_recv_obj() - else: - val_loss = float("nan") - - stitched_model.train() - - loss_list = [val_loss] - torch.distributed.broadcast_object_list(loss_list) - val_loss = loss_list[0] - - return val_loss - - -@torch.no_grad() -def validate_pipeline( - runtime: IRuntime, - stitched_model: StitchedModule, - model_config: DeciLMConfig, - val_dataloader: DataLoader, - iter_num: int | None = None, - max_iters: int | None = None, - model_name: str | None = None, - enable_print: bool = True, - enable_wandb_log: bool = False, - # pad_to_batchsize: bool = True, -) -> float: - if enable_print: - mprint("Validating ...") - - val_loss = validate_pipeline_inner( - runtime=runtime, - stitched_model=stitched_model, - val_dataloader=val_dataloader, - ) - - if runtime.is_main_process: - key = "val/loss" if model_name is None else f"val/{model_name}_loss" - if enable_print: - prefix = "" - if iter_num is not None: - prefix += f"iter {iter_num}" - if max_iters is not None: - prefix += f"/{max_iters}" - prefix += " - " - mprint(f"{prefix}{key}: {val_loss:.4f}") - if enable_wandb_log: - wandb.log({key: val_loss}, step=iter_num) - - runtime.wait_for_everyone() - - return val_loss - - class HiddenStatesAndLMHead(list): def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Tensor): super().__init__(hidden_states) @@ -202,7 +60,6 @@ def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Ten @torch.no_grad() def calculate_losses_pipeline( - runtime: IRuntime, stitched_model: StitchedModule | DeciLMForCausalLM, dataloader: DataLoader | None, target_hidden_states_per_batch: HiddenStatesAndLMHead | None = None, @@ -211,6 +68,7 @@ def calculate_losses_pipeline( calc_on_cpu: bool = False, just_model_forward: bool = False, checkpoint_manager=None, + autocast_dtype: torch.dtype = torch.bfloat16, ) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: """ Do model forward on each batch and calculate LM loss. @@ -232,27 +90,27 @@ def calculate_losses_pipeline( """ if isinstance(stitched_model, DeciLMForCausalLM): - stitched_model = perform_pipeline_stitches(stitched_model, runtime) + stitched_model = perform_pipeline_stitches(stitched_model) params = list(stitched_model.parameters()) model_device = params[0].device if params else "cpu" # Pre-populate outputs with dummy values for skipped batches start_batch = checkpoint_manager.current_batch if checkpoint_manager else 0 - if runtime.is_last_process: + if dist.is_last_process(): outputs = [{"lm_loss": [0.0]}] * start_batch else: outputs = None - if runtime.is_main_process: + if dist.is_master(): all_input_ids, all_targets = zip( *[(batch["input_ids"], batch["targets"]) for batch in dataloader] ) - if runtime.world_size > 1: - distributed_send_obj(all_targets, dst=runtime.world_size - 1) + if dist.size() > 1: + distributed_send_obj(all_targets, dst=dist.size() - 1) - if runtime.is_last_process: - if runtime.world_size > 1: + if dist.is_last_process(): + if dist.size() > 1: all_targets = distributed_recv_obj(src=0) lm_head: LMHead = next( @@ -268,37 +126,37 @@ def calculate_losses_pipeline( {"weight": lm_head_weights}, LMHead, *lm_head_weights.shape[::-1], bias=False ) - if runtime.is_main_process: + if dist.is_master(): num_batches = len(all_input_ids) seq_len = all_input_ids[0].shape[1] - if runtime.world_size > 1: + if dist.size() > 1: torch.distributed.broadcast_object_list([num_batches, seq_len]) # Create progress bar with sliced range starting from checkpoint position desc = ( - f"[rank {runtime.global_rank}] calculate_losses_pipeline(" + f"[rank {dist.rank()}] calculate_losses_pipeline(" f"{(target_hidden_states_per_batch is None)=}, {return_hidden_states=}, {num_batches=})" ) progress_bar = tqdm(range(start_batch, num_batches), desc=desc) else: obj_list = [None, None] - if runtime.world_size > 1: + if dist.size() > 1: torch.distributed.broadcast_object_list(obj_list) num_batches, seq_len = obj_list progress_bar = range(start_batch, num_batches) stitched_model.eval() - with runtime.autocast(): + with torch.autocast(device_type="cuda", dtype=autocast_dtype): for i_batch in progress_bar: - if runtime.is_main_process: + if dist.is_master(): input_ids = all_input_ids[i_batch].to(model_device) else: input_ids = fake_tensor(1, seq_len, dtype=torch.long) output = stitched_model({}, {}, input_ids) - if runtime.is_last_process: + if dist.is_last_process(): logits = output.captured_outputs.get("model_output") logits = getattr(logits, "logits", logits) hidden_states = output.captured_outputs.get("hidden_states") @@ -340,14 +198,11 @@ def calculate_losses_pipeline( hidden_states_per_batch, lm_head.weight.cpu() ) - runtime.wait_for_everyone() + dist.barrier() return losses, hidden_states_per_batch -def perform_pipeline_stitches( - model: DeciLMForCausalLM, - runtime: IRuntime, -) -> StitchedModule: +def perform_pipeline_stitches(model: DeciLMForCausalLM) -> StitchedModule: target = ModuleTarget("module", model) stitcher = Needle() @@ -356,10 +211,10 @@ def perform_pipeline_stitches( ) first_block, last_block = is_real_block.min(), is_real_block.max() - if runtime.global_rank != 0: + if dist.rank() != 0: # receive activations from previous rank stitcher.stitch( - RemoteTarget(peer_rank=runtime.global_rank - 1).value( + RemoteTarget(peer_rank=dist.rank() - 1).value( name="activations", adapter=lambda x: InputArgs(x) ), target.input( @@ -370,11 +225,11 @@ def perform_pipeline_stitches( ), ) - if not runtime.is_last_process: + if not dist.is_last_process(): # send activations to next rank stitcher.stitch( target.output(f"model.layers.{last_block}"), - RemoteTarget(peer_rank=runtime.global_rank + 1).value(name="activations"), + RemoteTarget(peer_rank=dist.rank() + 1).value(name="activations"), ) else: # register model output diff --git a/modelopt/torch/_compress/utils/validation.py b/modelopt/torch/_compress/utils/validation.py index 662ae4a2b..d970105e6 100644 --- a/modelopt/torch/_compress/utils/validation.py +++ b/modelopt/torch/_compress/utils/validation.py @@ -24,14 +24,10 @@ import functools import math from enum import Enum -from statistics import mean import numpy as np import torch -import torch.distributed import torch.nn.functional as F -import wandb -from accelerate import Accelerator from torch import nn from torch.utils.data import DataLoader from tqdm import tqdm @@ -39,159 +35,6 @@ from typing_extensions import Self from modelopt.torch._compress.tools import kd_model -from modelopt.torch._compress.utils.data.dataloaders import create_padded_tensor - - -@torch.no_grad() -def _validate_single( - accelerator: Accelerator, - model: torch.nn.Module, - rope_cache: torch.Tensor | None, - val_dataloader: DataLoader, - pad_to_batchsize: bool = True, - compute_kl_div: bool = False, - varlen: bool = False, - concat_token_id: int | None = None, -) -> list[float]: - assert val_dataloader.batch_sampler.batch_size is not None - desired_batch_size = val_dataloader.batch_sampler.batch_size - - with accelerator.device, accelerator.autocast(): - model.eval() - - losses: list[float] = [] - - input_ids: torch.LongTensor - targets: torch.LongTensor - is_first_batch = True - for batch in tqdm(val_dataloader, disable=not accelerator.is_main_process): - if is_first_batch: - print( - f"First batch, device {accelerator.device}, input_ids: {batch['input_ids'][:4]}" - ) - is_first_batch = False - input_ids, targets = ( - batch["input_ids"].to(accelerator.device), - batch["targets"].to(accelerator.device), - ) - batch_size = input_ids.size(0) - - if pad_to_batchsize: - input_ids = create_padded_tensor( - input_ids, (desired_batch_size, *input_ids.shape[1:]) - ) - targets = create_padded_tensor(targets, (desired_batch_size, *targets.shape[1:])) - - if rope_cache is not None: - logits = model( - input_ids, rope_cache=rope_cache, varlen=varlen, concat_token_id=concat_token_id - ) - else: - logits = model(input_ids) - - if hasattr(logits, "logits"): # For HF models - logits = logits.logits - - if isinstance(logits, tuple): # For KD - logits, teacher_logits, kd_block_loss, kd_logits_loss = logits - - if compute_kl_div: - # assumes kd_logits_loss has entry for each batch item - batch_losses = kd_logits_loss[:batch_size] - else: - batch_losses = torch.nn.functional.cross_entropy( - logits.transpose(1, 2), targets, ignore_index=-1, reduction="none" - )[:batch_size].mean(dim=-1) - - losses.extend(batch_losses.tolist()) - - model.train() - - return losses - - -@torch.no_grad() -def validate_parallel( - accelerator: Accelerator, - model: torch.nn.Module, - rope_cache: torch.Tensor | None, - val_dataloader: DataLoader, - pad_to_batchsize: bool = True, - compute_kl_div: bool = False, - varlen: bool = False, - concat_token_id: int | None = None, -) -> float: - losses = _validate_single( - accelerator=accelerator, - model=model, - rope_cache=rope_cache, - val_dataloader=val_dataloader, - pad_to_batchsize=pad_to_batchsize, - compute_kl_div=compute_kl_div, - varlen=varlen, - concat_token_id=concat_token_id, - ) - - results = [float("nan")] - if accelerator.is_main_process: - gathered_results = [[float("nan")]] * accelerator.num_processes - torch.distributed.gather_object(losses, gathered_results) - gathered_losses = [l for result in gathered_results for l in result] - results[0] = mean(gathered_losses) - else: - torch.distributed.gather_object(losses) - - torch.distributed.broadcast_object_list(results) - val_loss = results[0] - - return val_loss - - -@torch.no_grad() -def validate( - accelerator: Accelerator, - model: torch.nn.Module, - rope_cache: torch.Tensor | None, - val_dataloader: DataLoader, - iter_num: int | None = None, - max_iters: int | None = None, - model_name: str | None = None, - enable_print: bool = True, - enable_wandb_log: bool = False, - pad_to_batchsize: bool = True, - compute_kl_div: bool = False, - varlen: bool = False, - concat_token_id: int | None = None, -) -> float: - if enable_print: - accelerator.print("Validating ...") - - val_loss = validate_parallel( - accelerator=accelerator, - model=model, - rope_cache=rope_cache, - val_dataloader=val_dataloader, - pad_to_batchsize=pad_to_batchsize, - compute_kl_div=compute_kl_div, - varlen=varlen, - concat_token_id=concat_token_id, - ) - - if accelerator.is_main_process: - key = "val/loss" if model_name is None else f"val/{model_name}_loss" - if enable_print: - prefix = "" - if iter_num is not None: - prefix += f"iter {iter_num}" - if max_iters is not None: - prefix += f"/{max_iters}" - prefix += " - " - accelerator.print(f"{prefix}{key}: {val_loss:.4f}", show_delta=True) - if enable_wandb_log: - wandb.log({key: val_loss}, step=iter_num) - accelerator.wait_for_everyone() - - return val_loss class UnshardedLowMemorySparseTensor: @@ -325,37 +168,6 @@ def calculate_losses( return losses, None -def calc_entropy(logits: torch.Tensor) -> torch.Tensor: - """ - Returns per-token entropy given a logits tensor of shape [batch_size x seq_len x vocab_size]. - The output will have shape [batch_size x seq_len]. - """ - # Convert logits to log-probabilities - log_probs = F.log_softmax(logits, dim=-1) # shape: [B x T x V] - - # Compute probabilities from log-probabilities - probs = torch.exp(log_probs) # shape: [B x T x V] - - # Entropy calculation: sum over V of (- p * log p) - ent = -torch.sum(probs * log_probs, dim=-1) # shape: [B x T] - - return ent - - -def confidence_max_softmax(logits: torch.Tensor) -> torch.Tensor: - """ - Returns per-token max-softmax confidence given a logits tensor of shape [batch_size x seq_len x vocab_size]. - The output will have shape [batch_size x seq_len]. - """ - # Compute softmax probabilities - probs = F.softmax(logits, dim=-1) # shape: [B x T x V] - - # Take the maximum probability along the vocabulary dimension - max_confidence = torch.max(probs, dim=-1).values # shape: [B x T] - - return max_confidence - - def calculate_batch_outputs( hidden_states: torch.Tensor | None, target_hidden_states: torch.Tensor | None, @@ -380,8 +192,6 @@ def calculate_batch_outputs( batch_outputs = _calculate_ground_truth_based_scores(logits, targets) - # _DEBUG_calculate_per_token_entropy(batch_outputs, logits) - if (target_hidden_states is not None) or (target_logits is not None): batch_outputs.update( _calculate_teacher_similarity_scores( @@ -399,20 +209,6 @@ def calculate_batch_outputs( return batch_outputs -def _DEBUG_calculate_per_token_entropy(batch_outputs, logits, i_batch): - import os - - # calculate the per token entropy and per token top p - entropy = calc_entropy(logits).cpu() # .view(-1)#.tolist() - msftm = confidence_max_softmax(logits).cpu() # .view(-1)#.tolist() - teacher_dir = ".../meta-llama/Meta-Llama-3.1-70B-Instruct-new_rope/" - file_path = f"{teacher_dir}/validation/per_token_stats_{i_batch}.pth" - os.makedirs(os.path.dirname(file_path), exist_ok=True) - torch.save({"entropy": entropy, "max_softmax": msftm}, file_path) - batch_outputs["entropy"] = entropy - batch_outputs["max_softmax"] = msftm - - def _organize_outputs( outputs_per_batch: list[dict], ) -> tuple[dict[str, dict], list[torch.Tensor] | None]: @@ -473,28 +269,6 @@ def _calculate_ground_truth_based_scores( return scores -def _calculate_per_sample_kl_div_loss( - logits: torch.Tensor, - batch_target_probs: torch.Tensor | LowMemorySparseTensor, -) -> list[float]: - if isinstance(batch_target_probs, LowMemorySparseTensor): - logits = top_p_top_k(logits) - curr_target_probs = batch_target_probs.to_dense().to(logits.device) # .float() - per_sample_kl_div = [ - F.kl_div( - logits[i_sample].log_softmax(-1), - curr_target_probs[i_sample], - reduction="none", - log_target=False, - ) - .sum(-1) - .mean(-1) - .item() - for i_sample in range(logits.shape[0]) - ] - return per_sample_kl_div - - def cosine_embedding_loss( hidden_states: torch.Tensor, target_hidden_states: torch.Tensor, @@ -762,49 +536,6 @@ def tv_dist( DEFAULT_TOP_K = 1000 -def calculate_sparse_probs( - logits: torch.Tensor, - top_p: float | None = DEFAULT_TOP_P, - top_k: int | None = DEFAULT_TOP_K, - verbose: bool = False, -) -> LowMemorySparseTensor: - warped_logits = top_p_top_k(logits, top_p, top_k) - probs = warped_logits.softmax(-1) - sparse_probs = LowMemorySparseTensor(probs) - if True: # Always calculate these metrics (was: if verbose or True:) - probs_unfiltered = logits.softmax(-1) - num_active_per_token = (warped_logits > -1000).sum(-1).float() - prob_density = torch.tensor( - [ - probs_unfiltered[i, j, warped_logits[i, j] > -1000].sum(-1).float() - for j in range(probs_unfiltered.shape[1]) - for i in range(probs_unfiltered.shape[0]) - ] - ) - - print(f""" - Sparsity: - {num_active_per_token.mean().item()=} - {num_active_per_token.quantile(0.25).item()=} - {num_active_per_token.quantile(0.5).item()=} - {num_active_per_token.quantile(0.75).item()=} - {num_active_per_token.quantile(0.9).item()=} - {num_active_per_token.quantile(0.95).item()=} - {num_active_per_token.max().item()=} - - {probs_unfiltered.shape=} - {prob_density.shape=} - {prob_density.mean().item()=} - {prob_density.quantile(0.25).item()=} - {prob_density.quantile(0.5).item()=} - {prob_density.quantile(0.75).item()=} - {prob_density.quantile(0.9).item()=} - {prob_density.quantile(0.95).item()=} - {prob_density.max().item()=} - """) - return sparse_probs - - def top_p_top_k( logits: torch.Tensor, top_p: float | None = DEFAULT_TOP_P, diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 033b4aadb..9b32d1ac4 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -20,6 +20,8 @@ import os import time from collections.abc import Callable +from contextlib import suppress +from datetime import timedelta from typing import Any import torch @@ -70,11 +72,23 @@ def rank(group=None) -> int: return 0 +def local_rank() -> int: + """Returns the local rank of the current process.""" + if "LOCAL_RANK" in os.environ: + return int(os.environ["LOCAL_RANK"]) + raise RuntimeError("LOCAL_RANK environment variable not found.") + + def is_master(group=None) -> bool: """Returns whether the current process is the master process.""" return rank(group=group) == 0 +def is_last_process(group=None) -> bool: + """Returns whether the current process is the last process.""" + return rank(group=group) == size(group=group) - 1 + + def _serialize(obj: Any) -> torch.Tensor: buffer = io.BytesIO() torch.save(obj, buffer) @@ -184,6 +198,21 @@ def wrapper(*args, **kwargs): return wrapper +def setup(timeout: timedelta | None = None): + """Sets up the distributed environment.""" + torch.cuda.set_device(local_rank()) + if not is_initialized(): + torch.distributed.init_process_group("cpu:gloo,cuda:nccl", timeout=timeout) + + +def cleanup(): + """Cleans up the distributed environment.""" + if is_initialized(): + with suppress(Exception): + barrier() + torch.distributed.destroy_process_group() + + class DistributedProcessGroup: """A convenient wrapper around torch.distributed.ProcessGroup objects.""" diff --git a/setup.py b/setup.py index 20a271fe1..e19935a88 100644 --- a/setup.py +++ b/setup.py @@ -111,7 +111,6 @@ "omegaconf==2.3.0", "pandas", "typeguard", - "wandb~=0.17.5", ], } diff --git a/tests/gpu/torch/_compress/compress_test_utils.py b/tests/gpu/torch/_compress/compress_test_utils.py index 9df5f5bfc..1da08602b 100644 --- a/tests/gpu/torch/_compress/compress_test_utils.py +++ b/tests/gpu/torch/_compress/compress_test_utils.py @@ -21,14 +21,12 @@ from datasets import Dataset, DatasetDict from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers def setup_test_model_and_data( - project_root_path: Path, - tmp_path: Path, - rank: int, - runtime, + project_root_path: Path, tmp_path: Path, rank: int ) -> tuple[Path, Path, Path]: """ Setup the test model and data for the compress NAS search. @@ -37,7 +35,6 @@ def setup_test_model_and_data( project_root_path (Path): the root path of the project tmp_path (Path): the temporary path to use for the test rank (int): the rank of the process - runtime: the runtime to use for the test Returns: tuple[Path, Path, Path]: @@ -63,7 +60,7 @@ def setup_test_model_and_data( create_and_save_small_llama_model( llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer ) - runtime.wait_for_everyone() + dist.barrier() return ( puzzle_dir, diff --git a/tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py b/tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py index dbbcbacd4..913bc2116 100644 --- a/tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py +++ b/tests/gpu/torch/_compress/nas/plugins/test_nas_convert.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime import os +from datetime import timedelta from functools import partial from pathlib import Path @@ -23,14 +23,10 @@ from gpu.torch._compress.compress_test_utils import setup_test_model_and_data import modelopt.torch.nas as mtn +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel -from modelopt.torch._compress.tools.runtime import NativeDdpRuntime -# -# See tests/gpu/torch/_compress/test_compress.py for instructions on how to run this test -# TODO: Remove those instructions once this test runs automatically on CI -# def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), @@ -42,51 +38,49 @@ def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): def _test_nas_convert_ffn_pruning_multiprocess_job( project_root_path: Path, tmp_path: Path, rank: int, size: int ): - with NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) as runtime: - # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, runtime - ) - hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" - hydra_config_name = "Llama-3_1-8B-ffn-pruning" - - # - # Run the mnt.convert() step - # - input_model = CompressModel() - mtn.convert( - input_model, - mode=[ - ( - "compress", - { - "puzzle_dir": str(puzzle_dir), - "input_model_path": str(llama_checkpoint_path), - "hydra_config_dir": str(hydra_config_dir), - "hydra_config_name": hydra_config_name, - "dataset_path": str(dataset_path), - }, - ) - ], - ) - - # - # Check assertions - # - if rank == 0: - # assertions for the score_pruning_activations step - rank = int(os.environ["RANK"]) - rank_filepath = ( - f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" + dist.setup(timeout=timedelta(10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" + hydra_config_name = "Llama-3_1-8B-ffn-pruning" + + # + # Run the mnt.convert() step + # + input_model = CompressModel() + mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, ) - assert (puzzle_dir / rank_filepath).is_file() + ], + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() - # assertions for the pruning_ckpts step - assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + # assertions for the pruning_ckpts step + assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() - runtime.wait_for_everyone() + dist.cleanup() print("PYTEST SUMMARY: test_nas_convert_ffn_pruning() test has finished successfully") @@ -102,53 +96,51 @@ def test_nas_convert_attn_pruning(project_root_path: Path, tmp_path: Path): def _test_nas_convert_attn_pruning_multiprocess_job( project_root_path: Path, tmp_path: Path, rank: int, size: int ): - with NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) as runtime: - # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, runtime - ) - hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" - hydra_config_name = "Llama-3_1-8B-attn-pruning" - - # - # Run the mnt.convert() step - # - input_model = CompressModel() - mtn.convert( - input_model, - mode=[ - ( - "compress", - { - "puzzle_dir": str(puzzle_dir), - "input_model_path": str(llama_checkpoint_path), - "hydra_config_dir": str(hydra_config_dir), - "hydra_config_name": hydra_config_name, - "dataset_path": str(dataset_path), - }, - ) - ], - ) - - # - # Check assertions - # - if rank == 0: - # assertions for the score_pruning_activations step - rank = int(os.environ["RANK"]) - rank_filepath = ( - f"pruning/pruning_scores/attn_independent_kv_head_contribution/" - f"100samples_diverse_mini/rank_{rank}.pth" + dist.setup(timeout=timedelta(10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" + hydra_config_name = "Llama-3_1-8B-attn-pruning" + + # + # Run the mnt.convert() step + # + input_model = CompressModel() + mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, ) - assert (puzzle_dir / rank_filepath).is_file() + ], + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/attn_independent_kv_head_contribution/" + f"100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() - # assertions for the pruning_ckpts step - assert (puzzle_dir / "ckpts/n_heads_in_group8").exists() - assert (puzzle_dir / "ckpts/n_heads_in_group16").exists() - assert (puzzle_dir / "ckpts/n_heads_in_group32").exists() + # assertions for the pruning_ckpts step + assert (puzzle_dir / "ckpts/n_heads_in_group8").exists() + assert (puzzle_dir / "ckpts/n_heads_in_group16").exists() + assert (puzzle_dir / "ckpts/n_heads_in_group32").exists() - runtime.wait_for_everyone() + dist.cleanup() print("PYTEST SUMMARY: test_nas_convert_attn_pruning() test has finished successfully") diff --git a/tests/gpu/torch/_compress/nas/plugins/test_nas_search.py b/tests/gpu/torch/_compress/nas/plugins/test_nas_search.py index e8ea24ece..1b4ed93c6 100644 --- a/tests/gpu/torch/_compress/nas/plugins/test_nas_search.py +++ b/tests/gpu/torch/_compress/nas/plugins/test_nas_search.py @@ -13,11 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# -# See tests/gpu/torch/_compress/test_compress.py for instructions on how to run this test -# TODO: Remove those instructions once this test runs automatically on CI -# -import datetime +from datetime import timedelta from functools import partial from pathlib import Path @@ -26,8 +22,8 @@ from gpu.torch._compress.compress_test_utils import setup_test_model_and_data import modelopt.torch.nas as mtn +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel -from modelopt.torch._compress.tools.runtime import NativeDdpRuntime def test_nas_search(project_root_path: Path, tmp_path: Path): @@ -41,72 +37,68 @@ def test_nas_search(project_root_path: Path, tmp_path: Path): def _test_nas_search_multiprocess_job( project_root_path: Path, tmp_path: Path, rank: int, size: int ): - with NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) as runtime: - # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, runtime - ) - hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" - hydra_config_name = "Llama-3_1-8B-ffn-pruning" - - # - # Run the mnt.convert() step - # - input_model = CompressModel() - converted_model = mtn.convert( - input_model, - mode=[ - ( - "compress", - { - "puzzle_dir": str(puzzle_dir), - "input_model_path": str(llama_checkpoint_path), - "hydra_config_dir": str(hydra_config_dir), - "hydra_config_name": hydra_config_name, - "dataset_path": str(dataset_path), - }, - ) - ], - ) + dist.setup(timeout=timedelta(10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" + hydra_config_name = "Llama-3_1-8B-ffn-pruning" + + # + # Run the mnt.convert() step + # + input_model = CompressModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, + ) + ], + ) - # - # Run the mnt.search() step - # - mtn.search( - converted_model, - constraints={}, # this is not used as the search space is defined in the hydra config - dummy_input=None, # Not used - config={}, # this is not used as the search space is defined in the hydra config - ) + # + # Run the mnt.search() step + # + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) - # - # Check assertions for mtn.search() step - # - if rank == 0: - # assertions for the build_library_and_stats step - assert (puzzle_dir / "replacement_library.json").is_file() - assert (puzzle_dir / "subblock_stats.json").is_file() - - # assertions for the scoring step - solution_0_filepath = ( - puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" - ) + # + # Check assertions for mtn.search() step + # + if rank == 0: + # assertions for the build_library_and_stats step + assert (puzzle_dir / "replacement_library.json").is_file() + assert (puzzle_dir / "subblock_stats.json").is_file() + + # assertions for the scoring step + solution_0_filepath = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) - assert solution_0_filepath.exists() + assert solution_0_filepath.exists() - # assertions for the mip_and_realize_models step - solution_0_ckpt_config_path = ( - puzzle_dir - / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" - ) + # assertions for the mip_and_realize_models step + solution_0_ckpt_config_path = ( + puzzle_dir + / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" + ) - assert solution_0_ckpt_config_path.exists() - assert ( - puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json" - ).exists() + assert solution_0_ckpt_config_path.exists() + assert (puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json").exists() - runtime.wait_for_everyone() + dist.cleanup() print("PYTEST SUMMARY: test_nas_search() test has finished successfully") diff --git a/tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml b/tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml index 178edb50d..192b82c75 100644 --- a/tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml +++ b/tests/gpu/torch/_compress/resources/configs/validate_model_defaults.yaml @@ -1,3 +1,5 @@ +model_dtype: torch.bfloat16 +autocast_dtype: torch.bfloat16 block_size: 8192 bos_rate: 0.5 data_column: conversation diff --git a/tests/gpu/torch/_compress/test_compress.py b/tests/gpu/torch/_compress/test_compress.py index e40756602..997bb9971 100644 --- a/tests/gpu/torch/_compress/test_compress.py +++ b/tests/gpu/torch/_compress/test_compress.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime import os +from datetime import timedelta from functools import partial from pathlib import Path @@ -22,11 +22,11 @@ from _test_utils.torch.distributed.utils import spawn_multiprocess_job from gpu.torch._compress.compress_test_utils import setup_test_model_and_data +import modelopt.torch.utils.distributed as dist from modelopt.torch._compress import compress from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) -from modelopt.torch._compress.tools.runtime import NativeDdpRuntime # The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) # using a one-click command. @@ -43,66 +43,60 @@ def test_compress(project_root_path: Path, tmp_path: Path): def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, rank: int, size: int): - with NativeDdpRuntime( - dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) - ) as runtime: - # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( - project_root_path, tmp_path, rank, runtime + dist.setup(timeout=timedelta(10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" + hydra_config_name = "Llama-3_1-8B-ffn-pruning" + + # Convert the Llama model to DeciLM model. + if rank == 0: + convert_llama3_to_decilm( + input_dir=llama_checkpoint_path, + output_dir=puzzle_dir / "ckpts/teacher", ) - hydra_config_dir = project_root_path / "tests/gpu/torch/_compress/resources/configs" - hydra_config_name = "Llama-3_1-8B-ffn-pruning" - - # Convert the Llama model to DeciLM model. - if rank == 0: - convert_llama3_to_decilm( - input_dir=llama_checkpoint_path, - output_dir=puzzle_dir / "ckpts/teacher", - ) - runtime.wait_for_everyone() - - # Compress the model using a one-click approach - compress.compress( - str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path), runtime + dist.barrier() + + # Compress the model using a one-click approach + compress.compress(str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path)) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step 1 + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" ) + assert (puzzle_dir / rank_filepath).is_file() - # - # Check assertions - # - if rank == 0: - # assertions for the score_pruning_activations step 1 - rank = int(os.environ["RANK"]) - rank_filepath = ( - f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" - ) - assert (puzzle_dir / rank_filepath).is_file() - - # assertions for the pruning_ckpts step 2 - assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + # assertions for the pruning_ckpts step 2 + assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() - # assertions for the build_library_and_stats step 4 + # assertions for the build_library_and_stats step 4 - assert (puzzle_dir / "replacement_library.json").is_file() - assert (puzzle_dir / "subblock_stats.json").is_file() + assert (puzzle_dir / "replacement_library.json").is_file() + assert (puzzle_dir / "subblock_stats.json").is_file() - # assertions for the scoring step 5 - solution_0_filepath = ( - puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" - ) + # assertions for the scoring step 5 + solution_0_filepath = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) - assert solution_0_filepath.exists() + assert solution_0_filepath.exists() - # assertions for the mip_and_realize_models step 6 - solution_0_ckpt_config_path = ( - puzzle_dir - / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" - ) + # assertions for the mip_and_realize_models step 6 + solution_0_ckpt_config_path = ( + puzzle_dir + / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" + ) - assert solution_0_ckpt_config_path.exists() - assert ( - puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json" - ).exists() + assert solution_0_ckpt_config_path.exists() + assert (puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json").exists() - runtime.wait_for_everyone() + dist.cleanup() print("PYTEST SUMMARY: test_compress_model() test has finished successfully")