Skip to content

Commit 2fb9cc6

Browse files
yeyu-nvidiasoodoshll
authored andcommitted
Yeyu/hf eagle medusa (NVIDIA#664)
## What does this PR do? new feature, **Overview:** This PR implements HF parallel draft by combining eagle and medusa. In training, multiple medusa heads are added and trained together with eagle. In inference, medusa heads are used to generate draft tokens after all eagle tokens. ## Usage Set parallel_draft_step>1 in eagle_config to enable parallel draft. ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Ye Yu <[email protected]>
1 parent d9b706d commit 2fb9cc6

File tree

14 files changed

+216
-507
lines changed

14 files changed

+216
-507
lines changed
Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
11
{
2-
"rope_scaling": {
3-
"factor": 32.0,
4-
"low_freq_factor": 1.0,
5-
"high_freq_factor": 4.0,
6-
"original_max_position_embeddings": 8192,
7-
"rope_type": "llama3"
8-
},
9-
"initializer_range": 0.02,
102
"_attn_implementation": "sdpa"
113
}

examples/speculative_decoding/eagle_utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -518,18 +518,25 @@ def on_log(self, args, state, control, **kwargs):
518518
average_acc = np.mean(state.training_accs, axis=0)
519519
if self.estimate_ar:
520520
# Calculate mean training AR since last log
521-
# NOTE: This is only a estimate of the real AR.
521+
# NOTE: This is only an estimate of the real AR.
522522
est_ar = 1
523523
acc_cumprod = 1
524-
for step_acc in average_acc:
525-
est_ar += acc_cumprod * step_acc
524+
for step_acc in average_acc[0]:
526525
acc_cumprod *= step_acc
526+
est_ar += acc_cumprod
527+
# Parallel draft tokens only used after all eagle tokens
528+
for draft_acc in average_acc[1:]:
529+
acc_cumprod *= draft_acc[-1]
530+
est_ar += acc_cumprod
527531
print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}")
528532

529533
# log to wandb
530534
if wandb and is_master():
531-
for i, step_acc in enumerate(average_acc):
532-
wandb.log({f"step_{i}_train_acc": step_acc}, step=state.global_step)
535+
for i, draft_acc in enumerate(average_acc):
536+
for j, step_acc in enumerate(draft_acc):
537+
wandb.log(
538+
{f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step
539+
)
533540
if self.estimate_ar:
534541
wandb.log({"estimated_training_ar": est_ar}, step=state.global_step)
535542

examples/speculative_decoding/launch_train.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ while [ $# -gt 0 ]; do
9090
if [[ "$1" != *=* ]]; then shift; fi
9191
VLM_IMG_DIR="${1#*=}"
9292
;;
93+
--estimate_ar*)
94+
if [[ "$1" != *=* ]]; then shift; fi
95+
ESTIMATE_AR="${1#*=}"
96+
;;
9397
--ar_validate_steps*)
9498
if [[ "$1" != *=* ]]; then shift; fi
9599
AR_VALIDATE_STEPS="${1#*=}"
@@ -120,8 +124,6 @@ LR=${LR:-"1e-4"}
120124
TRAIN_BS=${TRAIN_BS:-4}
121125
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
122126
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
123-
REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1}
124-
REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1}
125127
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
126128
NUM_GPU=${NUM_GPU:-1}
127129
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
@@ -130,6 +132,7 @@ DISABLE_TQDM=${DISABLE_TQDM:-False}
130132
VLM_PROCESSOR=${VLM_PROCESSOR:-}
131133
VLM_IMG_DIR=${VLM_IMG_DIR:-}
132134
AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
135+
ESTIMATE_AR=${ESTIMATE_AR:-False}
133136

134137
if [[ "$MODE" == "medusa" ]]; then
135138
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
@@ -192,6 +195,7 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
192195
--tf32 True \
193196
--data_path $DATA \
194197
--disable_tqdm $DISABLE_TQDM \
198+
--estimate_ar $ESTIMATE_AR \
195199
--ar_validate_steps $AR_VALIDATE_STEPS \
196200
$VLM_ARGS \
197201
$OFFLINE_TRAINING_ARGS \

