diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 236f6c9f5..b458fe7da 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2588,6 +2588,7 @@ def export( self.model.config, fbs if self.continuous_batching else bs, seq_len ) enable_chunking = kwargs.get("enable_chunking", False) + if prefill_only: if not enable_chunking and self.continuous_batching: raise NotImplementedError( @@ -2602,7 +2603,11 @@ def export( if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH else seq_len ) - kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len + kv_cache_shape[2] = ( + seq_len + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0) + if enable_chunking + else seq_len + ) else: self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) self.hash_params.pop("prefill_only", None) @@ -2611,7 +2616,9 @@ def export( self.hash_params.pop("ENABLE_OPT_SWA", None) self.hash_params.pop("chunking", None) if kwargs.get("retain_full_kv", False): - kv_cache_shape[2] = seq_len + self.model.config.sliding_window + kv_cache_shape[2] = seq_len + ( + self.model.config.sliding_window if self.model.config.sliding_window is not None else 0 + ) self.hash_params["retain_full_kv"] = True example_inputs = { diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index b978b6193..3f027ab3d 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -419,6 +419,7 @@ QEffQwen3Model, ) from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( + QEffPrefillChunkedQwen3MoeSparseMoeBlock, QEffQwen3MoeAttention, QEffQwen3MoeDecoderLayer, QEffQwen3MoeForCausalLM, @@ -663,19 +664,25 @@ class PrefillOnlyTransform(ModuleMappingTransform): class PrefillOnlyChunkedTransform(ModuleMappingTransform): _module_mapping = { + # GPT_OSS QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP, + # Qwen3Moe + QEffQwen3MoeSparseMoeBlock: QEffPrefillChunkedQwen3MoeSparseMoeBlock, } class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): _module_mapping = { + # GPT_OSS QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffPrefillOnlyGptOssMLP: QEffGptOssMLP, QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, + # Qwen3Moe + QEffPrefillChunkedQwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, } diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index cbd80d8ca..3ee4472ec 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -104,7 +104,6 @@ def eager_attention_forward( key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = torch.where( @@ -118,53 +117,48 @@ def eager_attention_forward( return attn_output, attn_weights -class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): - def __qeff_init__(self): - self.gate_proj_w = [] - self.up_proj_w = [] - self.down_proj_w = [] - with torch.no_grad(): - for e in range(self.num_experts): - self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) - self.up_proj_w.append(self.experts[e].up_proj.weight.T) - self.down_proj_w.append(self.experts[e].down_proj.weight.T) - self.gate_proj_w = torch.stack(self.gate_proj_w) - self.up_proj_w = torch.stack(self.up_proj_w) - self.down_proj_w = torch.stack(self.down_proj_w) - - def alt_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape T = B * S x = hidden_states.view(T, H) - router_logits = self.gate(x) # [T, E] prob = F.softmax(router_logits, -1, dtype=torch.float) top_w, top_i = torch.topk(prob, self.top_k, -1) - if self.norm_topk_prob: # only diff with mixtral sparse moe block! - top_w /= top_w.sum(-1, keepdim=True) - top_w = top_w.to(x.dtype) + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) masked_logits = torch.zeros_like(router_logits) masked_logits.scatter_(1, top_i, top_w) - # Routing weights for each expert [T, E] routing_weights = masked_logits - # ────────────────── allocate the output tensor ───── expert_out = x.new_zeros((T, H)) # accumulation buffer - # ───────────────────────── Expert computation loop ───────────────────────────── for e in range(self.num_experts): routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] - W_g, W_u = self.experts[e].gate_proj, self.experts[e].up_proj # [H, I], [H, I] - W_d = self.experts[e].down_proj # [I, H] - gate = W_g(x) # [T, I] - up = W_u(x) # [T, I] - down = W_d(up * self.experts[e].act_fn(gate)) # [T, H] - + W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T # [H, I], [H, I] + W_d = self.experts[e].down_proj.weight.T # [I, H] + gate = x @ W_g # [T, I] + up = x @ W_u # [T, I] + down = (up * self.experts[e].act_fn(gate)) @ W_d # [T, H] masked_down = torch.where(routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out)) expert_out += masked_down return expert_out.view(B, S, H), router_logits + +class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): + def __qeff_init__(self): + self.gate_proj_w = [] + self.up_proj_w = [] + self.down_proj_w = [] + with torch.no_grad(): + for e in range(self.num_experts): + self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) + self.up_proj_w.append(self.experts[e].up_proj.weight.T) + self.down_proj_w.append(self.experts[e].down_proj.weight.T) + self.gate_proj_w = torch.stack(self.gate_proj_w) + self.up_proj_w = torch.stack(self.up_proj_w) + self.down_proj_w = torch.stack(self.down_proj_w) + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape T = B * S diff --git a/examples/qwen3moe_disagg_mode_with_chunking.py b/examples/qwen3moe_disagg_mode_with_chunking.py new file mode 100644 index 000000000..0653fac77 --- /dev/null +++ b/examples/qwen3moe_disagg_mode_with_chunking.py @@ -0,0 +1,133 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import torch +from transformers import AutoConfig, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32 +prompt = """ +Explain quantum computing in simple terms. +""" +config = AutoConfig.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 128 +CTX_LEN = 128 * 3 + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + retain_full_kv=True, +) + +# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 + +# prefill_qpc_path = "" + +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=2, + split_retained_state_io=True, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + use_onnx_subfunctions=True, +) + + +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +generation_len = CTX_LEN - position_ids.max() +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +inputs.pop("past_key_values", None) +inputs = {k: v.detach().numpy() for k, v in inputs.items()} + + +prefill_session = QAICInferenceSession(prefill_qpc_path) +decode_session = QAICInferenceSession(decode_qpc_path) + +all_outputs = [] +for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + ins = time.time() + qpc_out = prefill_session.run(chunk_inputs) + print(f"time for this run={time.time() - ins}") + for i in range(config.num_hidden_layers): + inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +all_outputs.append(np.argmax(qpc_out["logits"])) + +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +for i in range(config.num_hidden_layers): + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +all_outputs.append(np.argmax(decode_out["logits"])) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, +} + +for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + +st = time.time() +for i in range(generation_len - 2): + decode_out = decode_session.run(loop_decode_inputs) + all_outputs.append(np.argmax(decode_out["logits"])) + pos_id += 1 + for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + + loop_decode_inputs.update( + { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + ) +ft = time.time() + +print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") +print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}")