Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions test/modules/model/TinyLlamaWithFusedRMSNorm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,28 @@
# limitations under the License.

import torch
from tico.passes.module_fusion import llama_rmsnorm

from tico.serialize.operators.adapters.llama_rmsnorm import patched_llama_rmsnorm
from tico.passes.module_fusion.fusion_registry import replace_modules_with_fused
from tico.utils.pytree_utils import register_dynamic_cache

from transformers import AutoModelForCausalLM
from transformers.models.llama.modeling_llama import LlamaRMSNorm

from test.modules.base import TestModuleBase


class TinyLlamaWithFusedRMSNorm(TestModuleBase):
def __init__(self):
super().__init__()
with patched_llama_rmsnorm():
self.model = AutoModelForCausalLM.from_pretrained(
"Maykeye/TinyLLama-v0"
).to("cpu")
self.model = AutoModelForCausalLM.from_pretrained("Maykeye/TinyLLama-v0").to(
"cpu"
)

self.rtol = 1e-4
self.atol = 1e-4

replace_modules_with_fused(self.model, [LlamaRMSNorm])
register_dynamic_cache()

def forward(self, x):
Expand Down
81 changes: 81 additions & 0 deletions tico/passes/module_fusion/fusion_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict, List, Type

import torch.nn as nn

# Dict with original module classes as keys and fused module classes as values.
# The value can be the fused module class itself, or a factory function that
# takes the original module as an argument and creates a fused module instance
_FUSED_MODULE_MAPPING: Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]] = {}


def register_fused_module(original_module_class: Type[nn.Module]):
"""
Decorator to register an original module class and its corresponding factory that creates the fused module
"""

def decorator(fused_module_factory: Callable[[nn.Module], nn.Module]):
_FUSED_MODULE_MAPPING[original_module_class] = fused_module_factory
return fused_module_factory

return decorator


def get_fused_module_factory(
original_module_class: Type[nn.Module],
) -> Callable[[nn.Module], nn.Module] | None:
"""
Returns the fused module factory corresponding to the registered original module class
"""
return _FUSED_MODULE_MAPPING.get(original_module_class)


def replace_modules_with_fused(
model: nn.Module, target_module_classes: List[Type[nn.Module]]
):
"""
Replaces all instances within the model that correspond to target_module_classes
with their fused versions registered in the registry
"""
replaced_count = 0
for name, module in model.named_modules():
if type(module) in target_module_classes:
fused_module_factory = get_fused_module_factory(type(module))
if fused_module_factory:
parent_module_name = ".".join(name.split(".")[:-1])
module_short_name = name.split(".")[-1]

parent_module = model
if parent_module_name:
for part in parent_module_name.split("."):
parent_module = getattr(parent_module, part)

new_module = fused_module_factory(module)

setattr(parent_module, module_short_name, new_module)
replaced_count += 1
print(
f"Replaced {name} ({type(module).__name__}) with {type(new_module).__name__}"
)
else:
print(
f"Warning: No fused module factory registered for {type(module).__name__}. Skipping replacement of {name}."
)

if replaced_count > 0:
print(f"Successfully replaced {replaced_count} module instances.")
else:
print("No target module instances found to replace.")
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager

import torch

from transformers.models.llama.modeling_llama import LlamaRMSNorm

from .fusion_registry import register_fused_module


class FusedLlamaRMSNorm(LlamaRMSNorm):
def __init__(self, original_rmsnorm: LlamaRMSNorm):
super().__init__(
original_rmsnorm.weight.shape[0], original_rmsnorm.variance_epsilon
)
with torch.no_grad():
self.weight.copy_(original_rmsnorm.weight)

def llama_rmsnorm_forward_adapter(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return torch.ops.circle_custom.rms_norm(
hidden_states, self.weight, self.variance_epsilon
)
def forward(self, hidden_states):
return torch.ops.circle_custom.rms_norm(
hidden_states, self.weight, self.variance_epsilon
)


@contextmanager
def patched_llama_rmsnorm():
orig = LlamaRMSNorm.forward
LlamaRMSNorm.forward = llama_rmsnorm_forward_adapter
try:
yield
finally:
LlamaRMSNorm.forward = orig
@register_fused_module(LlamaRMSNorm)
def create_fused_llama_rmsnorm(original_module: LlamaRMSNorm) -> FusedLlamaRMSNorm:
return FusedLlamaRMSNorm(original_module)