Skip to content

Dynamic max_num_tiles for mllama #309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 97 additions & 9 deletions QEfficient/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MllamaTextCrossAttention,
MllamaTextModel,
MllamaTextSelfAttention,
MllamaVisionAttention,
MllamaVisionModel,
logger,
repeat_kv,
Expand Down Expand Up @@ -1133,24 +1134,24 @@ def get_dummy_inputs(self, kv_offload: bool = False):

if vis_cfg := getattr(self.config, "vision_config", None):
img_size = getattr(vis_cfg, "image_size", 448)
max_num_img_tiles = getattr(vis_cfg, "max_num_tiles", 4)
max_num_tiles = getattr(vis_cfg, "max_num_tiles", 4)
else:
img_size = 448
max_num_img_tiles = 4
max_num_tiles = 4

# vision inputs
vision_inputs = {
"pixel_values": torch.zeros(
(BS, MAX_NUM_IMG, max_num_img_tiles, NUM_CHANNEL, img_size, img_size), dtype=torch.float32
(BS, MAX_NUM_IMG, max_num_tiles, NUM_CHANNEL, img_size, img_size), dtype=torch.float32
),
"aspect_ratio_ids": torch.ones((BS, MAX_NUM_IMG), dtype=torch.int64),
"aspect_ratio_mask": torch.ones((BS, MAX_NUM_IMG, max_num_img_tiles), dtype=torch.int64),
"aspect_ratio_mask": torch.ones((BS, MAX_NUM_IMG, max_num_tiles), dtype=torch.int64),
}

# lang_inputs
lang_inputs = {
"input_ids": torch.zeros((BS, SEQ_LEN), dtype=torch.int64),
"cross_attention_mask": torch.zeros((BS, SEQ_LEN, MAX_NUM_IMG, max_num_img_tiles), dtype=torch.int64),
"cross_attention_mask": torch.zeros((BS, SEQ_LEN, MAX_NUM_IMG, max_num_tiles), dtype=torch.int64),
"attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
}

Expand Down Expand Up @@ -1201,6 +1202,7 @@ def get_specializations(
):
vis_cfg = self.config.vision_config
max_num_images = compiler_options.pop("max_num_images", 1)
max_num_tiles = compiler_options.pop("max_num_tiles", 4)
prefill_seq_len = prefill_seq_len if prefill_seq_len else 32
ctx_len = ctx_len if ctx_len else 128
if img_size is None and hasattr(vis_cfg, "image_size"):
Expand All @@ -1209,20 +1211,29 @@ def get_specializations(
img_size = 448
logger.warning("Setting `img_size=448` as it was neither passed nor found in vision_config")

vision = [{"batch_size": batch_size, "max_num_images": max_num_images, "img_size": img_size}]
vision = [
{
"batch_size": batch_size,
"max_num_images": max_num_images,
"max_num_tiles": max_num_tiles,
"img_size": img_size,
}
]
lang = [
{
"batch_size": batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"max_num_images": max_num_images,
"max_num_tiles": max_num_tiles,
"img_size": img_size,
},
{
"batch_size": batch_size,
"seq_len": "1",
"ctx_len": ctx_len,
"max_num_images": max_num_images,
"max_num_tiles": max_num_tiles,
"img_size": img_size,
},
]
Expand All @@ -1241,15 +1252,15 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
cross_attention_layers = txt_cfg.cross_attention_layers

vision_dynamic_axes = {
"pixel_values": {0: "batch_size", 1: "max_num_images", 4: "img_size", 5: "img_size"},
"pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_num_tiles", 4: "img_size", 5: "img_size"},
"aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"},
"aspect_ratio_mask": {0: "batch_size", 1: "max_num_images"},
"aspect_ratio_mask": {0: "batch_size", 1: "max_num_images", 2: "max_num_tiles"},
}

lang_dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"position_ids": {0: "batch_size", 1: "seq_len"},
"cross_attention_mask": {0: "batch_size", 1: "seq_len", 2: "max_num_images"},
"cross_attention_mask": {0: "batch_size", 1: "seq_len", 2: "max_num_images", 3: "max_num_tiles"},
}

for i in range(num_hidden_layers):
Expand Down Expand Up @@ -1305,3 +1316,80 @@ def get_inputs_info(self):
),
IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")),
]


