Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
model_dtype: torch.bfloat16 # dtype to cast the model for validate_model
autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model
block_size: 8192
bos_rate: 0.5
data_column: messages
Expand Down
149 changes: 75 additions & 74 deletions examples/compress/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@
"""

import argparse
import datetime
from datetime import timedelta
from pathlib import Path

import mip_and_realize_models
import torch
from puzzle_tools.hydra_utils import register_hydra_resolvers

import modelopt.torch._compress.mip.mip_and_realize_models as mip_and_realize_models
import modelopt.torch.nas as mtn
import modelopt.torch.utils.distributed as dist
from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel
from modelopt.torch._compress.runtime import NativeDdpRuntime
from modelopt.torch._compress.tools.hydra_utils import (
initialize_hydra_config_for_dir,
register_hydra_resolvers,
)
from modelopt.torch._compress.tools.logger import mprint
from tests.utils.test_utils import initialize_hydra_config_for_dir


def parse_args():
Expand Down Expand Up @@ -70,50 +70,52 @@ def run_full_compress(hydra_config_path: str):
config_path: Path to the YAML configuration file
"""
mprint("Compress Progress 1/8: starting compression pipeline")
with NativeDdpRuntime(dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)):
# Register Hydra custom resolvers (needed for config resolution)
register_hydra_resolvers()

hydra_config_path = Path(hydra_config_path).resolve()
hydra_config_dir = str(hydra_config_path.parent)
hydra_config_name = hydra_config_path.stem

# Load hydra config
hydra_cfg = initialize_hydra_config_for_dir(
config_dir=hydra_config_dir,
config_name=hydra_config_name,
overrides=[],
)

# Convert model (convert from HF to DeciLM, score pruning activations,
# prune the model and save pruned checkpoints)
input_model = CompressModel()
converted_model = mtn.convert(
input_model,
mode=[
(
"compress",
{
"puzzle_dir": str(hydra_cfg.puzzle_dir),
"input_model_path": hydra_cfg.input_hf_model_path,
"hydra_config_dir": hydra_config_dir,
"hydra_config_name": hydra_config_name,
"dataset_path": str(hydra_cfg.dataset_path),
},
)
],
)

# Run NAS search (build replacement library and compute stats,
# compute one block scores, run MIP and realize models)
mtn.search(
converted_model,
constraints={}, # this is not used as the search space is defined in the hydra config
dummy_input=None, # Not used
config={}, # this is not used as the search space is defined in the hydra config
)

mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)")
dist.setup(timeout=timedelta(10))

# Register Hydra custom resolvers (needed for config resolution)
register_hydra_resolvers()

hydra_config_path = Path(hydra_config_path).resolve()
hydra_config_dir = str(hydra_config_path.parent)
hydra_config_name = hydra_config_path.stem

# Load hydra config
hydra_cfg = initialize_hydra_config_for_dir(
config_dir=hydra_config_dir,
config_name=hydra_config_name,
overrides=[],
)

# Convert model (convert from HF to DeciLM, score pruning activations,
# prune the model and save pruned checkpoints)
input_model = CompressModel()
converted_model = mtn.convert(
input_model,
mode=[
(
"compress",
{
"puzzle_dir": str(hydra_cfg.puzzle_dir),
"input_model_path": hydra_cfg.input_hf_model_path,
"hydra_config_dir": hydra_config_dir,
"hydra_config_name": hydra_config_name,
"dataset_path": str(hydra_cfg.dataset_path),
},
)
],
)

# Run NAS search (build replacement library and compute stats,
# compute one block scores, run MIP and realize models)
mtn.search(
converted_model,
constraints={}, # this is not used as the search space is defined in the hydra config
dummy_input=None, # Not used
config={}, # this is not used as the search space is defined in the hydra config
)

dist.cleanup()
mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)")


def run_mip_only(hydra_config_path: str):
Expand All @@ -125,30 +127,29 @@ def run_mip_only(hydra_config_path: str):
Args:
hydra_config_path: Path to the YAML configuration file
"""
dist.setup(timeout=timedelta(10))

# Register Hydra custom resolvers (needed for config resolution)
register_hydra_resolvers()

hydra_config_path = Path(hydra_config_path).resolve()
hydra_config_dir = str(hydra_config_path.parent)
hydra_config_name = hydra_config_path.stem

# Load hydra config
hydra_cfg = initialize_hydra_config_for_dir(
config_dir=hydra_config_dir,
config_name=hydra_config_name,
overrides=[],
)

# mip_and_realize_models (distributed processing)
# TODO: How to make it part of mnt.search() api, similarly to run_full_compress() API
mprint("Compress Progress 7/8: running MIP and realizing models (multi-gpu)")
mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)

with NativeDdpRuntime(
dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)
) as runtime:
# Register Hydra custom resolvers (needed for config resolution)
register_hydra_resolvers()

hydra_config_path = Path(hydra_config_path).resolve()
hydra_config_dir = str(hydra_config_path.parent)
hydra_config_name = hydra_config_path.stem

# Load hydra config
hydra_cfg = initialize_hydra_config_for_dir(
config_dir=hydra_config_dir,
config_name=hydra_config_name,
overrides=[],
)

# mip_and_realize_models (distributed processing)
# TODO: How to make it part of mnt.search() api, similarly to run_full_compress() API
mprint("Compress Progress 7/8: running MIP and realizing models (multi-gpu)")
mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime)

mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)")
dist.cleanup()
mprint("Compress Progress 8/8: compression pipeline completed (multi-gpu)")


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
activation scoring for pruning.
"""

