diff --git a/doc/en/Hunyuan.md b/doc/en/Hunyuan.md new file mode 100644 index 00000000..6a0d8bbb --- /dev/null +++ b/doc/en/Hunyuan.md @@ -0,0 +1,77 @@ +# HunYuan Support for KTransformers + +## Introduction + +### Overview +We are excited to announce that **KTransformers now supports HunYuan models with AMX optimization**. + +- **HunYuan-Standard (AMX bf16)**: ~12 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~441 GB DRAM. Enhanced performance with Intel AMX acceleration for MoE expert computations. + +### Model & Resource Links +- *[Hunyuan-A13B-Instruct](https://huggingface.co/tencent/Hunyuan-A13B-Instruct)* + +--- + +## Installation Guide + +### 1. Resource Requirements + +| Model | Precision | Experts | DRAM Needed | GPU Memory Needed\* | TPS (approx.) | +| ------------------------- | ---------- | ------- | ----------- | ------------------- | --------------------------------------- | +| HunYuan-Standard | bf16 | 64 | \~441 GB | 14 GB | \~12 TPS | + +\* Exact GPU memory depends on sequence length, batch size, and kernels used. + +### 2. Prepare Models + +```bash +# Example: download original safetensors (adjust to your paths/repos) +# (Fill in actual repos/filenames yourself) + +# HunYuan-Standard +huggingface-cli download --resume-download https://huggingface.co/tencent/Hunyuan-A13B-Instruct \ + --local-dir ./Hunyuan-A13B-Instruct +``` + +### 3. Install KTransformers + +Follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html). + +```bash +pip install ktransformers # or from source if you need bleeding-edge features +``` + +### 4. Run HunYuan Inference Server + +```bash +python ktransformers/server/main.py \ + --port 10002 \ + --model_path /abs/path/to/Hunyuan-A13B-Instruct \ + --model_name Hunyuan-A13B-Instruct \ + --gguf_path /abs/path/to/Hunyuan model files (.gguf or .safetensor) \ + --optimize_config_path ktransformers/optimize/optimize_rules/Hunyuan-serve-amx.yaml \ + --max_new_tokens 1024 \ + --cache_lens 32768 \ + --chunk_size 256 \ + --max_batch_size 4 \ + --backend_type balance_serve +``` + +### 5. Access Server + +```bash +curl http://127.0.0.1:10002/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Hunyuan-A13B-Instruct", + "messages": [ + {"role": "user", "content": "介绍一下西伯利亚森林猫"} + ], + "temperature": 0.7, + "max_tokens": 200, + "stream": false + }' +``` + + +--- \ No newline at end of file diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 0a685d71..7390d83c 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -28,6 +28,7 @@ from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM +from ktransformers.models.modeling_hunyuan import HunYuanMoEV1ForCausalLM from ktransformers.util.utils import prefill_and_generate, get_compute_capability, xpu_fp16_model from ktransformers.server.config.config import Config from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled @@ -39,6 +40,7 @@ "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "MixtralForCausalLM": MixtralForCausalLM, + "HunYuanMoEV1ForCausalLM": HunYuanMoEV1ForCausalLM, } ktransformer_rules_dir = ( @@ -50,6 +52,7 @@ "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml", "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml", + "HunYuanMoEV1ForCausalLM": ktransformer_rules_dir + "Hunyuan-serve.yaml", } @@ -96,6 +99,8 @@ def local_chat( config._attn_implementation = "eager" if "Mixtral" in config.architectures[0]: config._attn_implementation = "flash_attention_2" + if "HunYuan" in config.architectures[0]: + config._attn_implementation = "flash_attention_2" if torch.xpu.is_available(): config._attn_implementation = "eager" model = custom_models[config.architectures[0]](config) diff --git a/ktransformers/models/configuration_hunyuan.py b/ktransformers/models/configuration_hunyuan.py new file mode 100644 index 00000000..fe794621 --- /dev/null +++ b/ktransformers/models/configuration_hunyuan.py @@ -0,0 +1,336 @@ +# coding=utf-8 +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +""" HunYuan model configuration""" +from torch import nn +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from typing import List, Union, Optional + + +logger = logging.get_logger(__name__) + + +class HunYuanConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`HunYuanModel`]. It is used to instantiate an + HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the HunYuan-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`HunYuanModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations or shared MLP representations. + moe_intermediate_size (`int` or `List`, *optional*, defaults to 11008): + Dimension of the MLP representations in MoE. Use a list if you want a different size per layer. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + use_qk_norm (`bool`, *optional*, defaults to `False`): + Whether query and key in attention use norm + use_cla (`bool`, *optional*, defaults to `False`): + Whether to use CLA in attention + cla_share_factor (`int`, *optional*, defaults to 1): + The share factor of CLA + num_experts (`int` or `List`, *optional*, defaults to 1): + The number of experts for moe. If it is a list, it will be used as the number of experts for each layer. + num_shared_expert (`int` or `List`, *optional*, defaults to 1): + The number of shared experts for moe. If it is a list, it will be used as the number of shared experts for each layer. + moe_topk (`int` or `List`, *optional*, defaults to 1): + The topk value for moe. If it is a list, it will be used as the topk value for each layer. + capacity_factor (Not used) (`float` or `List`, *optional*, defaults to 1.0): + The capacity factor for moe. If it is a list, it will be used as the capacity factor for each layer. + moe_layer_num_skipped (`int`, *optional*, defaults to 0): + First moe_layer_num_skipped layers do not use MoE. + """ + + model_type = "hunyuan_v1_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=290943, + org_vocab_size=290943, + hidden_size=4096, + intermediate_size: int=11008, + moe_intermediate_size: Union[int, List]=None, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + attention_head_dim=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + eod_token_id=3, + sep_token_id=4, + im_start_id=5, + im_end_id=6, + text_start_id=7, + text_end_id=8, + image_token_id=9, + video_start_id=10, + video_end_id=11, + im_newline_id=12, + mask_init_id=13, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + mlp_bias=False, + attention_dropout=0.0, + use_qk_norm=False, + use_rotary_pos_emb=True, + use_cla=False, + cla_share_factor=1, + norm_type="hf_rms", + num_experts: Union[int, List]=1, + use_mixed_mlp_moe=False, + num_shared_expert: Union[int, List]=1, + moe_topk: Union[int, List]=1, + # capacity_factor: Union[int, List]=1.0, + moe_drop_tokens=False, + moe_random_routing_dropped_token=False, + use_mla=False, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + moe_layer_num_skipped=0, + norm_topk_prob=True, + routed_scaling_factor=1.0, + group_limited_greedy=False, + n_group=None, + topk_group=None, + vit_path=None, + num_media_embeds=257, + vit_type="AnyResVit", + vit_input_resolution=224, + vit_token=64, + vit_patch=1, + vit_mapping_type="simple_conv_mlp", + vit_norm_type="fused", + vit_used_rms_norm=True, + vit_remove_prenorm=True, + vit_add_patchemb_bias=True, + anyres_vit_max_image_size=2048, + anyres_pooling_size=2, + anyres_vit_two_views=False, + skip_cls_token=False, + position_embedding_xdrope=False, + xdrope_section=None, + add_classification_head=False, + class_num=0, + pool_type="last", + pad_id=-1, + **kwargs, + ): + self.vocab_size = vocab_size + self.org_vocab_size = org_vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_experts = num_experts + self.use_mixed_mlp_moe = use_mixed_mlp_moe + self.num_shared_expert = num_shared_expert + self.moe_topk = moe_topk + # For compatibility with KTransformers which expects num_experts_per_tok + # Set it as a real attribute, not just a property + self.num_experts_per_tok = moe_topk[0] if isinstance(moe_topk, list) else moe_topk + # self.capacity_factor = capacity_factor + self.moe_drop_tokens = moe_drop_tokens + self.moe_random_routing_dropped_token = moe_random_routing_dropped_token + + if attention_head_dim is not None: + self.attention_head_dim = attention_head_dim + else: + self.attention_head_dim = self.hidden_size // num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # self._rope_scaling_validation() # TODO: Need validation? + self.attention_bias = attention_bias + self.mlp_bias = mlp_bias + self.attention_dropout = attention_dropout + self.use_qk_norm = use_qk_norm + self.use_rotary_pos_emb = use_rotary_pos_emb + self.use_cla = use_cla + self.cla_share_factor = cla_share_factor + self.norm_type = norm_type + # MLA args + self.use_mla = use_mla + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.v_head_dim = v_head_dim + + # DeepSeek related args + self.moe_layer_num_skipped = moe_layer_num_skipped + self.norm_topk_prob = norm_topk_prob + self.routed_scaling_factor = routed_scaling_factor + self.group_limited_greedy = group_limited_greedy + self.n_group = n_group + self.topk_group = topk_group + self.add_classification_head = add_classification_head + self.class_num = class_num + self.pool_type = pool_type + self.pad_id = pad_id + + if self.class_num is not None: + self.dense_list = [self.hidden_size, self.class_num] + + # Vit args + self.vit_path = vit_path + self.num_media_embeds = num_media_embeds + self.vit_type = vit_type + self.vit_input_resolution = vit_input_resolution + self.vit_token = vit_token + self.vit_patch = vit_patch + self.vit_mapping_type = vit_mapping_type + self.vit_norm_type = vit_norm_type + self.vit_used_rms_norm = vit_used_rms_norm + self.vit_remove_prenorm = vit_remove_prenorm + self.vit_add_patchemb_bias = vit_add_patchemb_bias + self.anyres_vit_max_image_size = anyres_vit_max_image_size + self.anyres_pooling_size = anyres_pooling_size + self.anyres_vit_two_views = anyres_vit_two_views + self.skip_cls_token = skip_cls_token + self.position_embedding_xdrope = position_embedding_xdrope + self.xdrope_section = xdrope_section + + # token id + self.eod_token_id = eod_token_id + self.im_start_id = im_start_id + self.im_end_id = im_end_id + self.text_start_id = text_start_id + self.text_end_id = text_end_id + self.image_token_id = image_token_id + self.video_start_id = video_start_id + self.video_end_id = video_end_id + self.im_newline_id = im_newline_id + self.mask_init_id = mask_init_id + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + sep_token_id=sep_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor` or `type` and `alpha`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + rope_scaling_alpha = self.rope_scaling.get("alpha", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None and rope_scaling_alpha is None: + raise ValueError("`rope_scaling`'s factor or alpha field must be have one, got both of none") + if rope_scaling_factor is not None: + if not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1.0, got {rope_scaling_factor}") + if rope_scaling_alpha is not None: + if not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0: + raise ValueError(f"`rope_scaling`'s alpha field must be a float > 1.0, got {rope_scaling_alpha}") + + @property + def num_experts_per_tok(self): + """ + Alias for moe_topk to maintain compatibility with KTransformers. + KTransformers expects num_experts_per_tok (used by Qwen/DeepSeek/Mixtral), + but Hunyuan uses moe_topk for the same concept. + """ + # Handle the case where moe_topk might be a list (per-layer configuration) + if isinstance(self.moe_topk, list): + # Return the first value or the most common value + # KTransformers typically expects a single value + return self.moe_topk[0] if self.moe_topk else 8 + return self.moe_topk \ No newline at end of file diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index 350af73e..f6716c30 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -330,4 +330,97 @@ def get_k_cache(self, layer_idx): return self.k_caches[layer_idx] def get_v_cache(self, layer_idx): + return self.v_caches[layer_idx] + + +class KHunYuanCache(nn.Module): + """ + HunYuan-specific cache implementation for GQA with flashinfer compatibility. + Handles KV cache with proper reshaping for paged attention format. + """ + def __init__( + self, + config: PretrainedConfig, + page_size: int = 256, + dtype=torch.bfloat16, + device=torch.device("cuda:0"), + ): + super().__init__() + self.config = config + self.dtype = dtype + self.device = device + self.page_size = page_size + self.k_caches = [] + self.v_caches = [] + + # HunYuan specific parameters + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + + + def load(self, inference_context: "sched_ext.InferenceContext"): + """ + Load and reshape KV caches from inference context to match flashinfer format. + HunYuan uses GQA with 32 attention heads and 8 KV heads. + """ + print(f"Loading HunYuan cache for {self.config.num_hidden_layers} layers") + + for i in range(self.config.num_hidden_layers): + k_cache_raw = inference_context.k_cache[0][i] + v_cache_raw = inference_context.v_cache[0][i] + + # Check if reshaping is needed based on tensor dimensions + if k_cache_raw.ndim == 2: + total_tokens = k_cache_raw.shape[0] + num_pages = total_tokens // self.page_size + + # Reshape k_cache: [total_tokens, kv_dim] -> [num_pages, page_size, num_kv_heads, head_dim] + k_cache = k_cache_raw.view(num_pages, self.page_size, self.num_kv_heads, self.head_dim) + v_cache = v_cache_raw.view(num_pages, self.page_size, self.num_kv_heads, self.head_dim) + elif k_cache_raw.ndim == 3: + num_pages = k_cache_raw.shape[0] + k_cache = k_cache_raw.view(num_pages, self.page_size, self.num_kv_heads, self.head_dim) + v_cache = v_cache_raw.view(num_pages, self.page_size, self.num_kv_heads, self.head_dim) + elif k_cache_raw.ndim == 4: + k_cache = k_cache_raw + v_cache = v_cache_raw + else: + raise ValueError(f"Unexpected cache dimension: k_cache has {k_cache_raw.ndim} dimensions") + + self.k_caches.append(k_cache) + self.v_caches.append(v_cache) + + if len(self.k_caches) > 0: + self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1] + print(f"Cache loaded: shape {self.k_caches[0].shape}, max_cache_len {self.max_cache_len}") + + def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, + kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor): + """Get page table for paged attention.""" + page_offset = cache_position % self.page_size + page_idx_local = cache_position // self.page_size + query_ids = torch.zeros_like(cache_position) + + for i in range(len(q_indptr) - 1): + start_idx = q_indptr[i] + end_idx = q_indptr[i + 1] + query_ids[start_idx:end_idx] = i + + page_idx = torch.zeros_like(page_idx_local) + for i in range(bsz_tensors[0]): + query_id = query_ids[i] + local_block = page_idx_local[i] + start_block = kv_indptr[query_id] + if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]: + page_idx[i] = kv_indices[start_block + local_block] + + return page_idx, page_offset + + def get_k_cache(self, layer_idx): + """Get k_cache for specific layer.""" + return self.k_caches[layer_idx] + + def get_v_cache(self, layer_idx): + """Get v_cache for specific layer.""" return self.v_caches[layer_idx] \ No newline at end of file diff --git a/ktransformers/models/custom_modeling_hunyuan.py b/ktransformers/models/custom_modeling_hunyuan.py new file mode 100644 index 00000000..7ebfbd49 --- /dev/null +++ b/ktransformers/models/custom_modeling_hunyuan.py @@ -0,0 +1,301 @@ +""" +Custom Hunyuan model implementation for KTransformers with optimized inference +""" + +import math +from dataclasses import dataclass +import torch +import torch.nn as nn +from torch.nn import functional as F +from typing import List, Optional, Tuple, Union +import torch.utils.checkpoint +import numpy as np +import os +from datetime import datetime + +from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput +from ktransformers.models.custom_cache import KHunYuanCache +from ktransformers.models.modeling_hunyuan import HunYuanModel, HunYuanPreTrainedModel +from ktransformers.models.configuration_hunyuan import HunYuanConfig +from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn + +torch.set_grad_enabled(False) +torch.set_default_dtype(torch.bfloat16) + +# Simple debug tensor recording +DEBUG_TENSORS = {} + +def save_debug_tensors(): + if not DEBUG_TENSORS: + return + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + dir_name = f"hunyuan_debug_{timestamp}" + os.makedirs(dir_name, exist_ok=True) + + for key, tensor in DEBUG_TENSORS.items(): + tensor_cpu = tensor.cpu().detach() + if tensor_cpu.dtype == torch.bfloat16: + tensor_cpu = tensor_cpu.float() + array = tensor_cpu.numpy() + np.save(os.path.join(dir_name, f"{key}.npy"), array) + print(f"Saved {key}: shape={array.shape}") + + print(f"Saved {len(DEBUG_TENSORS)} tensors to: {dir_name}") + return dir_name + +try: + import flashinfer +except ImportError: + flashinfer = None + +class KHunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + cache: KHunYuanCache + use_cuda_graph = False + + def __init__( + self, + config: HunYuanConfig, + cache = None, + ): + # ALWAYS print this to verify our file is being used + print("=" * 80) + print("[INIT] KHunYuanMoEV1ForCausalLM from custom_modeling_hunyuan.py is being initialized!") + print("=" * 80) + + super().__init__(config) + self.model = HunYuanModel(config) + self.config = config + self.cache = cache + self.vocab_size = config.vocab_size + # Don't create new lm_head weights - use reference to embed_tokens.weight + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Critical: Tie weights to embed_tokens after creation + self.lm_head.weight = self.model.embed_tokens.weight + self.attn = [None] * 100 + + # Initialize weights and apply final processing + self.post_init() + + def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0): + if flashinfer: + self.attn[cuda_graph_idx] = flashInferAttn( + use_cuda_graph=use_cuda_graph, + max_batch_token=max_batch_token, + max_batch_size=max_batch_size, + max_pages=max_pages, + device=device + ) + + def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, + num_q_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + causal: bool, + q_data_type: torch.dtype, + kv_data_type: torch.dtype, + cuda_graph_idx: int = 0 + ): + """Plan flashinfer attention computation for the batch""" + minibatch = batch.minibatch + if self.attn[cuda_graph_idx] is not None: + self.attn[cuda_graph_idx].plan( + minibatch.q_indptr, + minibatch.kv_indptr, + minibatch.kv_indices, + minibatch.kv_last_page_len, + bsz_tensors, + num_tokens_tensors, + num_q_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + q_data_type=q_data_type, + kv_data_type=kv_data_type + ) + + def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): + features = [] + for i in range(batch.batch_size): + tokens = batch.minibatch.tokens.contiguous() + + # Step-by-step embedding processing for debugging + tokens_cpu = tokens.to(torch.device('cpu')) + embed_output = self.model.embed_tokens(tokens_cpu) + + + # Convert dtype and device + feature = embed_output.to(torch.bfloat16).to(device=device) + + + features.append(feature) + return features + + def forward( + self, + batch: ForwardBatchInput | None = None, + features: List[torch.Tensor] | None = None, + bsz_tensors: torch.Tensor | None = None, + num_tokens_tensors: torch.Tensor | None = None, + page_idx: torch.Tensor | None = None, + page_offset: torch.Tensor | None = None, + cuda_graph_idx: int | None = 0 + ) -> ForwardBatchOutput: + current_stream = torch.cuda.current_stream() + forward_batch_output = ForwardBatchOutput() + + hidden_states = features[0] + + + + if flashinfer and self.attn[cuda_graph_idx] is not None: + self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0]) + + with torch.cuda.stream(current_stream): + # Initialize residual - will be set properly in each layer + residual = None + + # Register layer 0 hooks for detailed tensor tracking + if not hasattr(self, '_layer0_hooks_registered'): + self._layer0_hooks_registered = True + layer0 = self.model.layers[0] + + for i, decode_layer in enumerate(self.model.layers): + # Handle device transfer if needed + if hasattr(self.model, 'transfer_map') and self.model.transfer_map is not None and i in self.model.transfer_map: + prev_stream = torch.cuda.current_stream() + cur_device = self.model.transfer_map[i] + + if not hasattr(self.model, 'stream_device_map'): + self.model.stream_device_map = {} + + if cur_device not in self.model.stream_device_map: + self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) + + torch.cuda.set_device(cur_device) + self.model.stream_device_map[cur_device].wait_stream(prev_stream) + torch.cuda.set_stream(self.model.stream_device_map[cur_device]) + hidden_states = hidden_states.to(self.model.transfer_map[i], non_blocking=True) + + if batch and batch.minibatch.position_ids is not None: + batch.minibatch.position_ids = batch.minibatch.position_ids.to( + self.model.transfer_map[i], non_blocking=True + ) + + # Apply layer normalization + if hasattr(decode_layer, 'input_layernorm'): + if num_tokens_tensors is not None: + # Save current hidden_states as residual before normalization + # This matches standard HunYuan behavior + residual = hidden_states + + # KHunYuanRMSNorm now only does normalization (no residual handling) + hidden_states = decode_layer.input_layernorm( + hidden_states, batch_size_tensor=num_tokens_tensors + ) + + else: + # Standard path without batch_size_tensor + residual = hidden_states + hidden_states = decode_layer.input_layernorm(hidden_states) + + # Self-attention with CLA support + # Track KV states for Cross-Layer Attention + kv_states_for_cla = None + if hasattr(self, '_layer_kv_states'): + # Check if this layer should use CLA (cross-attention) + if hasattr(decode_layer.self_attn, 'attention_type') and decode_layer.self_attn.attention_type == 'cross': + # Find the source layer for KV states (should be a layer where idx % cla_share_factor == 0) + cla_share_factor = getattr(self.config, 'cla_share_factor', 1) + source_layer_idx = (i // cla_share_factor) * cla_share_factor + if source_layer_idx in self._layer_kv_states: + kv_states_for_cla = self._layer_kv_states[source_layer_idx] + + # Call attention with optional kv_states for CLA + # Check if this is KHunYuanAttention (which always returns tuple) + has_kv_states_param = (hasattr(decode_layer.self_attn, 'forward') and + 'kv_states' in decode_layer.self_attn.forward.__code__.co_varnames) + + if has_kv_states_param: + # KHunYuanAttention - pass kv_states and expect tuple return + attn_result = decode_layer.self_attn( + hidden_states, + self.cache, + position_ids=batch.minibatch.position_ids if batch else None, + wrapper=self.attn[cuda_graph_idx] if self.attn[cuda_graph_idx] is not None else None, + bsz_tensors=num_tokens_tensors, + page_idx=page_idx, + page_offset=page_offset, + kv_states=kv_states_for_cla + ) + # KHunYuanAttention always returns (attn_output, (key_states, value_states)) + attn_output, layer_kv_states = attn_result + + # Store KV states for potential CLA use by later layers + if not hasattr(self, '_layer_kv_states'): + self._layer_kv_states = {} + # Only store KV states from layers where idx % cla_share_factor == 0 + if hasattr(self.config, 'use_cla') and self.config.use_cla: + cla_share_factor = getattr(self.config, 'cla_share_factor', 1) + if i % cla_share_factor == 0: + self._layer_kv_states[i] = layer_kv_states + else: + # Other attention types - standard call without kv_states + attn_output = decode_layer.self_attn( + hidden_states, + self.cache, + position_ids=batch.minibatch.position_ids if batch else None, + wrapper=self.attn[cuda_graph_idx] if self.attn[cuda_graph_idx] is not None else None, + bsz_tensors=num_tokens_tensors, + page_idx=page_idx, + page_offset=page_offset + ) + + # Add residual connection after attention (matching standard HunYuan) + if residual is not None: + hidden_states = residual + attn_output + else: + hidden_states = attn_output + + # Post-attention layer norm and MLP + if hasattr(decode_layer, 'post_attention_layernorm'): + # Update residual to current hidden_states before post-attention norm + # This matches standard HunYuan behavior + residual = hidden_states + + # KHunYuanRMSNorm now only does normalization (no residual handling) + hidden_states = decode_layer.post_attention_layernorm( + hidden_states, num_tokens_tensors + ) + + # MLP layer + if hasattr(decode_layer, 'mlp'): + # Keep original 3D tensor format [batch_size, seq_len, hidden_size] for native HunYuan compatibility + mlp_output = decode_layer.mlp( + hidden_states, num_tokens_tensors, cuda_graph_idx + ) + # Add residual connection after MLP + if residual is not None: + hidden_states = residual + mlp_output + else: + hidden_states = mlp_output + + # Final layer norm + hidden_states = self.model.norm(hidden_states) + + # Handle dimension conversion for lm_head (expects 2D input) + if hidden_states.dim() == 3: + # For 3D tensor: [batch_size, seq_len, hidden_size] -> take last token + logits = self.lm_head(hidden_states[:, -1, :], num_tokens_tensors) + else: + # For 2D tensor: [batch_size, hidden_size] -> already the last token + logits = self.lm_head(hidden_states, num_tokens_tensors) + + forward_batch_output = ForwardBatchOutput() + forward_batch_output.logits.append(logits) + + return forward_batch_output \ No newline at end of file diff --git a/ktransformers/models/modeling_hunyuan.py b/ktransformers/models/modeling_hunyuan.py new file mode 100644 index 00000000..0d91b91e --- /dev/null +++ b/ktransformers/models/modeling_hunyuan.py @@ -0,0 +1,1728 @@ +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Tencent/Tencent-Hunyuan-Large/blob/main/License.docx +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" PyTorch HunYuan model.""" + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from ktransformers.models.configuration_hunyuan import HunYuanConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "HunYuanConfig" + + +def topkgating(logits: Tensor, topk: int): + logits = logits.float() + gates = F.softmax(logits, dim=1) + # expert_capacity = topk * gates.shape[0] + expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1]) + num_experts = int(gates.shape[1]) + # Top-k router probability and corresponding expert indices for each token. + # Shape: [tokens_per_group, num_selected_experts]. + expert_gate, expert_index = torch.topk(gates, topk) + expert_mask = F.one_hot(expert_index, num_experts) + # For a given token, determine if it was routed to a given expert. + # Shape: [tokens_per_group, num_experts] + expert_mask_aux = expert_mask.max(dim=-2)[0] + tokens_per_group_and_expert = torch.mean(expert_mask_aux.float(), dim=-2) + router_prob_per_group_and_expert = torch.mean(gates.float(), dim=-2) + l_aux = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) + + gates_s = torch.clamp( + torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps + ) + router_probs = gates / gates_s + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = torch.transpose(expert_index, 0, 1) + # Shape: [num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape(-1) + + # Create mask out of indices. + # Shape: [tokens_per_group * num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) + exp_counts = torch.sum(expert_mask, dim=0).detach() + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [tokens_per_group * num_selected_experts, num_experts]. + token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1 + # Shape: [num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((topk, -1, num_experts)) + # Shape: [tokens_per_group, num_selected_experts, num_experts]. + token_priority = torch.transpose(token_priority, 0, 1) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [tokens_per_group, num_experts]. + token_priority = torch.max(token_priority, dim=1)[0] + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [tokens_per_group, num_experts, expert_capacity]. + valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) + token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) + dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) + valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity) + dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, + # expert_capacity]. + combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) + exp_counts_capacity = torch.sum(dispatch_mask) + exp_capacity_rate = exp_counts_capacity / (logits.shape[0]*topk) + + return [l_aux, exp_capacity_rate], combine_weights, dispatch_mask, exp_counts + + +def top1gating(logits: Tensor, random_routing_dropped_token: bool = False): + """Implements Top1Gating on logits.""" + # everything is in fp32 in this function + logits = logits.float() + gates = F.softmax(logits, dim=1) + capacity = gates.shape[0] + + # Create a mask for 1st's expert per token + # noisy gating + indices1_s = torch.argmax(gates, dim=1) + num_experts = int(gates.shape[1]) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) + + # gating decisions + # exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + exp_counts = torch.sum(mask1, dim=0).detach() + + # Compute l_aux + me = torch.mean(gates, dim=0) + ce = torch.mean(mask1.float(), dim=0) + l_aux = torch.sum(me * ce) * num_experts + mask1_rand = mask1 + + top_idx = torch.topk(mask1_rand, k=capacity, dim=0)[1] + + new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) + mask1 = new_mask1 + mask1_bk = mask1 + if random_routing_dropped_token: + not_full = capacity - new_mask1.sum(dim=0) + sorted_notfull, indices_notfull = torch.sort(not_full, descending=True) + sorted_notfull = sorted_notfull.to(torch.int64) + not_full_experts_ids = torch.repeat_interleave(indices_notfull, sorted_notfull) + shuffle_not_full_ids = torch.randperm(not_full_experts_ids.shape[0]) + not_full_experts_ids = not_full_experts_ids[shuffle_not_full_ids] + indices1_s_after_drop = torch.argmax(new_mask1, dim=1) + # get drop idx + drop_mask = 1 - new_mask1.sum(dim=1) + drop_mask = drop_mask.bool() + drop_idx = drop_mask.nonzero().view(-1) + drop_num = drop_mask.sum().to(torch.int64) + indices1_s_after_drop.scatter_(0, drop_idx, not_full_experts_ids[:drop_num]) + nodrop_mask1 = F.one_hot(indices1_s_after_drop, num_classes=num_experts) + mask1 = nodrop_mask1 + + # Compute locations in capacity buffer + locations1 = torch.cumsum(mask1, dim=0) - 1 + + # Store the capacity location for each token + locations1_s = torch.sum(locations1 * mask1, dim=1) + + # Normalize gate probabilities + mask1_float = mask1.float() + gates = gates * mask1_float + + locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float() # one hot to float + combine_weights = torch.einsum("se,sc->sec", gates, locations1_sc) + + dispatch_mask = combine_weights.bool() + + exp_counts_capacity = torch.sum(mask1_bk) + exp_capacity_rate = exp_counts_capacity / (logits.shape[0]) + return [l_aux, exp_capacity_rate], combine_weights, dispatch_mask, exp_counts + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + warnings.warn( + "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be " + "removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask" + ) + return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + warnings.warn( + "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in " + "v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask" + ) + return AttentionMaskConverter._make_causal_mask( + input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length + ) + + +class HunYuanRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + HunYuanRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(HunYuanRMSNorm) + + +class HunYuanRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + # inv_freq = inv_freq.bfloat16() + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) + + self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).float() + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached or self.inv_freq.dtype != torch.float32: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class HunYuanLinearScalingRotaryEmbedding(HunYuanRotaryEmbedding): + """HunYuanRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class HunYuanDynamicNTKScalingRotaryEmbedding(HunYuanRotaryEmbedding): + """ + HunYuanRotaryEmbedding extended with Dynamic NTK scaling. + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class HunYuanDynamicNTKAlphaRotaryEmbedding(HunYuanRotaryEmbedding): + """ + HunYuanRotaryEmbedding extended with Dynamic NTK scaling. + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_alpha=1.0): + self.scaling_alpha = scaling_alpha + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + base = self.base * self.scaling_alpha ** (self.dim / (self.dim-2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class HunYuanMLP(nn.Module): + def __init__(self, config: HunYuanConfig, layer_idx=None, is_shared_mlp=False): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + if is_shared_mlp: + self.intermediate_size = config.intermediate_size * config.num_shared_expert[0] + else: + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +class HunYuanTopKGate(nn.Module): + def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.moe_topk = config.moe_topk + self.drop_tokens = config.moe_drop_tokens + self.min_capacity = 8 + self.random_routing_dropped_token = config.moe_random_routing_dropped_token + self.wg = nn.Linear(config.hidden_size, config.num_experts, bias=False, dtype=torch.float32) + + def forward(self, hidden_states): + *_, hidden_size = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_size) + if self.wg.weight.dtype == torch.float32: + hidden_states = hidden_states.float() + logits = self.wg(hidden_states) + if self.moe_topk == 1: + gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token) + else: + gate_output = topkgating(logits, self.moe_topk[0]) + + return gate_output + + +class HunYuanMoE(nn.Module): + def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.moe_topk = config.moe_topk + self.num_experts = config.num_experts + if config.use_mixed_mlp_moe: + self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True) + self.gate = HunYuanTopKGate(config, layer_idx=layer_idx) + self.experts = nn.ModuleList( + [HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)] + ) + + def forward(self, hidden_states): + bsz, seq_len, hidden_size = hidden_states.shape + + if self.config.use_mixed_mlp_moe: + hidden_states_mlp = self.shared_mlp(hidden_states) + + l_moe, combine_weights, dispatch_mask, exp_counts = self.gate(hidden_states) + + reshaped_input = hidden_states.reshape(-1, hidden_size) + + dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) + + chunks = dispatched_input.chunk(self.num_experts, dim=0) + expert_outputs = [] + for chunk, expert in zip(chunks, self.experts): + expert_outputs.append(expert(chunk)) + + expert_output = torch.cat(expert_outputs, dim=0) + combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output) + combined_output = combined_output.reshape(bsz, seq_len, hidden_size) + + if self.config.use_mixed_mlp_moe: + output = hidden_states_mlp + combined_output + else: + output = combined_output + + return output + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class HunYuanAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + # layer_idx 从 0 开始 + self.attention_type = 'cross' if config.use_cla and layer_idx % config.cla_share_factor != 0 else 'self' + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.use_qk_norm = config.use_qk_norm + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + if self.attention_type == 'self': + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + if self.use_qk_norm: + self.query_layernorm = HunYuanRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = HunYuanRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = HunYuanRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + scaling_alpha = self.config.rope_scaling["alpha"] + if scaling_type == "linear": + self.rotary_emb = HunYuanLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + if scaling_alpha: + self.rotary_emb = HunYuanDynamicNTKAlphaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_alpha=scaling_alpha, + base=self.rope_theta, + ) + else: + self.rotary_emb = HunYuanDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + kv_states: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use " + "`attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + if self.attention_type == "cross" and kv_states is not None and isinstance(kv_states, tuple): + orig_key_states, orig_value_states = kv_states + key_states, value_states = kv_states + else: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + orig_key_states, orig_value_states = key_states, value_states + + else: + query_states = self.q_proj(hidden_states) + if self.attention_type == "cross" and kv_states is not None and isinstance(kv_states, tuple): + orig_key_states, orig_value_states = kv_states + key_states, value_states = kv_states + else: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + orig_key_states, orig_value_states = key_states, value_states + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if self.use_qk_norm: + query_states = self.query_layernorm(query_states) + key_states = self.key_layernorm(key_states) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value, (orig_key_states, orig_value_states) + + +class HunYuanFlashAttention2(HunYuanAttention): + """ + HunYuan flash attention module. This module inherits from `HunYuanAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + kv_states: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # HunYuanFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use " + "`attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + if self.attention_type == "cross" and kv_states is not None and isinstance(kv_states, tuple): + orig_key_states, orig_value_states = kv_states + key_states, value_states = kv_states + else: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + orig_key_states, orig_value_states = key_states, value_states + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if self.use_qk_norm: + query_states = self.query_layernorm(query_states) + key_states = self.key_layernorm(key_states) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (HunYuanRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value, (orig_key_states, orig_value_states) + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class HunYuanSdpaAttention(HunYuanAttention): + """ + HunYuan attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `HunYuanAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt + to SDPA API. + """ + + # Adapted from HunYuanAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + kv_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + logger.warning_once( + 'HunYuanModel is using HunYuanSdpaAttention,' + 'but `torch.nn.functional.scaled_dot_product_attention`' + 'does not support `output_attentions=True`. Falling back to the manual attention implementation, ' + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. ' + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + if self.attention_type == "cross" and kv_states is not None and isinstance(kv_states, tuple): + orig_key_states, orig_value_states = kv_states + key_states, value_states = kv_states + else: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + orig_key_states, orig_value_states = key_states, value_states + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if self.use_qk_norm: + query_states = self.query_layernorm(query_states) + key_states = self.key_layernorm(key_states) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with + # custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a + # causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value, (orig_key_states, orig_value_states) + + +HUNYUAN_ATTENTION_CLASSES = { + "eager": HunYuanAttention, + "flash_attention_2": HunYuanFlashAttention2, + "sdpa": HunYuanSdpaAttention, +} + + +class HunYuanDecoderLayer(nn.Module): + def __init__(self, config: HunYuanConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = HUNYUAN_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + if config.num_experts > 1: + self.mlp = HunYuanMoE(config, layer_idx=layer_idx) + else: + self.mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) + self.input_layernorm = HunYuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = HunYuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + kv_states: Optional[Tuple[torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + kv_states (`Tuple(torch.FloatTensor)`, *optional*): Used when CLA is enabled, + key and value states from past attention blocks + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use " + "`attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value, kv_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + kv_states=kv_states, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + outputs += (kv_states,) + + return outputs + + +HUNYUAN_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`HunYuanConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare HunYuan Model outputting raw hidden-states without any specific head on top.", + HUNYUAN_START_DOCSTRING, +) +class HunYuanPreTrainedModel(PreTrainedModel): + config_class = HunYuanConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["HunYuanDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +HUNYUAN_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare HunYuan Model outputting raw hidden-states without any specific head on top.", + HUNYUAN_START_DOCSTRING, +) +class HunYuanModel(HunYuanPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HunYuanDecoderLayer`] + + Args: + config: HunYuanConfig + """ + + def __init__(self, config: HunYuanConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._use_sdpa = config._attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = HunYuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.cla = config.use_cla + self.cla_share_factor = config.cla_share_factor + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # Fix lora with gradient checkpointing training + if self.training and inputs_embeds.is_leaf: + inputs_embeds.requires_grad = True + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + prev_kv_states = None + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + prev_kv_states, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + kv_states=prev_kv_states + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + kv_states = layer_outputs[-1] + + if self.cla and layer_idx % self.cla_share_factor == 0: + prev_kv_states = kv_states + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: HunYuanConfig): + super().__init__(config) + self.model = HunYuanModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_cache_shape() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1]:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The HunYuan Model transformer with a sequence classification head on top (linear layer). + + [`HunYuanForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + HUNYUAN_START_DOCSTRING, +) +class HunYuanForSequenceClassification(HunYuanPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = HunYuanModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py index 968c7b98..424415b9 100644 --- a/ktransformers/operators/RoPE.py +++ b/ktransformers/operators/RoPE.py @@ -22,6 +22,10 @@ yarn_linear_ramp_mask, yarn_find_correction_range ) +from ktransformers.models.modeling_hunyuan import ( + HunYuanRotaryEmbedding, + HunYuanDynamicNTKAlphaRotaryEmbedding +) from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader from ktransformers.util.utils import InferenceState @@ -527,4 +531,65 @@ def forward(self, x, position_ids): cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) \ No newline at end of file + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class KHunYuanRotaryEmbedding(BaseInjectedModule, HunYuanDynamicNTKAlphaRotaryEmbedding): + """HunYuan RoPE with KTransformers optimizations - simplified for HunYuanDynamicNTKAlphaRotaryEmbedding""" + + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + generate_device: str = "cuda", + prefill_device: str = "cuda", + **kwargs, + ): + BaseInjectedModule.__init__( + self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs + ) + + self.generate_device = generate_device + self.prefill_device = prefill_device + + self.orig_module.__init__( + dim=orig_module.dim, + max_position_embeddings=orig_module.max_position_embeddings, + base=orig_module.base, + device=None, # Will be set in load() + scaling_alpha=orig_module.scaling_alpha if hasattr(orig_module, 'scaling_alpha') else 1000.0 + ) + + + def forward(self, x, position_ids=None, seq_len=None): + """Forward pass with KTransformers compatibility + + Args: + x: Input tensor + position_ids: Position indices (KTransformers mode) + seq_len: Sequence length (native Hunyuan mode) + """ + # Convert position_ids to seq_len if needed + if position_ids is not None and seq_len is None: + # We need seq_len = max(position_ids) + 1 to avoid index out of bounds + # But avoid .item() for CUDA Graph compatibility + seq_len = self.max_seq_len_cached + elif seq_len is None: + seq_len = x.shape[-2] if x.ndim > 2 else x.shape[0] + + # Call parent's forward with seq_len (no .item() needed) + return super().forward(x, seq_len=seq_len) + + def load(self): + """Reinitialize the module on the correct device after loading""" + BaseInjectedModule.load(self) + + self.orig_module.__init__( + dim=self.orig_module.dim, + max_position_embeddings=self.orig_module.max_position_embeddings, + base=self.orig_module.base, + device=self.generate_device, + scaling_alpha=self.orig_module.scaling_alpha if hasattr(self.orig_module, 'scaling_alpha') else 1000.0 + ) \ No newline at end of file diff --git a/ktransformers/operators/balance_serve_attention.py b/ktransformers/operators/balance_serve_attention.py index ecb614f8..4b7ef8fa 100644 --- a/ktransformers/operators/balance_serve_attention.py +++ b/ktransformers/operators/balance_serve_attention.py @@ -11,6 +11,7 @@ from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention from ktransformers.models.modeling_smallthinker import SmallthinkerAttention from ktransformers.models.modeling_glm4_moe import Glm4MoeAttention +from ktransformers.models.modeling_hunyuan import HunYuanAttention from ktransformers.models.modeling_qwen3_next import Qwen3NextGatedDeltaNet from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule @@ -19,7 +20,7 @@ from transformers.configuration_utils import PretrainedConfig from flashinfer import BatchMLAPagedAttentionWrapper from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn -from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache +from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache, KHunYuanCache logger = logging.getLogger("attention") # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -645,6 +646,163 @@ def forward(self, attn_output = self.o_proj(attn_output.view(q_len, self.config.num_attention_heads * self.head_dim), bsz_tensors) return attn_output + + +class KHunYuanAttention(BaseInjectedModule, HunYuanAttention): + """HunYuan attention for balance serve mode with GQA and CLA support + + Key features: + 1. GQA (Grouped Query Attention): 32 query heads, 8 KV heads + - Flashinfer handles the repeat internally for memory efficiency + - KV cache stores only 8 heads, expanded during attention computation + 2. CLA (Cross-Layer Attention): Some layers reuse KV from previous layers + - Reduces computation by sharing KV states across layers + """ + + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__( + self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs + ) + # Store layer_idx before calling HunYuanAttention init + layer_idx = orig_module.layer_idx + # Initialize HunYuanAttention components manually to avoid __setattr__ conflicts + self.config = config + self.layer_idx = layer_idx + self.prefill_device = prefill_device + self.generate_device = generate_device + + # Initialize attention parameters from config + # HunYuan uses 32 attention heads and 8 KV heads (GQA with 4x groups) + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads # 32 for HunYuan + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads # 8 for HunYuan + self.num_key_value_groups = self.num_heads // self.num_key_value_heads # 4 groups + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + # Cross-Layer Attention (CLA) configuration + self.use_cla = getattr(config, "use_cla", False) + self.cla_share_factor = getattr(config, "cla_share_factor", 1) + # Determine attention type based on CLA configuration + if self.use_cla and layer_idx is not None and layer_idx % self.cla_share_factor != 0: + self.attention_type = 'cross' + else: + self.attention_type = 'self' + + if getattr(config, "use_qk_norm", False): + self.use_qk_norm = True + else: + self.use_qk_norm = False + + # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb + def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Following original Hunyuan implementation pattern. + """ + # Use original Hunyuan pattern: cos/sin indexed by position_ids then unsqueezed + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (self.rotate_half(q) * sin) + k_embed = (k * cos) + (self.rotate_half(k) * sin) + return q_embed, k_embed + + def rotate_half(self, x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KHunYuanCache, + position_ids: torch.Tensor, + wrapper: flashInferAttn, + bsz_tensors: torch.Tensor, + page_idx: torch.Tensor, + page_offset: torch.Tensor, + kv_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Balance serve forward implementation for HunYuan attention with CLA support + + Args: + hidden_states: Input hidden states + kv_cache: KV cache for current layer + position_ids: Position IDs for RoPE + wrapper: FlashInfer attention wrapper + bsz_tensors: Batch size tensors + page_idx: Page indices for paged attention + page_offset: Page offsets for paged attention + kv_states: Optional pre-computed KV states from previous layers (for CLA) + """ + q_len, _ = hidden_states.size() + + # Compute query projections (always from current hidden_states) + query_states = self.q_proj(hidden_states, bsz_tensors) + + # Handle Cross-Layer Attention (CLA) + if self.attention_type == "cross" and kv_states is not None: + # Use pre-computed KV states from a previous layer + key_states, value_states = kv_states + else: + key_states = self.k_proj(hidden_states, bsz_tensors) + value_states = self.v_proj(hidden_states, bsz_tensors) + + # Reshape for attention heads + query_states = query_states.view(q_len, self.num_heads, self.head_dim) + # Only reshape if we computed new KV (not using CLA) + if self.attention_type != "cross" or kv_states is None: + key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + + # Apply RoPE following KTransformers pattern - avoid .item() for CUDA Graph compatibility + if hasattr(self, 'rotary_emb'): + # Pass position_ids directly to rotary_emb like other models (Qwen, DeepSeek) + cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0)) + # Apply RoPE with proper unsqueeze for dimension compatibility + query_states, key_states = self.apply_rotary_pos_emb( + query_states.unsqueeze(0), + key_states.unsqueeze(0), + cos, + sin, + position_ids.unsqueeze(0), + unsqueeze_dim=2 + ) + query_states = query_states.squeeze(0) + key_states = key_states.squeeze(0) + + # Apply QK normalization if configured + if self.use_qk_norm: + query_states = self.query_layernorm(query_states) + key_states = self.key_layernorm(key_states) + + # Store original KV states for CLA sharing + orig_key_states = key_states + orig_value_states = value_states + + # Get k_cache and v_cache for current layer + k_cache = kv_cache.get_k_cache(self.layer_idx) + v_cache = kv_cache.get_v_cache(self.layer_idx) + + # It expects key_states and value_states with original num_kv_heads (8) + # Query_states with num_heads (32) + attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states) + + attn_output = attn_output.view(q_len, self.num_heads * self.head_dim) + attn_output = self.o_proj(attn_output, bsz_tensors) + + return attn_output, (orig_key_states, orig_value_states) from ktransformers.models.modeling_qwen3_next import apply_mask_to_padding_states import torch.nn.functional as F diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index df2088c1..d9aa5171 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -185,12 +185,26 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = N hidden_type = 1 # fp16 else: hidden_type = 30 # bf16 + + # Unified parameter handling for different model types + # This will not influence other models + if hasattr(self.config, 'model_type') and self.config.model_type == "hunyuan_v1_moe": + # Hunyuan uses moe_topk instead of num_experts_per_tok + moe_topk = getattr(self.config, 'moe_topk', 8) + experts_per_tok = moe_topk[0] if isinstance(moe_topk, list) else moe_topk + moe_intermediate_size = self.config.moe_intermediate_size + if isinstance(moe_intermediate_size, list): + moe_intermediate_size = moe_intermediate_size[0] + else: + # Other models use num_experts_per_tok directly + experts_per_tok = self.config.num_experts_per_tok + moe_intermediate_size = self.config.moe_intermediate_size if self.backend == "llamafile": moe_config = MOEConfig( n_routed_experts, - self.config.num_experts_per_tok, + experts_per_tok, self.config.hidden_size, - self.config.moe_intermediate_size, + moe_intermediate_size, 64, 10, 1024, @@ -211,9 +225,9 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = N assert self.down_type == GGMLQuantizationType.BF16 moe_config = AMX_MOEConfig( n_routed_experts, - self.config.num_experts_per_tok, + experts_per_tok, self.config.hidden_size, - self.config.moe_intermediate_size, + moe_intermediate_size, max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size, self.config.hidden_act == 'silu', gate_ptr, @@ -230,9 +244,9 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = N assert self.down_type == GGMLQuantizationType.BF16 moe_config = AMX_MOEConfig( n_routed_experts, - self.config.num_experts_per_tok, + experts_per_tok, self.config.hidden_size, - self.config.moe_intermediate_size, + moe_intermediate_size, max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size, self.config.hidden_act == 'silu', gate_ptr, @@ -242,8 +256,8 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = N self.moe = AMXInt8_MOE(moe_config) self.cpu_infer.submit(self.moe.load_weights()) self.cpu_infer.sync() - # print(n_routed_experts, hidden_size, moe_intermediate_size) - num_experts_per_tok = self.config.num_experts_per_tok + # Use the unified experts_per_tok value for buffer allocation + num_experts_per_tok = experts_per_tok if warmup: self.cpu_infer.submit(self.moe.warm_up()) self.cpu_infer.sync() @@ -297,9 +311,7 @@ def sync_for_one_decode(self, cuda_graph_idx=0): KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) return KExpertsCPU.output_gpu_map[self.out_device] - def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0): - # generate, capture and run cuda graph - # print(expert_ids) + def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0): if bsz_tensor is None and (not torch.xpu.is_available() or input_tensor.size(0) > 1): bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32) if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): @@ -338,7 +350,8 @@ def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph bsz_tensor = bsz_tensor.contiguous().cpu() output = torch.empty_like(input_tensor).contiguous() self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr())) - self.cpu_infer.sync() + self.cpu_infer.sync() + return output.to(device=object.__getattribute__(self, "out_device")) def unload(self): @@ -734,6 +747,7 @@ def set_inference_mode(self, mode: InferenceState): from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock from ktransformers.models.modeling_smallthinker import SmallthinkerMoeBlock from ktransformers.models.modeling_glm4_moe import Glm4MoeMoE +from ktransformers.models.modeling_hunyuan import HunYuanMoE from ktransformers.models.modeling_qwen3_next import Qwen3NextSparseMoeBlock @@ -1302,7 +1316,8 @@ def unload(self): def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0): if self.mode == InferenceState.GENERATE: assert self.generate_experts is not None, "generate_experts is None" - return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) + result = self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) + return result elif self.mode == InferenceState.PREFILL: assert self.prefill_experts is not None, "prefill_experts is None" return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) @@ -1941,6 +1956,118 @@ def moe_infer(self, x, topk_ids, topk_weight): return final_out +class KHunyuanGateWrapper(nn.Module): + """Wrapper for HunYuan gate that converts output to standard MoE format""" + + def __init__(self, original_gate, config): + super().__init__() + self.original_gate = original_gate + self.config = config + self.moe_topk = config.moe_topk[0] if isinstance(config.moe_topk, list) else config.moe_topk + + def forward(self, hidden_states): + """Forward pass that converts HunYuan format to standard MoE format""" + # Debug: Track gate execution during problematic warmup + batch_size = hidden_states.shape[0] if hidden_states.dim() > 1 else 1 + l_moe, combine_weights, dispatch_mask, exp_counts = self.original_gate(hidden_states) + + + mask = dispatch_mask.bool() + + # Aggregate weights per expert (sum over capacity dimension) + # If combine_weights might be non-zero at inactive positions, multiply by mask to zero them out + expert_weights = (combine_weights * mask.float()).sum(dim=-1) # [batch_size, num_experts] + + # Direct top-k selection (fixed shape, CUDA Graph friendly) + topk_weight, topk_ids = torch.topk(expert_weights, k=self.moe_topk, dim=1, largest=True, sorted=True) + + # Handle case where some tokens have fewer valid experts than moe_topk + # IMPORTANT: Use 0 instead of -1 for invalid entries to avoid uint64_t overflow in C++ + # The weight being 0 ensures these entries won't contribute to the output + valid = topk_weight > 0 + topk_ids = torch.where(valid, topk_ids, torch.zeros_like(topk_ids)) # Use 0 instead of -1 + topk_weight = torch.where(valid, topk_weight, torch.zeros_like(topk_weight)) + + return topk_ids, topk_weight + + +class KHunyuanMoE(BaseInjectedModule, HunYuanMoE): + """HunYuan MoE with KTransformers optimizations""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Wrap the gate to output standard format + if hasattr(self, 'gate'): + self.gate = KHunyuanGateWrapper(self.gate, self.config) + + def forward(self, hidden_states: torch.Tensor, bsz_tensor=None, cuda_graph_idx=0) -> torch.Tensor: + """HunYuan MoE forward pass following KQwen3MoeSparseMoeBlockV2 structure""" + + orig_shape = hidden_states.shape + + # Compute shared MLP if using mixed MLP mode (HunYuan specific) + shared_output = None + if getattr(self.config, 'use_mixed_mlp_moe', True) and hasattr(self, 'shared_mlp'): + + if bsz_tensor is not None: + shared_output = self.shared_mlp(hidden_states, bsz_tensor) + else: + shared_output = self.shared_mlp(hidden_states) + + topk_ids, topk_weight = self.gate(hidden_states) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + # Determine which path to use based on expert type + if isinstance(self.experts, KExpertsBase): + # Use KTransformers optimized path with standard format + y = self.moe_on_cpuinfer(hidden_states, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) + + elif hidden_states.size(0) > 10: + # Use batch processing for larger inputs + y = self.moe_infer(hidden_states, topk_ids, topk_weight) + else: + # Use simple iteration for small inputs + y = self.moe_infer_simple(hidden_states, topk_ids, topk_weight) + + y = y.view(*orig_shape) + + # Combine shared and expert outputs + if shared_output is not None: + return shared_output + y + else: + return y + + @torch.no_grad() + def moe_on_cpuinfer(self, x, topk_ids, topk_weight, bsz_tensor=None, cuda_graph_idx=0): + """Use KTransformers experts with standard MoE format""" + + # Use standard KTransformers experts interface - pass all 4 required parameters + # KTransformersExpertsV2.forward requires (input_tensor, expert_ids, weights, bsz_tensor) + outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) + + return outs + + @torch.no_grad() + def moe_infer_simple(self, x, topk_ids, topk_weight): + """Simple routing for small batches using standard MoE format""" + outs = torch.zeros_like(x) + for token_idx in range(topk_ids.size(0)): + for expert_idx in range(topk_ids.size(1)): + expert_id = topk_ids[token_idx, expert_idx].item() + weight = topk_weight[token_idx, expert_idx].item() + # Skip if weight is 0 (invalid expert) + if weight > 0: + expert = self.experts[expert_id] + outs[token_idx] += expert.forward(x[token_idx]) * weight + return outs + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + """Batch routing using standard MoE format""" + # Use simple implementation for now - can be optimized later + return self.moe_infer_simple(x, topk_ids, topk_weight) + class KQwen3NextSparseMoeBlockV2(BaseInjectedModule, Qwen3NextSparseMoeBlock): def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0): @@ -2064,4 +2191,4 @@ def moe_infer(self, x, topk_ids, topk_weight): .sum(dim=1) .type(new_x.dtype) ) - return final_out \ No newline at end of file + return final_out diff --git a/ktransformers/operators/layernorm.py b/ktransformers/operators/layernorm.py index 7ca7e1d5..c1be082d 100644 --- a/ktransformers/operators/layernorm.py +++ b/ktransformers/operators/layernorm.py @@ -31,6 +31,7 @@ from ktransformers.models.modeling_qwen3_next import Qwen3NextRMSNorm from ktransformers.models.modeling_smallthinker import SmallthinkerRMSNorm from ktransformers.models.modeling_glm4_moe import Glm4MoeRMSNorm +from ktransformers.models.modeling_hunyuan import HunYuanRMSNorm from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader if not torch.xpu.is_available(): @@ -344,6 +345,71 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = rms_norm_forward(self, x) return output.to(x.dtype) + def load(self): + BaseInjectedModule.load(self) + if self.weight.dtype not in [torch.float32, torch.float16]: + self.weight = self.weight.float() + + +class KHunYuanRMSNorm(HunYuanRMSNorm, BaseInjectedModule): + """HunYuan RMSNorm with KTransformers optimizations + + Unlike Qwen/DeepSeek models, HunYuan's LayerNorm is designed to ONLY normalize, + without handling residual connections. Residual additions happen explicitly + in the HunYuanDecoderLayer after attention and MLP blocks. + """ + + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + # Use the same pattern as other RMSNorm classes - call original module's __init__ + # For QK normalization, use the weight shape to determine the correct size + # orig_module.weight.shape[0] gives us the actual dimension (head_dim=128 for QK norm, hidden_size=4096 for regular) + actual_size = orig_module.weight.shape[0] + self.orig_module.__init__(actual_size, orig_module.variance_epsilon) + + def forward( + self, + x: torch.Tensor, + batch_size_tensor: torch.Tensor = None, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Forward pass for HunYuan RMSNorm - pure normalization only + + IMPORTANT: HunYuan's architecture handles residual connections externally. + This RMSNorm should ONLY normalize, never add residuals. + The residual parameter is ignored to maintain compatibility with the interface. + """ + # Explicitly ignore residual parameter (kept for interface compatibility) + _ = residual + + if batch_size_tensor is None: + return self.forward_native(x) + + # Use flashinfer optimized rmsnorm for pure normalization + # We explicitly DO NOT use fused_add_rmsnorm here + # out = rmsnorm(x, self.weight.data, batch_size_tensor, self.variance_epsilon) + out = self.forward_native(x) + + # Return normalized output only (no residual handling) + return out + + def forward_native(self, hidden_states): + """Native PyTorch implementation as fallback""" + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # Ensure weight matches input dtype to prevent type promotion to float32 + weight = self.weight.to(input_dtype) if self.weight.dtype != input_dtype else self.weight + return weight * hidden_states.to(input_dtype) + def load(self): BaseInjectedModule.load(self) if self.weight.dtype not in [torch.float32, torch.float16]: diff --git a/ktransformers/operators/mlp.py b/ktransformers/operators/mlp.py index 6d3e8120..7ea5d4f8 100644 --- a/ktransformers/operators/mlp.py +++ b/ktransformers/operators/mlp.py @@ -7,6 +7,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeMLP from ktransformers.models.modeling_smallthinker import SmallthinkerDenseMlpBlock from ktransformers.models.modeling_glm4_moe import Glm4MoeMLP +from ktransformers.models.modeling_hunyuan import HunYuanMLP class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule): def __init__(self, key: str, @@ -67,4 +68,28 @@ def __init__(self, self.orig_module.__init__(orig_module.config, orig_module.hidden_size, orig_module.intermediate_size) def forward(self, x, bsz_tensor): down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor) - return down_proj \ No newline at end of file + return down_proj + + +class KHunYuanMLP(HunYuanMLP, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.intermediate_size) + + def forward(self, x, bsz_tensor=None): + if bsz_tensor is not None: + # If batch tensor is provided, use it for optimized computation + gate_proj = self.gate_proj(x, bsz_tensor) + up_proj = self.up_proj(x, bsz_tensor) + down_proj = self.down_proj(self.act_fn(gate_proj) * up_proj, bsz_tensor) + return down_proj + else: + return HunYuanMLP.forward(self, x) \ No newline at end of file diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index e136b57a..d97eb810 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -65,6 +65,10 @@ LlamaRMSNorm, LlamaRotaryEmbedding, ) +from ktransformers.models.modeling_hunyuan import ( + HunYuanDecoderLayer, + HunYuanMoE, +) if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -1373,3 +1377,293 @@ def _update_causal_mask( ) return causal_mask + +class KHunyuanModel(BaseInjectedModule): + """ + HunYuan MoE Model with KTransformers optimizations + + Transformer decoder consisting of *config.num_hidden_layers* layers. + Each layer is a [`HunYuanDecoderLayer`] with MoE support. + + Args: + config: HunYuanConfig + """ + + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + device: str = "cuda", + per_layer_prefill_intput_threshold: int = 30000, + transfer_map: dict = None, + **kwargs, + ): + BaseInjectedModule.__init__( + self, key, gguf_loader, config, orig_module, device, **kwargs + ) + self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold + self.transfer_map = transfer_map + self.stream_device_map = dict() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + per_layer_prefill_intput_threshold: Optional[int] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + + if per_layer_prefill_intput_threshold is None: + per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold + + per_layer_prefill_flag = False + seq_length = ( + inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1) + ) + + if ( + per_layer_prefill_intput_threshold + and per_layer_prefill_intput_threshold < seq_length + ): + per_layer_prefill_flag = True + for layer in self.layers: + self.load_layer_to(layer, InferenceState.UNLOAD) + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + use_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + input_ids = input_ids.to("cpu") + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds.to("cuda") + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + + # Create position embeddings for HunYuan if needed + position_embeddings = None + + # Decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for i, decoder_layer in enumerate(self.layers): + if self.transfer_map is not None and i in self.transfer_map: + prev_stream = torch.cuda.current_stream() + cur_device = self.transfer_map[i] + if cur_device not in self.stream_device_map: + self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) + torch.cuda.set_device(cur_device) + self.stream_device_map[cur_device].wait_stream(prev_stream) + torch.cuda.set_stream(self.stream_device_map[cur_device]) + hidden_states = hidden_states.to( + self.transfer_map[i], non_blocking=True + ) + causal_mask = ( + causal_mask.to(self.transfer_map[i], non_blocking=True) + if causal_mask is not None + else None + ) + position_ids = ( + position_ids.to(self.transfer_map[i], non_blocking=True) + if position_ids is not None + else None + ) + cache_position = ( + cache_position.to(self.transfer_map[i], non_blocking=True) + if cache_position is not None + else None + ) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + ) + else: + if per_layer_prefill_flag: + self.load_layer_to(decoder_layer, InferenceState.PREFILL) + torch.cuda.empty_cache() + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + if per_layer_prefill_flag: + self.load_layer_to(decoder_layer, InferenceState.UNLOAD) + torch.cuda.empty_cache() + + hidden_states = layer_outputs[0] + + if use_cache and len(layer_outputs) > 1: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits and layer_outputs[-1] is not None: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + if per_layer_prefill_flag: + per_layer_prefill_flag = False + for layer in self.layers: + self.load_layer_to(layer, InferenceState.GENERATE) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + if next_decoder_cache is not None: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + else: + next_cache = past_key_values + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_router_logits, + ] + if v is not None + ) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + def load_layer_to(self, layer: HunYuanDecoderLayer, target: InferenceState): + """Load HunYuan layer to target device/state""" + assert isinstance( + layer, HunYuanDecoderLayer + ), "module should be HunYuanDecoderLayer" + + device = "cpu" if target == InferenceState.UNLOAD else "cuda" + + # Attention components + layer.self_attn.q_proj.set_inference_mode(target) + layer.self_attn.k_proj.set_inference_mode(target) + layer.self_attn.v_proj.set_inference_mode(target) + layer.self_attn.o_proj.set_inference_mode(target) + layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(device) + + # MLP/MoE components + if isinstance(layer.mlp, HunYuanMoE): + # HunYuan MoE with mixed architecture + layer.mlp.gate.set_inference_mode(target) + layer.mlp.experts.set_inference_mode(target) + + # Shared MLP components + if hasattr(layer.mlp, 'shared_mlp'): + layer.mlp.shared_mlp.gate_proj.set_inference_mode(target) + layer.mlp.shared_mlp.up_proj.set_inference_mode(target) + layer.mlp.shared_mlp.down_proj.set_inference_mode(target) + layer.mlp.shared_mlp.act_fn.to(device) + else: + # Regular MLP + layer.mlp.gate_proj.set_inference_mode(target) + layer.mlp.up_proj.set_inference_mode(target) + layer.mlp.down_proj.set_inference_mode(target) + layer.mlp.act_fn.to(device) + + # Layer normalization + layer.input_layernorm.to(device) + layer.post_attention_layernorm.to(device) diff --git a/ktransformers/optimize/optimize_rules/Hunyuan-serve-amx.yaml b/ktransformers/optimize/optimize_rules/Hunyuan-serve-amx.yaml new file mode 100644 index 00000000..12b18454 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/Hunyuan-serve-amx.yaml @@ -0,0 +1,102 @@ +# HunYuan MoE v1 with AMX Balance Serve Configuration +# Optimized configuration for balance serve mode with CPU/GPU hybrid deployment + +# RoPE configuration - simplified for HunYuanDynamicNTKAlphaRotaryEmbedding only +- match: + class: ktransformers.models.modeling_hunyuan.HunYuanDynamicNTKAlphaRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.KHunYuanRotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^lm_head$" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "VLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\..*$" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_hunyuan.HunYuanMoE + replace: + class: ktransformers.operators.experts.KHunyuanMoE + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExpertsV2 + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + backend: "AMXBF16" # or "AMXBF16" or "llamafile" (default) + recursive: False + +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.balance_serve_attention.KHunYuanAttention + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KHunyuanModel" + kwargs: + per_layer_prefill_intput_threshold: 0 + +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + class: ktransformers.models.modeling_hunyuan.HunYuanRMSNorm + replace: + class: ktransformers.operators.layernorm.KHunYuanRMSNorm + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + class: ktransformers.models.modeling_hunyuan.HunYuanMLP + replace: + class: ktransformers.operators.mlp.KHunYuanMLP + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + class: ktransformers.models.modeling_hunyuan.HunYuanAttention + replace: + class: ktransformers.operators.attention.KHunYuanAttention + kwargs: + generate_device: "cuda" + prefill_device: "cuda" \ No newline at end of file diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index bfbe4604..24c706bf 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -152,7 +152,7 @@ def parse_args(self): raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.") - if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" or model_config.architectures[0] == "SmallThinkerForCausalLM" or model_config.architectures[0] == "Glm4MoeForCausalLM": + if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" or model_config.architectures[0] == "HunYuanMoEV1ForCausalLM"or model_config.architectures[0] == "SmallThinkerForCausalLM" or model_config.architectures[0] == "Glm4MoeForCausalLM": args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim args.architectures = model_config.architectures[0] else: diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py index c5122f42..36e4680f 100644 --- a/ktransformers/server/backend/interfaces/balance_serve.py +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -1,5 +1,5 @@ from typing import Any, AsyncIterator, List, Optional, Set -from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache +from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache, KHunYuanCache from transformers import ( AutoTokenizer, AutoConfig, @@ -26,6 +26,7 @@ from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM +from ktransformers.models.custom_modeling_hunyuan import KHunYuanMoEV1ForCausalLM from ktransformers.models.custom_modeling_qwen3_next import KQwen3NextForCausalLM from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig from ktransformers.models.configuration_smallthinker import SmallthinkerConfig @@ -68,6 +69,7 @@ "Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml", "SmallThinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml", "Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml", + "HunYuanMoEV1ForCausalLM": ktransformer_rules_dir + "Hunyuan-serve.yaml", "Qwen3NextForCausalLM": ktransformer_rules_dir + "Qwen3Next-serve.yaml", } @@ -119,7 +121,7 @@ class Engine: model_runner: ModelRunner sampler: Sampler query_manager: QueryManager - cache: KDeepSeekV3Cache | KGQACache + cache: KDeepSeekV3Cache | KGQACache | KHunYuanCache def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None): self.args = args @@ -178,6 +180,9 @@ def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue self.cache = KGQACache(config, self.args.page_size) self.model = KQwen3NextForCausalLM(config, self.cache) + elif config.architectures[0] == "HunYuanMoEV1ForCausalLM": + self.cache = KHunYuanCache(config, self.args.page_size) + self.model = KHunYuanMoEV1ForCausalLM(config, self.cache) context = zmq.Context() @@ -228,7 +233,8 @@ def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue self.block_num = inference_context.k_cache[0].size(1) self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num) #@TODO add config - if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM" or config.architectures[0] == "Qwen3NextForCausalLM": + + if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "HunYuanMoEV1ForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM" or config.architectures[0] == "Qwen3NextForCausalLM": self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num) else: self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num) diff --git a/ktransformers/server/balance_serve/inference/model_runner.py b/ktransformers/server/balance_serve/inference/model_runner.py index 5e5f32dd..38d55923 100644 --- a/ktransformers/server/balance_serve/inference/model_runner.py +++ b/ktransformers/server/balance_serve/inference/model_runner.py @@ -31,6 +31,7 @@ from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM +from ktransformers.models.custom_modeling_hunyuan import KHunYuanMoEV1ForCausalLM from ktransformers.models.custom_modeling_qwen3_next import KQwen3NextForCausalLM from ktransformers.server.balance_serve.inference.query_manager import QueryManager from ktransformers.server.balance_serve.settings import sched_ext @@ -56,7 +57,8 @@ def generate_cuda_graphs(chunk_size: int) -> list: class ModelRunner: """A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile.""" - model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM | KQwen3NextForCausalLM + + model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KHunYuanMoEV1ForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM | KQwen3NextForCausalLM input: ForwardBatchInput | list[ForwardBatchInput] output: ForwardBatchOutput @@ -96,7 +98,8 @@ def model_attn_plan(self, batch, cuda_graph_idx=0): num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) - elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM) or isinstance(self.model, KQwen3NextForCausalLM): + + elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KHunYuanMoEV1ForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM) or isinstance(self.model, KQwen3NextForCausalLM): self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf, num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads, head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads, @@ -127,7 +130,9 @@ def capture_graphs(cuda_graph_idx): num_tokens = self.features_buf[i][0].size(0) print("capturing cuda graph", batch_size, num_tokens) - if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM) or isinstance(self.model, KQwen3NextForCausalLM): + + if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KHunYuanMoEV1ForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM) or isinstance(self.model, KQwen3NextForCausalLM): + self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens) self.bsz_tensor_buf[0] = batch_size diff --git a/ktransformers/server/balance_serve/sched_rpc.py b/ktransformers/server/balance_serve/sched_rpc.py index a2759ae4..ac0bbe7d 100644 --- a/ktransformers/server/balance_serve/sched_rpc.py +++ b/ktransformers/server/balance_serve/sched_rpc.py @@ -10,7 +10,8 @@ # sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) import pickle import argparse -from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe, create_sched_settings_glm4moe, create_sched_settings_smallthinker, create_sched_settings_qwen3next +from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe, create_sched_settings_hunyuan, create_sched_settings_glm4moe, create_sched_settings_smallthinker, create_sched_settings_qwen3next + @@ -217,6 +218,8 @@ def get_inference_context_raw(self): settings = create_sched_settings_glm4moe(main_args) elif main_args.architectures == "SmallThinkerForCausalLM": settings = create_sched_settings_smallthinker(main_args) + elif main_args.architectures == "HunYuanMoEV1ForCausalLM": + settings = create_sched_settings_hunyuan(main_args) elif main_args.architectures == "Qwen3NextForCausalLM": settings = create_sched_settings_qwen3next(main_args) else: diff --git a/ktransformers/server/balance_serve/settings.py b/ktransformers/server/balance_serve/settings.py index b3bd5c2e..8e3ee8c8 100644 --- a/ktransformers/server/balance_serve/settings.py +++ b/ktransformers/server/balance_serve/settings.py @@ -175,6 +175,61 @@ def create_sched_settings_qwen3moe(args): settings.auto_derive() return settings +def create_sched_settings_hunyuan(args): + """Create scheduler settings for HunYuan MoE models""" + default_sample_options = sched_ext.SampleOptions() + model_name = os.path.basename(os.path.normpath(args.model_dir)) + input_model_settings = sched_ext.ModelSettings() + input_model_settings.model_path = args.model_dir + input_model_settings.params_count = int(0) + model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + input_model_settings.layer_count = model_config.num_hidden_layers + input_model_settings.num_k_heads = model_config.num_key_value_heads # Hunyuan: 8 KV heads + # Hunyuan specific: hidden_size=4096, num_attention_heads=32, num_key_value_heads=8 + # Each KV head dimension: hidden_size / num_attention_heads = 4096 / 32 = 128 + head_dim = getattr(model_config, 'head_dim', model_config.hidden_size // model_config.num_attention_heads) + input_model_settings.k_head_dim = head_dim # Should be 128 for Hunyuan + input_model_settings.bytes_per_params = 2 + input_model_settings.bytes_per_kv_cache_element = 2 + settings = sched_ext.Settings() + settings.model_name = model_name + settings.quant_type = "BF16" + settings.model_settings = input_model_settings + settings.page_size = args.page_size + settings.gpu_device_count = 1 # tp + settings.gpu_device_id = [i for i in range(settings.gpu_device_count)] + settings.gpu_memory_size = args.gpu_memory_size + settings.memory_utilization_percentage = args.utilization_percentage + max_batch_size = args.max_batch_size + chunk_size = args.chunk_size + + max_decode_batch_size = max_batch_size - 2 + + settings.max_batch_size = max_batch_size + settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2 + settings.sample_options = default_sample_options + settings.sched_metrics_port = args.sched_metrics_port + settings.gpu_only = args.memory_gpu_only + settings.use_self_defined_head_dim = False + settings.self_defined_head_dim = head_dim + settings.full_kv_cache_on_each_gpu = True + # Critical: Enable both k_cache and v_cache for Hunyuan + settings.k_cache_on = True + settings.v_cache_on = True + + settings.kvc2_root_path = args.kvc2_disk_path + settings.kvc2_config_path = args.kvc2_config_dir + settings.memory_pool_size_GB = args.cpu_memory_size_GB + settings.evict_count = 40 + settings.kvc2_metrics_port = args.kvc2_metrics_port + settings.load_from_disk = False + settings.save_to_disk = True + + settings.strategy_name = args.sched_strategy + settings.auto_derive() + return settings + + def create_sched_settings_glm4moe(args): default_sample_options = sched_ext.SampleOptions() model_name = os.path.basename(os.path.normpath(args.model_dir)) diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py index ee08e479..25ffa4e4 100644 --- a/ktransformers/util/custom_loader.py +++ b/ktransformers/util/custom_loader.py @@ -90,7 +90,17 @@ def load_tensor(self, key: str, device: str="cpu"): elif key in self.tensor_file_map: pass else: - raise KeyError(f"Key {key} not found in Safetensor files") + # Handle weight tying/sharing for models like Hunyuan + if key == "lm_head.weight": + alternative_key = "model.embed_tokens.weight" + if alternative_key in self.tensor_file_map: + print(f"Key '{key}' not found, using '{alternative_key}' (tied weights)") + key = alternative_key + else: + raise KeyError(f"Key {key} not found in Safetensor files, and alternative key {alternative_key} also not found") + else: + raise KeyError(f"Key {key} not found in Safetensor files") + file = self.tensor_file_map[key] f = self.file_handle_map.get(file) if f is None: diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 98a44f28..01702bad 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -164,6 +164,8 @@ def xpu_fp16_model(config): # Qwen3-30B seems have precision issue with FP16 # so we only use FP16 for Qwen3-235B now return True + if config.architectures[0] == "HunYuanMoEV1ForCausalLM": + return True return False def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"): @@ -302,6 +304,8 @@ def chunk_prefill(inputs, cache_position, past_key_values): from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache if model.config.architectures[0] in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]: past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None) + elif model.config.architectures[0] in ["HunYuanMoEV1ForCausalLM"]: + past_key_values = DynamicNormalCache.from_legacy_cache(None) else: past_key_values = DynamicNormalCache.from_legacy_cache(None) elif mode != 'long_context':