Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2588,6 +2588,7 @@ def export(
self.model.config, fbs if self.continuous_batching else bs, seq_len
)
enable_chunking = kwargs.get("enable_chunking", False)

if prefill_only:
if not enable_chunking and self.continuous_batching:
raise NotImplementedError(
Expand All @@ -2602,7 +2603,11 @@ def export(
if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH
else seq_len
)
kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len
kv_cache_shape[2] = (
seq_len + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0)
if enable_chunking
else seq_len
)
else:
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
self.hash_params.pop("prefill_only", None)
Expand All @@ -2611,7 +2616,9 @@ def export(
self.hash_params.pop("ENABLE_OPT_SWA", None)
self.hash_params.pop("chunking", None)
if kwargs.get("retain_full_kv", False):
kv_cache_shape[2] = seq_len + self.model.config.sliding_window
kv_cache_shape[2] = seq_len + (
self.model.config.sliding_window if self.model.config.sliding_window is not None else 0
)
self.hash_params["retain_full_kv"] = True

example_inputs = {
Expand Down
7 changes: 7 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@
QEffQwen3Model,
)
from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import (
QEffPrefillChunkedQwen3MoeSparseMoeBlock,
QEffQwen3MoeAttention,
QEffQwen3MoeDecoderLayer,
QEffQwen3MoeForCausalLM,
Expand Down Expand Up @@ -663,19 +664,25 @@ class PrefillOnlyTransform(ModuleMappingTransform):

class PrefillOnlyChunkedTransform(ModuleMappingTransform):
_module_mapping = {
# GPT_OSS
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP,
# Qwen3Moe
QEffQwen3MoeSparseMoeBlock: QEffPrefillChunkedQwen3MoeSparseMoeBlock,
}


class RevertPrefillKeepAttentionTransform(ModuleMappingTransform):
_module_mapping = {
# GPT_OSS
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
QEffPrefillOnlyGptOssMLP: QEffGptOssMLP,
QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP,
# Qwen3Moe
QEffPrefillChunkedQwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock,
}


Expand Down
52 changes: 23 additions & 29 deletions QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def eager_attention_forward(
key_states = repeat_kv(key, module.num_key_value_groups)

value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
Expand All @@ -118,53 +117,48 @@ def eager_attention_forward(
return attn_output, attn_weights


class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def __qeff_init__(self):
self.gate_proj_w = []
self.up_proj_w = []
self.down_proj_w = []
with torch.no_grad():
for e in range(self.num_experts):
self.gate_proj_w.append(self.experts[e].gate_proj.weight.T)
self.up_proj_w.append(self.experts[e].up_proj.weight.T)
self.down_proj_w.append(self.experts[e].down_proj.weight.T)
self.gate_proj_w = torch.stack(self.gate_proj_w)
self.up_proj_w = torch.stack(self.up_proj_w)
self.down_proj_w = torch.stack(self.down_proj_w)

def alt_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
B, S, H = hidden_states.shape
T = B * S
x = hidden_states.view(T, H)

router_logits = self.gate(x) # [T, E]
prob = F.softmax(router_logits, -1, dtype=torch.float)
top_w, top_i = torch.topk(prob, self.top_k, -1)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
top_w /= top_w.sum(-1, keepdim=True)
top_w = top_w.to(x.dtype)
top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype)
masked_logits = torch.zeros_like(router_logits)
masked_logits.scatter_(1, top_i, top_w)

# Routing weights for each expert [T, E]
routing_weights = masked_logits

# ────────────────── allocate the output tensor ─────
expert_out = x.new_zeros((T, H)) # accumulation buffer

# ───────────────────────── Expert computation loop ─────────────────────────────
for e in range(self.num_experts):
routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1]
W_g, W_u = self.experts[e].gate_proj, self.experts[e].up_proj # [H, I], [H, I]
W_d = self.experts[e].down_proj # [I, H]
gate = W_g(x) # [T, I]
up = W_u(x) # [T, I]
down = W_d(up * self.experts[e].act_fn(gate)) # [T, H]

W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T # [H, I], [H, I]
W_d = self.experts[e].down_proj.weight.T # [I, H]
gate = x @ W_g # [T, I]
up = x @ W_u # [T, I]
down = (up * self.experts[e].act_fn(gate)) @ W_d # [T, H]
masked_down = torch.where(routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out))
expert_out += masked_down
return expert_out.view(B, S, H), router_logits


class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def __qeff_init__(self):
self.gate_proj_w = []
self.up_proj_w = []
self.down_proj_w = []
with torch.no_grad():
for e in range(self.num_experts):
self.gate_proj_w.append(self.experts[e].gate_proj.weight.T)
self.up_proj_w.append(self.experts[e].up_proj.weight.T)
self.down_proj_w.append(self.experts[e].down_proj.weight.T)
self.gate_proj_w = torch.stack(self.gate_proj_w)
self.up_proj_w = torch.stack(self.up_proj_w)
self.down_proj_w = torch.stack(self.down_proj_w)

def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
B, S, H = hidden_states.shape
T = B * S
Expand Down
133 changes: 133 additions & 0 deletions examples/qwen3moe_disagg_mode_with_chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import time

import numpy as np
import torch
from transformers import AutoConfig, AutoTokenizer

from QEfficient import QEFFAutoModelForCausalLM
from QEfficient.generation.cloud_infer import QAICInferenceSession

model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32
prompt = """
Explain quantum computing in simple terms.
"""
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
PREFILL_SEQ_LEN = 128
CTX_LEN = 128 * 3

qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id)
decode_qpc_path = qeff_model.compile(
prefill_seq_len=1,
ctx_len=CTX_LEN,
num_cores=16,
mxfp6_matmul=True,
mxint8_kv_cache=True,
num_devices=1,
mos=1,
aic_enable_depth_first=True,
num_speculative_tokens=None,
offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step
retain_full_kv=True,
)

# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68

# prefill_qpc_path = ""

prefill_qpc_path = qeff_model.compile(
prefill_seq_len=PREFILL_SEQ_LEN,
ctx_len=CTX_LEN,
num_cores=16,
mxfp6_matmul=True,
mxint8_kv_cache=True,
num_devices=2,
split_retained_state_io=True,
mos=1,
aic_enable_depth_first=True,
num_speculative_tokens=None,
prefill_only=True,
enable_chunking=True,
use_onnx_subfunctions=True,
)


inputs = tokenizer(prompt, return_tensors="np", padding=True)
position_ids = inputs["attention_mask"].sum(1, keepdims=True)
generation_len = CTX_LEN - position_ids.max()
padded_len = inputs["input_ids"].shape[1]
num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
inputs.pop("token_type_ids", None)
inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
inputs.pop("past_key_values", None)
inputs = {k: v.detach().numpy() for k, v in inputs.items()}


prefill_session = QAICInferenceSession(prefill_qpc_path)
decode_session = QAICInferenceSession(decode_qpc_path)

all_outputs = []
for i in range(num_chunks):
chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
ins = time.time()
qpc_out = prefill_session.run(chunk_inputs)
print(f"time for this run={time.time() - ins}")
for i in range(config.num_hidden_layers):
inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]

all_outputs.append(np.argmax(qpc_out["logits"]))

decode_inputs = {
"input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1),
"position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1,
}
for i in range(config.num_hidden_layers):
decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]

st = time.time()
decode_out = decode_session.run(decode_inputs)
print(f"time for first run of decode with KV as input = {time.time() - st} sec\n")
all_outputs.append(np.argmax(decode_out["logits"]))
pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1
loop_decode_inputs = {
"input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
"position_ids": pos_id,
}

for i in range(config.num_hidden_layers):
loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]

st = time.time()
for i in range(generation_len - 2):
decode_out = decode_session.run(loop_decode_inputs)
all_outputs.append(np.argmax(decode_out["logits"]))
pos_id += 1
for i in range(config.num_hidden_layers):
loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]

loop_decode_inputs.update(
{
"input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
"position_ids": pos_id,
}
)
ft = time.time()

print(f"decode tok/sec={(generation_len - 2) / (ft - st)}")
print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}")
Loading