class QEffMllamaVisionAttention(MllamaVisionAttention):
def compute_block_attention(self, query_states, key_states, value_states, attention_mask, start_idx, end_idx):
curr_attn_weights = torch.matmul(
query_states[:, :, start_idx:end_idx, :], key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask_block = attention_mask[:, :, start_idx:end_idx, : key_states.shape[-2]]
curr_attn_weights += causal_mask_block
# upcast attention to fp32
curr_attn_weights = nn.functional.softmax(curr_attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
curr_attn_output = torch.matmul(curr_attn_weights, value_states)

return curr_attn_output

def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = None,
block_size: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
query = self.q_proj(hidden_state)
key = self.k_proj(hidden_state)
value = self.v_proj(hidden_state)

batch_size, q_seq_len, _ = query.shape
_, kv_seq_len, _ = key.shape

query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)

if block_size is not None:
runtime_block_size = torch.where(
(q_seq_len // torch.tensor(block_size)) > 0, torch.tensor(block_size), torch.tensor(1)
)
reminder_block_size = q_seq_len % block_size # calculate the remaining query block
attn_output = torch.zeros(batch_size, self.num_heads, q_seq_len, self.head_dim)
num_iterations = q_seq_len // runtime_block_size

for iteration in range(num_iterations):
start_idx = iteration * runtime_block_size
end_idx = (iteration + 1) * runtime_block_size
attn_output[:, :, start_idx:end_idx, :] = self.compute_block_attention(
query, key, value, attention_mask, start_idx, end_idx
)

if reminder_block_size:
start_idx = num_iterations * runtime_block_size
end_idx = start_idx + reminder_block_size
attn_output[:, :, start_idx:end_idx, :] = self.compute_block_attention(
query, key, value, attention_mask, start_idx, end_idx
)

else:
# Regular attention
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)

output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return output, attn_weights
17 changes: 14 additions & 3 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
get_compilation_dims,
)
from QEfficient.transformers.models.pytorch_transforms import (
BlockAttentionTransform,
CustomOpsTransform,
KVCacheModuleMethodMapperTransform,
KVCacheTransform,
Expand Down Expand Up @@ -525,6 +526,7 @@ class _QEffAutoModelForImageTextToTextDualQPC:
def __init__(
self,
model: nn.Module,
block_size: int = None,
**kwargs,
):
if kwargs.pop("full_batch_size", None):
Expand All @@ -536,6 +538,9 @@ def __init__(

self.input_shapes, self.output_names = None, None

if block_size:
BlockAttentionTransform.apply(model, block_size=block_size)

@property
def model_name(self) -> str:
mname = self.model.__class__.__name__
Expand Down Expand Up @@ -627,7 +632,7 @@ def compile(

custom_io_vision = {}
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
custom_io_vision["pixel_values"] = kv_cache_dtype
custom_io_vision["pixel_values"] = "float16"
for output_name in output_names["vision"]:
custom_io_vision[output_name] = kv_cache_dtype

Expand Down Expand Up @@ -850,12 +855,16 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal
def __init__(
self,
model: nn.Module,
block_size: int = None,
**kwargs,
):
if kwargs.pop("full_batch_size", None):
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
super().__init__(model)

if block_size:
BlockAttentionTransform.apply(model, block_size=block_size)

# to handle internvl models
if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"):
self.model.config.llm_config.use_cache = True
Expand Down Expand Up @@ -1222,7 +1231,9 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs)

@classmethod
@with_replaced_quantizers
def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs):
def from_pretrained(
cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, block_size: int = None, **kwargs
):
"""Used to load models supported by transformers.AutoModelForImageTextToText for Cloud AI 100.

Args:
Expand All @@ -1241,7 +1252,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(model, kv_offload=kv_offload, **kwargs)
return cls(model, kv_offload=kv_offload, block_size=block_size, **kwargs)


MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {"InternVLChatModel": QEFFAutoModelForImageTextToText}
Expand Down
21 changes: 21 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# -----------------------------------------------------------------------------

from functools import partial
from types import MethodType
from typing import Tuple

Expand Down Expand Up @@ -84,6 +85,7 @@
MllamaTextModel,
MllamaTextRMSNorm,
MllamaTextSelfAttention,
MllamaVisionAttention,
MllamaVisionModel,
)
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
Expand Down Expand Up @@ -204,6 +206,7 @@
QEffMllamaTextCrossAttentionTwoQPC,
QEffMllamaTextModel,
QEffMllamaTextSelfAttention,
QEffMllamaVisionAttention,
QEffMllamaVisionModel,
)
from QEfficient.transformers.models.mpt.modeling_mpt import (
Expand Down Expand Up @@ -439,3 +442,21 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform):
"InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward},
}
_match_class_replace_method = {}


class BlockAttentionTransform(ModuleMappingTransform): # Fixed typo in class name
# supported architectures
_module_mapping = {
MllamaVisionAttention: QEffMllamaVisionAttention,
}

@classmethod
def apply(cls, model: nn.Module, block_size) -> Tuple[nn.Module, bool]:
transformed = False
for module in model.modules():
if repl_module := cls._module_mapping.get(type(module)):
module.__class__ = repl_module
# Bind the partial function to the instance
module.forward = MethodType(partial(repl_module.forward, block_size=block_size), module)
transformed = True # Set to True if at least one transformation occurs
return model, transformed
Loading