From 81cf999a40bfe08284c720d58dec369ce8c09960 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Mon, 22 Dec 2025 11:22:56 +0000 Subject: [PATCH 1/5] Adding dissagg mode support to Qwen3Moe Signed-off-by: Dipankar Sarkar --- .../transformers/models/modeling_auto.py | 11 +- .../transformers/models/pytorch_transforms.py | 7 + .../models/qwen3_moe/modeling_qwen3_moe.py | 52 +++---- .../causallm/example_pytorch_transforms.py | 12 +- .../qwen3moe_disagg_mode_with_chunking.py | 132 ++++++++++++++++++ 5 files changed, 177 insertions(+), 37 deletions(-) create mode 100644 examples/qwen3moe_disagg_mode_with_chunking.py diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 236f6c9f5..5abeb3824 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2588,6 +2588,7 @@ def export( self.model.config, fbs if self.continuous_batching else bs, seq_len ) enable_chunking = kwargs.get("enable_chunking", False) + if prefill_only: if not enable_chunking and self.continuous_batching: raise NotImplementedError( @@ -2602,7 +2603,11 @@ def export( if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH else seq_len ) - kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len + kv_cache_shape[2] = ( + seq_len + (0 if self.model.config.sliding_window is None else self.model.config.sliding_window) + if enable_chunking + else seq_len + ) else: self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) self.hash_params.pop("prefill_only", None) @@ -2611,7 +2616,9 @@ def export( self.hash_params.pop("ENABLE_OPT_SWA", None) self.hash_params.pop("chunking", None) if kwargs.get("retain_full_kv", False): - kv_cache_shape[2] = seq_len + self.model.config.sliding_window + kv_cache_shape[2] = seq_len + ( + 0 if self.model.config.sliding_window is None else self.model.config.sliding_window + ) self.hash_params["retain_full_kv"] = True example_inputs = { diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index b978b6193..3f027ab3d 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -419,6 +419,7 @@ QEffQwen3Model, ) from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( + QEffPrefillChunkedQwen3MoeSparseMoeBlock, QEffQwen3MoeAttention, QEffQwen3MoeDecoderLayer, QEffQwen3MoeForCausalLM, @@ -663,19 +664,25 @@ class PrefillOnlyTransform(ModuleMappingTransform): class PrefillOnlyChunkedTransform(ModuleMappingTransform): _module_mapping = { + # GPT_OSS QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP, + # Qwen3Moe + QEffQwen3MoeSparseMoeBlock: QEffPrefillChunkedQwen3MoeSparseMoeBlock, } class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): _module_mapping = { + # GPT_OSS QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffPrefillOnlyGptOssMLP: QEffGptOssMLP, QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, + # Qwen3Moe + QEffPrefillChunkedQwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, } diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index cbd80d8ca..3ee4472ec 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -104,7 +104,6 @@ def eager_attention_forward( key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = torch.where( @@ -118,53 +117,48 @@ def eager_attention_forward( return attn_output, attn_weights -class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): - def __qeff_init__(self): - self.gate_proj_w = [] - self.up_proj_w = [] - self.down_proj_w = [] - with torch.no_grad(): - for e in range(self.num_experts): - self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) - self.up_proj_w.append(self.experts[e].up_proj.weight.T) - self.down_proj_w.append(self.experts[e].down_proj.weight.T) - self.gate_proj_w = torch.stack(self.gate_proj_w) - self.up_proj_w = torch.stack(self.up_proj_w) - self.down_proj_w = torch.stack(self.down_proj_w) - - def alt_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape T = B * S x = hidden_states.view(T, H) - router_logits = self.gate(x) # [T, E] prob = F.softmax(router_logits, -1, dtype=torch.float) top_w, top_i = torch.topk(prob, self.top_k, -1) - if self.norm_topk_prob: # only diff with mixtral sparse moe block! - top_w /= top_w.sum(-1, keepdim=True) - top_w = top_w.to(x.dtype) + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) masked_logits = torch.zeros_like(router_logits) masked_logits.scatter_(1, top_i, top_w) - # Routing weights for each expert [T, E] routing_weights = masked_logits - # ────────────────── allocate the output tensor ───── expert_out = x.new_zeros((T, H)) # accumulation buffer - # ───────────────────────── Expert computation loop ───────────────────────────── for e in range(self.num_experts): routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] - W_g, W_u = self.experts[e].gate_proj, self.experts[e].up_proj # [H, I], [H, I] - W_d = self.experts[e].down_proj # [I, H] - gate = W_g(x) # [T, I] - up = W_u(x) # [T, I] - down = W_d(up * self.experts[e].act_fn(gate)) # [T, H] - + W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T # [H, I], [H, I] + W_d = self.experts[e].down_proj.weight.T # [I, H] + gate = x @ W_g # [T, I] + up = x @ W_u # [T, I] + down = (up * self.experts[e].act_fn(gate)) @ W_d # [T, H] masked_down = torch.where(routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out)) expert_out += masked_down return expert_out.view(B, S, H), router_logits + +class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): + def __qeff_init__(self): + self.gate_proj_w = [] + self.up_proj_w = [] + self.down_proj_w = [] + with torch.no_grad(): + for e in range(self.num_experts): + self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) + self.up_proj_w.append(self.experts[e].up_proj.weight.T) + self.down_proj_w.append(self.experts[e].down_proj.weight.T) + self.gate_proj_w = torch.stack(self.gate_proj_w) + self.up_proj_w = torch.stack(self.up_proj_w) + self.down_proj_w = torch.stack(self.down_proj_w) + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape T = B * S diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index ff62588f9..503efc12d 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,12 +27,6 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from torch import nn # Example imports for three representative models @@ -62,6 +56,12 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, diff --git a/examples/qwen3moe_disagg_mode_with_chunking.py b/examples/qwen3moe_disagg_mode_with_chunking.py new file mode 100644 index 000000000..6dbbe1215 --- /dev/null +++ b/examples/qwen3moe_disagg_mode_with_chunking.py @@ -0,0 +1,132 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import torch +from transformers import AutoConfig, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32 +prompt = """ +Explain quantum computing in simple terms. +""" +config = AutoConfig.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 128 +CTX_LEN = 128 * 3 + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + retain_full_kv=True, +) + +# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 +# prefill_qpc_path = "/home/dipankar/.cache/qeff_models/Qwen3MoeForCausalLM/Qwen3MoeForCausalLM-d6bec77055bbf321/qpc-6d69cd128947ac31/qpc" + +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=2, + split_retained_state_io=True, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + use_onnx_subfunctions=True, +) + + +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +generation_len = CTX_LEN - position_ids.max() +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +inputs.pop("past_key_values", None) +inputs = {k: v.detach().numpy() for k, v in inputs.items()} + + +prefill_session = QAICInferenceSession(prefill_qpc_path) +decode_session = QAICInferenceSession(decode_qpc_path) + +all_outputs = [] +for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + ins = time.time() + qpc_out = prefill_session.run(chunk_inputs) + print(f"time for this run={time.time() - ins}") + for i in range(config.num_hidden_layers): + inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +all_outputs.append(np.argmax(qpc_out["logits"])) + +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +for i in range(config.num_hidden_layers): + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +all_outputs.append(np.argmax(decode_out["logits"])) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, +} + +for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + +st = time.time() +for i in range(generation_len - 2): + decode_out = decode_session.run(loop_decode_inputs) + all_outputs.append(np.argmax(decode_out["logits"])) + pos_id += 1 + for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + + loop_decode_inputs.update( + { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + ) +ft = time.time() + +print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") +print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}") From 4819fbca32cdb63bd651537d9d86f361bd2c322d Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Mon, 22 Dec 2025 11:28:10 +0000 Subject: [PATCH 2/5] Cleaning of example script Signed-off-by: Dipankar Sarkar --- examples/qwen3moe_disagg_mode_with_chunking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/qwen3moe_disagg_mode_with_chunking.py b/examples/qwen3moe_disagg_mode_with_chunking.py index 6dbbe1215..7e16ac15b 100644 --- a/examples/qwen3moe_disagg_mode_with_chunking.py +++ b/examples/qwen3moe_disagg_mode_with_chunking.py @@ -39,7 +39,7 @@ ) # Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 -# prefill_qpc_path = "/home/dipankar/.cache/qeff_models/Qwen3MoeForCausalLM/Qwen3MoeForCausalLM-d6bec77055bbf321/qpc-6d69cd128947ac31/qpc" +# prefill_qpc_path = "" prefill_qpc_path = qeff_model.compile( prefill_seq_len=PREFILL_SEQ_LEN, From d2ba2824ca374027689045139c33ff55c61611cb Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Mon, 22 Dec 2025 12:10:14 +0000 Subject: [PATCH 3/5] Lint fix Signed-off-by: Dipankar Sarkar --- .../causallm/example_pytorch_transforms.py | 14 +++++++------- examples/qwen3moe_disagg_mode_with_chunking.py | 1 + 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index 503efc12d..a57ac3d1b 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,6 +27,12 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from torch import nn # Example imports for three representative models @@ -56,12 +62,6 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, @@ -288,4 +288,4 @@ def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Modu ) model = PooledModel(model, pooling_method) warnings.warn("Pooling is applied to the model.") - return model, transformed + return model, transformed \ No newline at end of file diff --git a/examples/qwen3moe_disagg_mode_with_chunking.py b/examples/qwen3moe_disagg_mode_with_chunking.py index 7e16ac15b..0653fac77 100644 --- a/examples/qwen3moe_disagg_mode_with_chunking.py +++ b/examples/qwen3moe_disagg_mode_with_chunking.py @@ -39,6 +39,7 @@ ) # Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 + # prefill_qpc_path = "" prefill_qpc_path = qeff_model.compile( From 4585c93ac5a8db61e93d3ec6054e63174b230ae9 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Mon, 22 Dec 2025 13:53:15 +0000 Subject: [PATCH 4/5] Minor fixes Signed-off-by: Dipankar Sarkar --- QEfficient/transformers/models/modeling_auto.py | 4 ++-- .../causallm/example_pytorch_transforms.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5abeb3824..b458fe7da 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2604,7 +2604,7 @@ def export( else seq_len ) kv_cache_shape[2] = ( - seq_len + (0 if self.model.config.sliding_window is None else self.model.config.sliding_window) + seq_len + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0) if enable_chunking else seq_len ) @@ -2617,7 +2617,7 @@ def export( self.hash_params.pop("chunking", None) if kwargs.get("retain_full_kv", False): kv_cache_shape[2] = seq_len + ( - 0 if self.model.config.sliding_window is None else self.model.config.sliding_window + self.model.config.sliding_window if self.model.config.sliding_window is not None else 0 ) self.hash_params["retain_full_kv"] = True diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index a57ac3d1b..503efc12d 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,12 +27,6 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from torch import nn # Example imports for three representative models @@ -62,6 +56,12 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, @@ -288,4 +288,4 @@ def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Modu ) model = PooledModel(model, pooling_method) warnings.warn("Pooling is applied to the model.") - return model, transformed \ No newline at end of file + return model, transformed From f24dc83c431769d68087348b94133dfd3b64c414 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Mon, 22 Dec 2025 18:08:56 +0000 Subject: [PATCH 5/5] Lint Fix 2 Signed-off-by: Dipankar Sarkar --- .../causallm/example_pytorch_transforms.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index 503efc12d..ff62588f9 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,6 +27,12 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from torch import nn # Example imports for three representative models @@ -56,12 +62,6 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer,