diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 125517c1a..71cb8eda6 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -3,12 +3,14 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Set, Union +from loguru import logger from compressed_tensors import has_offloaded_params from compressed_tensors.quantization import find_name_or_class_matches from torch.fx import Graph, GraphModule, Node from torch.fx.graph import PythonCode from torch.fx.proxy import Argument from torch.nn import Module +from torch.fx import _symbolic_trace from transformers import PreTrainedModel from transformers.configuration_utils import PretrainedConfig from transformers.utils.fx import HFTracer @@ -82,14 +84,18 @@ def trace_subgraphs( """ # find modules sequential_targets = match_modules(model, sequential_targets) - ignore = match_modules(model, ignore) + ignore = ["MistralModel._update_causal_mask"] # initialize arguments tracer = get_tracer(model, sequential_targets, ignore) concrete_args = populate_concrete_args(model, sample_input) + from torch import compiler + + wrap_module_methods(model, ignore) + # trace - with calibration_forward_context(model), HooksMixin.disable_hooks(): + with calibration_forward_context(model), HooksMixin.disable_hooks(), patch_attr(compiler, "_is_compiling_flag", True): graph = GraphModule( model, tracer.trace( @@ -115,6 +121,69 @@ def trace_subgraphs( return subgraphs +def get_sequential_ancestors(model: Module, targets: Set[Module]) -> Set[Module]: + ancestors = set() + def dfs(module: Module) -> bool: + if module in ancestors: + return True + + if module in targets: + return True + + # search all children (do not early stop) + is_ancestor = False + for child in module.children(): + is_ancestor |= dfs(child) + + if is_ancestor: + ancestors.add(module) + + return is_ancestor + + dfs(model) + + return ancestors + + +def wrap_module_methods(model: Module, ignore: List[str]): + module_classes = set(type(module) for module in model.modules()) + + for ignore_pattern in ignore: + num_dots = ignore_pattern.count(".") + + if num_dots == 0: + method_name = ignore_pattern + num_match = 0 + for cls in module_classes: + if hasattr(cls, method_name): + _symbolic_trace._wrapped_methods_to_patch.append((cls, method_name)) + num_match += 1 + + if num_match <= 0: + raise ValueError() + + if num_match >= 2: + logger.warning() + + elif num_dots == 1: + cls_name, method_name = ignore_pattern.split(".") + num_match = 0 + for cls in module_classes: + if cls.__name__ == cls_name and hasattr(cls, method_name): + _symbolic_trace._wrapped_methods_to_patch.append((cls, method_name)) + print(f"wrapped {(cls, method_name)}") + num_match += 1 + + if num_match <= 0: + raise ValueError() + + if num_match >= 2: + logger.warning() + + else: + raise ValueError() + + def get_tracer( model: Module, sequential_targets: Set[Module], ignore: Set[Module] ) -> HFTracer: @@ -131,7 +200,18 @@ def get_tracer( """ # TODO: redefine skip_trace_modules to all non-ancestors of sequential_targets offloaded_modules = set(m for m in model.modules() if has_offloaded_params(m)) - skip_trace_modules = sequential_targets | offloaded_modules | ignore + sequential_ancestors = get_sequential_ancestors(model, sequential_targets) + print([ancestor.__class__.__name__ for ancestor in sequential_ancestors]) + #skip_trace_modules = sequential_targets | offloaded_modules | ignore + + + from torch.fx import _symbolic_trace + + from transformers import MistralModel + + print(_symbolic_trace._wrapped_methods_to_patch) + _symbolic_trace._wrapped_methods_to_patch.append((MistralModel, "_update_causal_mask")) + print(_symbolic_trace._wrapped_methods_to_patch) class SequentialTracer(HFTracer): def create_arg(self, a: Any) -> Argument: @@ -144,11 +224,16 @@ def create_arg(self, a: Any) -> Argument: return super().create_arg(a) def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: - return module in skip_trace_modules or super().is_leaf_module( - module, module_qualified_name - ) + #return module in skip_trace_modules or super().is_leaf_module( + # module, module_qualified_name + # ) + return module not in sequential_ancestors def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph: + print(f"trace: {root.__class__.__name__}") + + # todo; if has ignored function name as method, add to list of methods to patch + if isinstance(root, Module): # due to a bug in Tracer.create_args_for_root (_patch_function), # we must unwrap function wrappers prior to tracing, for example diff --git a/src/llmcompressor/transformers/tracing/debug.py b/src/llmcompressor/transformers/tracing/debug.py index fe68ecf2a..87afa48fa 100644 --- a/src/llmcompressor/transformers/tracing/debug.py +++ b/src/llmcompressor/transformers/tracing/debug.py @@ -12,6 +12,8 @@ from llmcompressor.transformers import TextGenerationDataset from llmcompressor.args import DatasetArguments +from llmcompressor.utils.dev import skip_weights_download + __all__ = [ "get_model_class" ] @@ -24,6 +26,7 @@ def parse_args(): parser.add_argument("--sequential_targets", type=str, nargs="*", default=None, metavar="TARGET", help="List of targets for sequential tracing") # noqa: E501 parser.add_argument("--ignore", type=str, nargs="*", default=[], metavar="PATTERN", help="List of patterns to ignore during tracing") # noqa: E501 parser.add_argument("--modality", type=str, default="text", help="Modality of calibration dataset, defaults to text") # noqa: E501 + parser.add_argument("--trust_remote_code", type=bool, default=False, help="Whether to trust model remote code") # noqa: E501 return parser.parse_args() @@ -33,6 +36,7 @@ def trace( sequential_targets: Optional[Union[List[str], str]] = None, ignore: Union[List[str], str] = [], modality: str = "text", + trust_remote_code: bool = True ): """ Debug traceability by tracing a pre-trained model into subgraphs @@ -44,6 +48,7 @@ def trace( inference :param ignore: patterns to ignore during tracing :param modality: data modality for dummy tracing data, defaults to 'text' + :param trust_remote_code: trust remote model code Example usage from CLI llmcompressor.trace \ @@ -54,12 +59,16 @@ def trace( --modality text """ # Load model - model = model_class.from_pretrained( - model_id, - device_map="auto", - torch_dtype="auto", + with skip_weights_download(model_class): + model = model_class.from_pretrained( + model_id, + device_map="cpu", + torch_dtype="auto", + trust_remote_code=trust_remote_code, + ) + processor = AutoProcessor.from_pretrained( + model_id, trust_remote_code=trust_remote_code ) - processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) print("Loaded model") # Prepare sample data @@ -138,6 +147,7 @@ def main(): sequential_targets=args.sequential_targets, ignore=args.ignore, modality=args.modality, + trust_remote_code=args.trust_remote_code )