Skip to content

[WIP][Tracing] Mistral3ForConditionalGeneration #1387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 91 additions & 6 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Union

from loguru import logger
from compressed_tensors import has_offloaded_params
from compressed_tensors.quantization import find_name_or_class_matches
from torch.fx import Graph, GraphModule, Node
from torch.fx.graph import PythonCode
from torch.fx.proxy import Argument
from torch.nn import Module
from torch.fx import _symbolic_trace
from transformers import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.utils.fx import HFTracer
Expand Down Expand Up @@ -82,14 +84,18 @@ def trace_subgraphs(
"""
# find modules
sequential_targets = match_modules(model, sequential_targets)
ignore = match_modules(model, ignore)
ignore = ["MistralModel._update_causal_mask"]

# initialize arguments
tracer = get_tracer(model, sequential_targets, ignore)
concrete_args = populate_concrete_args(model, sample_input)

from torch import compiler

wrap_module_methods(model, ignore)

# trace
with calibration_forward_context(model), HooksMixin.disable_hooks():
with calibration_forward_context(model), HooksMixin.disable_hooks(), patch_attr(compiler, "_is_compiling_flag", True):
graph = GraphModule(
model,
tracer.trace(
Expand All @@ -115,6 +121,69 @@ def trace_subgraphs(
return subgraphs


def get_sequential_ancestors(model: Module, targets: Set[Module]) -> Set[Module]:
ancestors = set()
def dfs(module: Module) -> bool:
if module in ancestors:
return True

if module in targets:
return True

# search all children (do not early stop)
is_ancestor = False
for child in module.children():
is_ancestor |= dfs(child)

if is_ancestor:
ancestors.add(module)

return is_ancestor

dfs(model)

return ancestors


def wrap_module_methods(model: Module, ignore: List[str]):
module_classes = set(type(module) for module in model.modules())

for ignore_pattern in ignore:
num_dots = ignore_pattern.count(".")

if num_dots == 0:
method_name = ignore_pattern
num_match = 0
for cls in module_classes:
if hasattr(cls, method_name):
_symbolic_trace._wrapped_methods_to_patch.append((cls, method_name))
num_match += 1

if num_match <= 0:
raise ValueError()

if num_match >= 2:
logger.warning()

elif num_dots == 1:
cls_name, method_name = ignore_pattern.split(".")
num_match = 0
for cls in module_classes:
if cls.__name__ == cls_name and hasattr(cls, method_name):
_symbolic_trace._wrapped_methods_to_patch.append((cls, method_name))
print(f"wrapped {(cls, method_name)}")
num_match += 1

if num_match <= 0:
raise ValueError()

if num_match >= 2:
logger.warning()

else:
raise ValueError()


def get_tracer(
model: Module, sequential_targets: Set[Module], ignore: Set[Module]
) -> HFTracer:
Expand All @@ -131,7 +200,18 @@ def get_tracer(
"""
# TODO: redefine skip_trace_modules to all non-ancestors of sequential_targets
offloaded_modules = set(m for m in model.modules() if has_offloaded_params(m))
skip_trace_modules = sequential_targets | offloaded_modules | ignore
sequential_ancestors = get_sequential_ancestors(model, sequential_targets)
print([ancestor.__class__.__name__ for ancestor in sequential_ancestors])
#skip_trace_modules = sequential_targets | offloaded_modules | ignore


from torch.fx import _symbolic_trace

from transformers import MistralModel

print(_symbolic_trace._wrapped_methods_to_patch)
_symbolic_trace._wrapped_methods_to_patch.append((MistralModel, "_update_causal_mask"))
print(_symbolic_trace._wrapped_methods_to_patch)

class SequentialTracer(HFTracer):
def create_arg(self, a: Any) -> Argument:
Expand All @@ -144,11 +224,16 @@ def create_arg(self, a: Any) -> Argument:
return super().create_arg(a)

def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool:
return module in skip_trace_modules or super().is_leaf_module(
module, module_qualified_name
)
#return module in skip_trace_modules or super().is_leaf_module(
# module, module_qualified_name
# )
return module not in sequential_ancestors

def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph:
print(f"trace: {root.__class__.__name__}")

# todo; if has ignored function name as method, add to list of methods to patch

if isinstance(root, Module):
# due to a bug in Tracer.create_args_for_root (_patch_function),
# we must unwrap function wrappers prior to tracing, for example
Expand Down
20 changes: 15 additions & 5 deletions src/llmcompressor/transformers/tracing/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from llmcompressor.transformers import TextGenerationDataset
from llmcompressor.args import DatasetArguments

from llmcompressor.utils.dev import skip_weights_download

__all__ = [
"get_model_class"
]
Expand All @@ -24,6 +26,7 @@ def parse_args():
parser.add_argument("--sequential_targets", type=str, nargs="*", default=None, metavar="TARGET", help="List of targets for sequential tracing") # noqa: E501
parser.add_argument("--ignore", type=str, nargs="*", default=[], metavar="PATTERN", help="List of patterns to ignore during tracing") # noqa: E501
parser.add_argument("--modality", type=str, default="text", help="Modality of calibration dataset, defaults to text") # noqa: E501
parser.add_argument("--trust_remote_code", type=bool, default=False, help="Whether to trust model remote code") # noqa: E501
return parser.parse_args()


Expand All @@ -33,6 +36,7 @@ def trace(
sequential_targets: Optional[Union[List[str], str]] = None,
ignore: Union[List[str], str] = [],
modality: str = "text",
trust_remote_code: bool = True
):
"""
Debug traceability by tracing a pre-trained model into subgraphs
Expand All @@ -44,6 +48,7 @@ def trace(
inference
:param ignore: patterns to ignore during tracing
:param modality: data modality for dummy tracing data, defaults to 'text'
:param trust_remote_code: trust remote model code

Example usage from CLI
llmcompressor.trace \
Expand All @@ -54,12 +59,16 @@ def trace(
--modality text
"""
# Load model
model = model_class.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto",
with skip_weights_download(model_class):
model = model_class.from_pretrained(
model_id,
device_map="cpu",
torch_dtype="auto",
trust_remote_code=trust_remote_code,
)
processor = AutoProcessor.from_pretrained(
model_id, trust_remote_code=trust_remote_code
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
print("Loaded model")

# Prepare sample data
Expand Down Expand Up @@ -138,6 +147,7 @@ def main():
sequential_targets=args.sequential_targets,
ignore=args.ignore,
modality=args.modality,
trust_remote_code=args.trust_remote_code
)


Expand Down
Loading