examples/speculative_decoding/main.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ class TrainingArguments(transformers.TrainingArguments):
9292
dataloader_drop_last: bool = field(default=True)
9393
bf16: bool = field(default=True)
9494
mode: Literal["eagle1", "eagle3", "medusa"] = "eagle3"
95+
estimate_ar: bool = field(
96+
default=False, metadata={"help": "Whether to estimate AR during training for logging."}
97+
)
9598
ar_validate_steps: int = field(default=1000, metadata={"help": "Steps between AR validation."})
9699
disable_tqdm: bool = field(default=False, metadata={"help": "Disable tqdm progress bar."})
97100
remove_unused_columns: bool = field(
@@ -193,22 +196,6 @@ def train():
193196
custom_config = json.load(f)
194197
config["eagle_architecture_config"].update(custom_config)
195198

196-
# Hidden size and vocab size must match base model
197-
llm_config = (
198-
model.config.llm_config if hasattr(model.config, "llm_config") else model.config
199-
)
200-
config["eagle_architecture_config"].update(
201-
{
202-
"hidden_size": llm_config.hidden_size,
203-
"vocab_size": llm_config.vocab_size,
204-
# we also overwrite max_pos_embedding for deployment compatibility
205-
"max_position_embeddings": llm_config.max_position_embeddings,
206-
"draft_vocab_size": custom_config["draft_vocab_size"]
207-
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
208-
else llm_config.vocab_size,
209-
}
210-
)
211-
212199
mtsp.convert(model, [("eagle", config)])
213200

214201
# read draft vocab cache
@@ -238,7 +225,7 @@ def train():
238225
model=model,
239226
processing_class=tokenizer,
240227
args=training_args,
241-
callbacks=[EagleTrainingPlot(training_args.ar_validate_steps)],
228+
callbacks=[EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)],
242229
**data_module,
243230
)
244231

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 80 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,62 @@
1818
import torch
1919
import torch.nn as nn
2020

