From fe4a5689ee57688f0147aa8783d09f735d5b3e6c Mon Sep 17 00:00:00 2001 From: "xuhaojie.2025" Date: Thu, 28 Aug 2025 10:30:02 +0800 Subject: [PATCH] Adapt Eagle3 for Deepseek architecture --- configs/deepseek-r1-eagle3.json | 55 ++ examples/run_dpsk_r1_eagle3_offline.sh | 21 + specforge/data/template.py | 31 + specforge/modeling/__init__.py | 2 + specforge/modeling/auto.py | 4 + specforge/modeling/draft/__init__.py | 3 +- specforge/modeling/draft/deepseekv3_eagle.py | 569 +++++++++++++++++++ 7 files changed, 684 insertions(+), 1 deletion(-) create mode 100644 configs/deepseek-r1-eagle3.json create mode 100644 examples/run_dpsk_r1_eagle3_offline.sh create mode 100644 specforge/modeling/draft/deepseekv3_eagle.py diff --git a/configs/deepseek-r1-eagle3.json b/configs/deepseek-r1-eagle3.json new file mode 100644 index 00000000..ede21aef --- /dev/null +++ b/configs/deepseek-r1-eagle3.json @@ -0,0 +1,55 @@ +{ + "architectures": [ + "DeepseekV3ForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 0, + "eos_token_id": 1, + "ep_size": 1, + "first_k_dense_replace": 3, + "hidden_act": "silu", + "hidden_size": 7168, + "initializer_range": 0.02, + "intermediate_size": 18432, + "kv_lora_rank": 512, + "max_position_embeddings": 2048, + "model_type": "deepseek_v3", + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "n_group": 8, + "n_routed_experts": 256, + "n_shared_experts": 1, + "norm_topk_prob": true, + "num_attention_heads": 128, + "num_experts_per_tok": 8, + "num_hidden_layers": 1, + "num_key_value_heads": 128, + "num_nextn_predict_layers": 1, + "pad_token_id": 0, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn" + }, + "rope_theta": 10000, + "routed_scaling_factor": 2.5, + "scoring_func": "sigmoid", + "tie_word_embeddings": false, + "topk_group": 4, + "topk_method": "noaux_tc", + "torch_dtype": "float16", + "transformers_version": "4.28.1", + "use_cache": true, + "v_head_dim": 128, + "vocab_size": 129280, + "draft_vocab_size": 32000 +} \ No newline at end of file diff --git a/examples/run_dpsk_r1_eagle3_offline.sh b/examples/run_dpsk_r1_eagle3_offline.sh new file mode 100644 index 00000000..ffccdcd5 --- /dev/null +++ b/examples/run_dpsk_r1_eagle3_offline.sh @@ -0,0 +1,21 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for DeepSeek-R1 offline +NUM_GPUS=${1:-8} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3_offline.py \ + --target-model-path deepseek-ai/DeepSeek-R1-0528 \ + --draft-model-config $ROOT_DIR/configs/deepseek-r1-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \ + --train-hidden-states-path $ROOT_DIR/cache/hidden_states/ \ + --output-dir $ROOT_DIR/outputs/Deepseek-r1-eagle3 \ + --num-epochs 10 \ + --batch-size 2 \ + --learning-rate 1e-5 \ + --max-length 2048 \ + --chat-template deepseek_r1 \ + --cache-dir $ROOT_DIR/cache\ diff --git a/specforge/data/template.py b/specforge/data/template.py index 12241113..b7ca011f 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -163,3 +163,34 @@ def get_all_template_names(self) -> List[str]: end_of_turn_token="<|end|>", ), ) + +TEMPLATE_REGISTRY.register( + name="kimi_k2", + template=ChatTemplate( + system_prompt="You are a helpful assistant.", + user_header="<|im_user|>user<|im_middle|>", + assistant_header="<|im_assistant|>assistant<|im_middle|>", + end_of_turn_token="<|im_end|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="deepseek_v3", + template=ChatTemplate( + system_prompt="You are a helpful assistant.", + user_header="<|User|>", + assistant_header="<|Assistant|>", + end_of_turn_token="<|end▁of▁sentence|>", + ), +) + + +TEMPLATE_REGISTRY.register( + name="deepseek_r1", + template=ChatTemplate( + system_prompt="You are a helpful assistant.", + user_header="<|User|>", + assistant_header="<|Assistant|>", + end_of_turn_token="<|end▁of▁sentence|>", + ), +) \ No newline at end of file diff --git a/specforge/modeling/__init__.py b/specforge/modeling/__init__.py index 074d1e65..734705be 100644 --- a/specforge/modeling/__init__.py +++ b/specforge/modeling/__init__.py @@ -1,9 +1,11 @@ from .auto import AutoDistributedTargetModel, AutoDraftModelConfig, AutoEagle3DraftModel from .draft.llama3_eagle import LlamaForCausalLMEagle3 +from .draft.deepseekv3_eagle import DeepseekV3ForCausalLMEagle3 __all__ = [ "AutoDraftModelConfig", "AutoEagle3DraftModel", "AutoDistributedTargetModel", "LlamaForCausalLMEagle3", + "DeepseekV3ForCausalLMEagle3", ] diff --git a/specforge/modeling/auto.py b/specforge/modeling/auto.py index 0e54777b..75dc196d 100644 --- a/specforge/modeling/auto.py +++ b/specforge/modeling/auto.py @@ -16,12 +16,14 @@ Qwen2Config, Qwen3Config, Qwen3MoeConfig, + DeepseekV3Config, modeling_utils, ) from specforge.utils import default_torch_dtype from .draft.llama3_eagle import LlamaForCausalLMEagle3 +from .draft.deepseekv3_eagle import DeepseekV3ForCausalLMEagle3 from .target.llama import LlamaForCausalLM from .target.llama4 import Llama4ForCausalLM from .target.phi3 import Phi3ForCausalLM @@ -34,6 +36,7 @@ class AutoEagle3DraftModel(AutoModelForCausalLMBase): # the model mapping is currently hardcoded, we should support lazy model mapping via registry _model_mapping = { LlamaConfig: LlamaForCausalLMEagle3, + DeepseekV3Config: DeepseekV3ForCausalLMEagle3, } @classmethod @@ -134,6 +137,7 @@ class AutoDraftModelConfig: _config_mapping = { "LlamaForCausalLMEagle3": LlamaConfig, + "DeepseekV3ForCausalLMEagle3": DeepseekV3Config, } @classmethod diff --git a/specforge/modeling/draft/__init__.py b/specforge/modeling/draft/__init__.py index f32ce0ef..5c4f3139 100644 --- a/specforge/modeling/draft/__init__.py +++ b/specforge/modeling/draft/__init__.py @@ -1,4 +1,5 @@ from .base import Eagle3DraftModel from .llama3_eagle import LlamaForCausalLMEagle3 +from .deepseekv3_eagle import DeepseekV3ForCausalLMEagle3 -__all__ = ["Eagle3DraftModel", "LlamaForCausalLMEagle3"] +__all__ = ["Eagle3DraftModel", "LlamaForCausalLMEagle3", "DeepseekV3ForCausalLMEagle3"] diff --git a/specforge/modeling/draft/deepseekv3_eagle.py b/specforge/modeling/draft/deepseekv3_eagle.py new file mode 100644 index 00000000..b0deca0c --- /dev/null +++ b/specforge/modeling/draft/deepseekv3_eagle.py @@ -0,0 +1,569 @@ +import logging +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import GenerationMixin, PreTrainedModel +from transformers.activations import ACT2FN + +from ..utils import padding +from .base import Eagle3DraftModel +from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config + +logger = logging.getLogger(__name__) + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Applies Rotary Position Embedding to the query and key tensors with interleaved mode. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + +class DeepseekV3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class DeepseekV3RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, config: DeepseekV3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + # 简化的rope初始化 + self.dim = config.qk_rope_head_dim + self.base = config.rope_theta + + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + self.attention_scaling = 1.0 + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + +class DeepseekV3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +class Eagle3MLP(nn.Module): + """Eagle3专用的MLP层,用于将2*hidden_size转换为hidden_size""" + def __init__(self, config): + super().__init__() + self.config = config + self.input_size = config.hidden_size * 2 + self.hidden_size = config.hidden_size + + # 添加输入norm + self.input_norm = DeepseekV3RMSNorm(self.input_size, eps=config.rms_norm_eps) + self.proj = nn.Linear(self.input_size, self.hidden_size, bias=False) + + def forward(self, x): + # 在线性投影前进行norm + x = self.input_norm(x) + return self.proj(x) + +class DeepseekV3Attention(nn.Module): + """Multi-headed attention from DeepSeek V3 - 使用原版实现,输入为hidden_size""" + + def __init__(self, config: DeepseekV3Config): + super().__init__() + self.config = config + self.layer_idx = 0 # Eagle3中只有一层 + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + + # 标准的DeepSeek V3 attention投影层 - 输入为hidden_size + if self.q_lora_rank is None: + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=getattr(config, 'attention_bias', False)) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=getattr(config, 'attention_bias', False), + ) + self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=getattr(config, 'attention_bias', False), + ) + + self.scaling = self.qk_head_dim ** (-0.5) + if hasattr(config, 'rope_scaling') and config.rope_scaling is not None: + mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + cache_hidden: Optional[List[torch.Tensor]] = None, + position_ids: Optional[torch.LongTensor] = None, + **kwargs, + ) -> torch.Tensor: + + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + # 完全复制官方实现的逻辑 + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + else: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + cos, sin = position_embeddings + if hasattr(self.config, 'rope_interleave') and self.config.rope_interleave: + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + # Eagle3的缓存处理 + if cache_hidden is None: + # Standard attention without cache + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + dropout_p=0.0, + ) + else: + # Eagle3-style cached attention + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + cache_hidden[0] = cache_hidden[0] + [key_states] + cache_hidden[1] = cache_hidden[1] + [value_states] + + cache_k = cache_hidden[0] + cache_v = cache_hidden[1] + + k0 = cache_k[0] + v0 = cache_v[0] + + attn_weights = torch.matmul(query_states, k0.transpose(2, 3)) * self.scaling + lck = len(cache_k) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + for i in range(1, lck): + ki = cache_k[i] + qi = query_states + kiq = ki + + attn_weightsi = (qi * kiq).sum(-1) * self.scaling + attn_weights = torch.cat((attn_weights, attn_weightsi[..., None]), dim=-1) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights0 = attn_weights[..., :seq_length] + + attn_output = torch.matmul(attn_weights0, v0) + + for i in range(1, lck): + vi = cache_v[i] + attn_weightsi = attn_weights[..., seq_length + i - 1] + attn_outputi = attn_weightsi[..., None] * vi + attn_output = attn_output + attn_outputi + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output + +class DeepseekV3DecoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + + # Eagle3特有:将2*hidden_size转换为hidden_size的MLP + self.input_proj = Eagle3MLP(config) + + # 标准的DeepSeek V3组件 + self.self_attn = DeepseekV3Attention(config) + self.mlp = DeepseekV3MLP(config) + + # Layer norms + self.hidden_norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + # RoPE embedding + self.rotary_emb = DeepseekV3RotaryEmbedding(config) + + def forward( + self, + input_emb: torch.Tensor, + hidden_states: torch.Tensor, + cache_hidden: List[List[torch.Tensor]] = [], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> torch.Tensor: + + residual = hidden_states + + # 对输入进行norm + hidden_states = self.hidden_norm(hidden_states) + input_emb = self.input_layernorm(input_emb) + + # Eagle3特性:拼接输入并通过投影层转换为hidden_size + concat_hidden = torch.cat((input_emb, hidden_states), dim=-1) + projected_hidden = self.input_proj(concat_hidden) + + # 生成position embeddings + position_embeddings = self.rotary_emb(projected_hidden, position_ids) + + # Self Attention - 现在输入是标准的hidden_size + attn_output = self.self_attn( + hidden_states=projected_hidden, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + cache_hidden=cache_hidden, + position_ids=position_ids, + ) + hidden_states = residual + attn_output + + # MLP + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + +class DeepseekV3ForCausalLMEagle3(Eagle3DraftModel): + + config_class = DeepseekV3Config + + def __init__(self, config, quant_config=None) -> None: + super().__init__(config) + self.config = config + self.quant_config = quant_config + + self.vocab_size = config.vocab_size + self.draft_vocab_size = config.draft_vocab_size + + # 安全处理pad_token_id + pad_token_id = getattr(config, 'pad_token_id', None) + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, pad_token_id + ) + self.midlayer = DeepseekV3DecoderLayer(config) + + if hasattr(config, "target_hidden_size"): + self.fc = torch.nn.Linear( + config.target_hidden_size * 3, config.hidden_size, bias=False + ) + else: + self.fc = torch.nn.Linear( + config.hidden_size * 3, config.hidden_size, bias=False + ) + + self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.lm_head = nn.Linear( + config.hidden_size, config.draft_vocab_size, bias=False + ) + + # create vocab buffers + t2d = torch.zeros(self.vocab_size, dtype=torch.bool) + d2t = torch.zeros(self.draft_vocab_size, dtype=torch.int64) + self.register_buffer("t2d", t2d) + self.register_buffer("d2t", d2t) + + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ttt_length: int = 1, + ): + """ + Arguments: + hidden_states (`torch.FloatTensor`): input to the layer, cat low, mid high hidden_states of shape `(batch, seq_len, hidden_states * 3)` + inputs_embeds (`torch.FloatTensor`): input embeddings + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + """ + if ttt_length == 1: + logger.debug("using ttt_length 1, no need to cache hidden states") + cache_hidden = None + else: + logger.debug(f"using ttt_length {ttt_length}, caching hidden states") + cache_hidden = [[], []] + + batch_size, seq_length, _ = hidden_states.size() + + # make position ids + device = hidden_states.device + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + + # make attention mask + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, 0 + ) + + # fc + hidden_states = self.fc(hidden_states) + hidden_states = self.midlayer( + input_emb=inputs_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None, + output_attentions=False, + use_cache=False, + ) + + # norm + hidden_states = self.norm(hidden_states) + + return hidden_states + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + # eagle 3 requires hidden states from 3 layers + assert hidden_states.size(-1) == self.config.hidden_size * 3 + return self.fc(hidden_states) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + norm_hidden_states = self.norm(hidden_states) + return self.lm_head(norm_hidden_states) + + def backbone( + self, + input_embeds: torch.Tensor, + hidden_states: torch.Tensor, + cache_hidden: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + use_cache: bool = True, + ) -> torch.Tensor: + return self.midlayer( + input_emb=input_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None, + output_attentions=False, + use_cache=False, + ) \ No newline at end of file