import argparse
import gc
import json
from abc import ABC, abstractmethod
Expand All @@ -30,6 +29,8 @@
from omegaconf import DictConfig, OmegaConf
from torch import nn

import modelopt.torch.utils.distributed as dist

# BlockConfig used at runtime, not just type hints (lines 680, 790)
from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import BlockConfig # noqa: TC001
from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import (
Expand All @@ -38,7 +39,6 @@
from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMRMSNorm
from modelopt.torch._compress.tools.logger import aprint
from modelopt.torch._compress.tools.robust_json import json_dump
from modelopt.torch._compress.tools.runtime import IRuntime


def clear_gpu_memory(clear: bool) -> None:
Expand Down Expand Up @@ -97,17 +97,16 @@ def dump_activations_logs(
cls: type["ActivationsHook"],
activation_hooks: dict[str, "ActivationsHook"],
activations_log_dir: Path | str,
args: argparse.Namespace,
runtime: IRuntime | None,
):
args: DictConfig,
) -> None:
"""
Default implementation for dumping final activation scores logs to disk.
This is called only at the end of scoring to save final results.
"""

activations_log_dir = Path(activations_log_dir)
activations_log_dir.mkdir(exist_ok=True, parents=True)
rank = runtime.global_rank if runtime is not None else 0
rank = dist.rank()
activations_log_path = activations_log_dir / f"rank_{rank}.pth"
activations_log = {
module_name: hook.to_dict() for module_name, hook in activation_hooks.items()
Expand All @@ -116,14 +115,8 @@ def dump_activations_logs(

if rank == 0:
args.activation_hooks_kwargs.pop("model")
json_dump(
OmegaConf.to_container(args, resolve=True)
if isinstance(args, DictConfig)
else vars(args),
activations_log_dir / "args.json",
)
if runtime is not None:
runtime.wait_for_everyone() # rank 0 will not wait before dumping args.json
json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json")
dist.barrier()

aprint(f"Dumped final activations log to {activations_log_path}")

Expand All @@ -132,16 +125,15 @@ def save_hook_states(
cls: type["ActivationsHook"],
activation_hooks: dict[str, "ActivationsHook"],
activations_log_dir: Path | str,
runtime: IRuntime | None,
):
) -> None:
"""
Save hook states for checkpointing (separate from final results).
This can be called periodically during scoring.
Note: Synchronization should be handled at a higher level to avoid deadlocks.
"""
activations_log_dir = Path(activations_log_dir)
activations_log_dir.mkdir(exist_ok=True, parents=True)
rank = runtime.global_rank if runtime is not None else 0
rank = dist.rank()

hook_states_path = activations_log_dir / f"hook_states_rank_{rank}.pth"
hook_states = {
Expand Down Expand Up @@ -461,29 +453,28 @@ def dump_activations_logs(
cls: type["LayerNormContributionHook"],
activation_hooks: dict[str, "ActivationsHook"],
activations_log_dir: Path | str,
args: argparse.Namespace,
runtime: IRuntime | None,
):
args: DictConfig,
) -> None:
"""
At the end of the default implementation of dumping activation scores to disc,
save aggregated channel importance results.
"""

super().dump_activations_logs(activation_hooks, activations_log_dir, args, runtime)
super().dump_activations_logs(activation_hooks, activations_log_dir, args)

rank = runtime.global_rank if runtime is not None else 0
rank = dist.rank()
if rank == 0:
LayerNormContributionHook._save_channel_importance_results(
activation_hooks, activations_log_dir, args
)

runtime.wait_for_everyone()
dist.barrier()

@staticmethod
def _save_channel_importance_results(
activation_hooks: dict[str, ActivationsHook],
activations_log_dir: Path,
args: argparse.Namespace,
args: DictConfig,
) -> None:
"""
Save channel importance results from activation hooks.
Expand Down
Loading