21-
EAGLE_MODELOPT_TO_OFFICIAL = {
22-
"required": {
23-
"layers.0.self_attn.q_proj.weight": "midlayer.self_attn.q_proj.weight",
24-
"layers.0.self_attn.k_proj.weight": "midlayer.self_attn.k_proj.weight",
25-
"layers.0.self_attn.v_proj.weight": "midlayer.self_attn.v_proj.weight",
26-
"layers.0.self_attn.o_proj.weight": "midlayer.self_attn.o_proj.weight",
27-
"layers.0.mlp.gate_proj.weight": "midlayer.mlp.gate_proj.weight",
28-
"layers.0.mlp.up_proj.weight": "midlayer.mlp.up_proj.weight",
29-
"layers.0.mlp.down_proj.weight": "midlayer.mlp.down_proj.weight",
30-
"hidden_norm.weight": "midlayer.hidden_norm.weight",
31-
"input_embeds_norm.weight": "midlayer.input_layernorm.weight",
32-
"layers.0.post_attention_layernorm.weight": "midlayer.post_attention_layernorm.weight",
33-
"norm.weight": "norm.weight",
34-
"fc.weight": "fc.weight",
35-
},
36-
"optional": {
37-
"d2t": "d2t",
38-
"eagle_lm_head.weight": "lm_head.weight",
39-
},
40-
}
21+
22+
def eagle_state_dict_key_convert(num_hidden_layers: int = 1) -> dict[str, dict[str, str]]:
23+
"""Convert our eagle model state dict key to official format key(s)."""
24+
assert num_hidden_layers >= 1, "num_hidden_layers should be at least 1."
25+
eagle_modelopt_to_official = {
26+
"required": {
27+
"norm.weight": "norm.weight",
28+
"fc.weight": "fc.weight",
29+
},
30+
"optional": {
31+
"d2t": "d2t",
32+
"eagle_lm_head.weight": "lm_head.weight",
33+
},
34+
}
35+
if num_hidden_layers == 1:
36+
eagle_modelopt_to_official["required"].update(
37+
{
38+
"hidden_norm.weight": "midlayer.hidden_norm.weight",
39+
"input_embeds_norm.weight": "midlayer.input_layernorm.weight",
40+
}
41+
)
42+
else:
43+
eagle_modelopt_to_official["required"].update(
44+
{
45+
"hidden_norm.weight": "midlayer.0.hidden_norm.weight",
46+
"input_embeds_norm.weight": "midlayer.0.input_layernorm.weight",
47+
}
48+
)
49+
for i in range(num_hidden_layers):
50+
if num_hidden_layers == 1:
51+
index = ""
52+
else:
53+
index = f".{i}"
54+
eagle_modelopt_to_official["required"].update(
55+
{
56+
f"layers.{i}.self_attn.q_proj.weight": "midlayer"
57+
+ index
58+
+ ".self_attn.q_proj.weight",
59+
f"layers.{i}.self_attn.k_proj.weight": "midlayer"
60+
+ index
61+
+ ".self_attn.k_proj.weight",
62+
f"layers.{i}.self_attn.v_proj.weight": "midlayer"
63+
+ index
64+
+ ".self_attn.v_proj.weight",
65+
f"layers.{i}.self_attn.o_proj.weight": "midlayer"
66+
+ index
67+
+ ".self_attn.o_proj.weight",
68+
f"layers.{i}.mlp.gate_proj.weight": "midlayer" + index + ".mlp.gate_proj.weight",
69+
f"layers.{i}.mlp.up_proj.weight": "midlayer" + index + ".mlp.up_proj.weight",
70+
f"layers.{i}.mlp.down_proj.weight": "midlayer" + index + ".mlp.down_proj.weight",
71+
f"layers.{i}.post_attention_layernorm.weight": "midlayer"
72+
+ index
73+
+ ".post_attention_layernorm.weight",
74+
}
75+
)
76+
return eagle_modelopt_to_official
4177

4278

4379
def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
@@ -61,15 +97,16 @@ def export_spec_ckpt_state_dict(model: nn.Module):
6197
# check the model has only speculative decoding
6298
assert spec_opt_only(model), "Not purely eagle model."
6399

100+
eagle_modelopt_to_official = eagle_state_dict_key_convert(model.eagle_config.num_hidden_layers)
64101
# Check if the state dict keys match
65-
_check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
102+
_check_state_dict_keys_match(model.eagle_module, eagle_modelopt_to_official["required"])
66103

67104
# Convert key names and save the state dict
68105
eagle_state = model.eagle_module.state_dict()
69106
export_state_dict = {}
70107
for ours_key, export_key in {
71-
**EAGLE_MODELOPT_TO_OFFICIAL["required"],
72-
**EAGLE_MODELOPT_TO_OFFICIAL["optional"],
108+
**eagle_modelopt_to_official["required"],
109+
**eagle_modelopt_to_official["optional"],
73110
}.items():
74111
if ours_key in eagle_state:
75112
export_state_dict[export_key] = eagle_state[ours_key]
@@ -78,6 +115,21 @@ def export_spec_ckpt_state_dict(model: nn.Module):
78115
if "eagle_lm_head.weight" not in eagle_state:
79116
export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"]
80117

118+
# Add parallel draft weights
119+
if model.eagle_config.parallel_draft_step > 1:
120+
for i in range(model.eagle_config.parallel_draft_step - 1):
121+
for j in range(model.eagle_config.parallel_draft_heads_num_layers):
122+
export_state_dict[f"parallel_draft_heads.{i}.medusa_layers.{j}.linear.weight"] = (
123+
eagle_state[f"parallel_draft_heads.{i}.{j}.linear.weight"]
124+
)
125+
if f"parallel_draft_heads.{i}.{j}.linear.bias" in eagle_state:
126+
export_state_dict[f"parallel_draft_heads.{i}.medusa_layers.{j}.linear.bias"] = (
127+
eagle_state[f"parallel_draft_heads.{i}.{j}.linear.bias"]
128+
)
129+
export_state_dict[f"parallel_draft_heads.{i}.lm_head.weight"] = eagle_state[
130+
f"parallel_draft_heads.{i}.{model.eagle_config.parallel_draft_heads_num_layers}.weight"
131+
]
132+
81133
return export_state_dict
82134

83135

@@ -120,6 +172,9 @@ def export_spec_ckpt_config(model: nn.Module):
120172
"use_input_layernorm_in_first_layer": None,
121173
"use_last_layernorm": None,
122174
"use_mtp_layernorm": None,
175+
"next_layer_regular": True,
176+
"parallel_draft_step": None,
177+
"parallel_draft_heads_num_layers": None,
123178
},
124179
}
125180

@@ -136,7 +191,8 @@ def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module):
136191
if isinstance(value, dict):
137192
# for eagle config, we find it in model.eagle_config
138193
for sub_key in value:
139-
value[sub_key] = _get_config_from_eagle_config_or_base_config(sub_key, model)
194+
if value[sub_key] is None:
195+
value[sub_key] = _get_config_from_eagle_config_or_base_config(sub_key, model)
140196
elif value is None:
141197
# First, we try to load fron eagle config.
142198
new_value = _get_config_from_eagle_config_or_base_config(key, model)

modelopt/torch/export/plugins/mcore_llama.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@
9898
"final_layernorm": NameRemapping("norm."),
9999
"d2t": NameRemapping("d2t"),
100100
"output_layer": NameRemapping("lm_head."),
101-
"parallel_draft_heads.medusa_layers": NameRemapping("parallel_draft_heads.{}.{}.linear."),
101+
"parallel_draft_heads.medusa_layers": NameRemapping(
102+
"parallel_draft_heads.{}.medusa_layers.{}.linear."
103+
),
102104
"parallel_draft_heads.lm_head": NameRemapping("parallel_draft_heads.{}.lm_head."),
103105
}
104106

@@ -115,7 +117,9 @@
115117
"final_layernorm": NameRemapping("norm."),
116118
"d2t": NameRemapping("d2t"),
117119
"output_layer": NameRemapping("lm_head."),
118-
"parallel_draft_heads.medusa_layers": NameRemapping("parallel_draft_heads.{}.{}.linear."),
120+
"parallel_draft_heads.medusa_layers": NameRemapping(
121+
"parallel_draft_heads.{}.medusa_layers.{}.linear."
122+
),
119123
"parallel_draft_heads.lm_head": NameRemapping("parallel_draft_heads.{}.lm_head."),
120124
}
121125

@@ -133,7 +137,9 @@
133137
"final_layernorm": NameRemapping("norm."),
134138
"d2t": NameRemapping("d2t"),
135139
"output_layer": NameRemapping("lm_head."),
136-
"parallel_draft_heads.medusa_layers": NameRemapping("parallel_draft_heads.{}.{}.linear."),
140+
"parallel_draft_heads.medusa_layers": NameRemapping(
141+
"parallel_draft_heads.{}.medusa_layers.{}.linear."
142+
),
137143
"parallel_draft_heads.lm_head": NameRemapping("parallel_draft_heads.{}.lm_head."),
138144
}
139145

modelopt/torch/export/unified_export_megatron.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -212,26 +212,14 @@ def __init__(
212212
)
213213

214214
eagle_config = {
215-
"use_input_layernorm_in_first_layer": mode_cfg["config"][
216-
"eagle_architecture_config"
217-
]["use_input_layernorm_in_first_layer"],
218-
"use_last_layernorm": mode_cfg["config"]["eagle_architecture_config"][
219-
"use_last_layernorm"
220-
],
221-
"use_mtp_layernorm": mode_cfg["config"]["eagle_architecture_config"][
222-
"use_mtp_layernorm"
223-
],
224-
"use_aux_hidden_state": mode_cfg["config"]["eagle_architecture_config"][
225-
"use_aux_hidden_state"
226-
],
215+
"use_input_layernorm_in_first_layer": model.eagle_config.use_input_layernorm_in_first_layer,
216+
"use_last_layernorm": model.eagle_config.use_last_layernorm,
217+
"use_mtp_layernorm": model.eagle_config.use_mtp_layernorm,
218+
"use_aux_hidden_state": model.eagle_config.use_aux_hidden_state,
227219
"eagle_aux_hidden_state_layer_ids": model.eagle_config.eagle_aux_hidden_state_layer_ids,
228220
"next_layer_regular": True,
229-
"parallel_draft_step": mode_cfg["config"]["eagle_architecture_config"][
230-
"parallel_draft_step"
231-
],
232-
"parallel_draft_heads_num_layers": mode_cfg["config"][
233-
"eagle_architecture_config"
234-
]["parallel_draft_heads_num_layers"],
221+
"parallel_draft_step": model.eagle_config.parallel_draft_step,
222+
"parallel_draft_heads_num_layers": model.eagle_config.parallel_draft_heads_num_layers,
235223
}
236224

237225
eagle_config_update = {
@@ -243,9 +231,7 @@ def __init__(
243231
"max_position_embeddings": self._hf_text_config.max_position_embeddings,
244232
"num_attention_heads": model.eagle_module.config.num_attention_heads,
245233
"num_key_value_heads": model.eagle_module.config.num_query_groups,
246-
"num_hidden_layers": mode_cfg["config"]["eagle_architecture_config"][
247-
"num_hidden_layers"
248-
],
234+
"num_hidden_layers": model.eagle_config.num_layers,
249235
"vocab_size": self._hf_text_config.vocab_size,
250236
# Unset any special token ids given that the tokenizer can change here.
251237
"bos_token_id": None,
@@ -254,9 +240,7 @@ def __init__(
254240
"sep_token_id": None,
255241
# The following attributes are EAGLE specific
256242
"eagle_config": eagle_config,
257-
"draft_vocab_size": mode_cfg["config"]["eagle_architecture_config"][
258-
"draft_vocab_size"
259-
],
243+
"draft_vocab_size": model.eagle_config.draft_vocab_size,
260244
}
261245

262246
self._hf_extra_config.update(eagle_config_update)

modelopt/torch/speculative/eagle/default_config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
default_eagle_config = {
1919
"hidden_act": "silu",
2020
"torch_dtype": "bfloat16",
21-
"vocab_size": 128256,
22-
"draft_vocab_size": 128256,
23-
"max_position_embeddings": 8192,
2421
"position_embedding_type": "rope",
2522
"rope_scaling": {
2623
"factor": 8.0,
@@ -31,7 +28,6 @@
3128
},
3229
"rope_theta": 500000.0,
3330
"num_hidden_layers": 1,
34-
"hidden_size": 4096,
3531
"intermediate_size": 14336,
3632
"num_attention_heads": 32,
3733
"num_key_value_heads": 8,

modelopt/torch/speculative/eagle/eagle_model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
"""Eagle model to support eagle decoding."""
1717

18-
import torch
19-
2018
from modelopt.torch.opt.dynamic import DynamicModule
2119

2220

@@ -45,7 +43,3 @@ def modify(
4543
self.eagle_report_acc = eagle_report_acc
4644
self.eagle_reuse_base_decoder = eagle_reuse_base_decoder
4745
self.eagle_loss_decay_factor = eagle_loss_decay_factor
48-
49-
if eagle_architecture_config.get("parallel_draft_step", 1) > 1:
50-
for i in range(eagle_architecture_config.get("parallel_draft_step") - 1):
51-
self.register_buffer(f"mask_token_{i}", torch.tensor(-1))

0 commit comments

Comments
 (0)