Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions examples/speculative_decoding/eagle_config.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
{
"rope_scaling": {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"initializer_range": 0.02,
"_attn_implementation": "sdpa"
}
17 changes: 12 additions & 5 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,18 +518,25 @@ def on_log(self, args, state, control, **kwargs):
average_acc = np.mean(state.training_accs, axis=0)
if self.estimate_ar:
# Calculate mean training AR since last log
# NOTE: This is only a estimate of the real AR.
# NOTE: This is only an estimate of the real AR.
est_ar = 1
acc_cumprod = 1
for step_acc in average_acc:
est_ar += acc_cumprod * step_acc
for step_acc in average_acc[0]:
acc_cumprod *= step_acc
est_ar += acc_cumprod
# Parallel draft tokens only used after all eagle tokens
for draft_acc in average_acc[1:]:
acc_cumprod *= draft_acc[-1]
est_ar += acc_cumprod
print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}")

# log to wandb
if wandb and is_master():
for i, step_acc in enumerate(average_acc):
wandb.log({f"step_{i}_train_acc": step_acc}, step=state.global_step)
for i, draft_acc in enumerate(average_acc):
for j, step_acc in enumerate(draft_acc):
wandb.log(
{f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step
)
if self.estimate_ar:
wandb.log({"estimated_training_ar": est_ar}, step=state.global_step)

Expand Down
8 changes: 6 additions & 2 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
VLM_IMG_DIR="${1#*=}"
;;
--estimate_ar*)
if [[ "$1" != *=* ]]; then shift; fi
ESTIMATE_AR="${1#*=}"
;;
--ar_validate_steps*)
if [[ "$1" != *=* ]]; then shift; fi
AR_VALIDATE_STEPS="${1#*=}"
Expand Down Expand Up @@ -120,8 +124,6 @@ LR=${LR:-"1e-4"}
TRAIN_BS=${TRAIN_BS:-4}
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1}
REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1}
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
NUM_GPU=${NUM_GPU:-1}
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
Expand All @@ -130,6 +132,7 @@ DISABLE_TQDM=${DISABLE_TQDM:-False}
VLM_PROCESSOR=${VLM_PROCESSOR:-}
VLM_IMG_DIR=${VLM_IMG_DIR:-}
AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
ESTIMATE_AR=${ESTIMATE_AR:-False}

if [[ "$MODE" == "medusa" ]]; then
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
Expand Down Expand Up @@ -192,6 +195,7 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
--tf32 True \
--data_path $DATA \
--disable_tqdm $DISABLE_TQDM \
--estimate_ar $ESTIMATE_AR \
--ar_validate_steps $AR_VALIDATE_STEPS \
$VLM_ARGS \
$OFFLINE_TRAINING_ARGS \
Expand Down
21 changes: 4 additions & 17 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ class TrainingArguments(transformers.TrainingArguments):
dataloader_drop_last: bool = field(default=True)
bf16: bool = field(default=True)
mode: Literal["eagle1", "eagle3", "medusa"] = "eagle3"
estimate_ar: bool = field(
default=False, metadata={"help": "Whether to estimate AR during training for logging."}
)
ar_validate_steps: int = field(default=1000, metadata={"help": "Steps between AR validation."})
disable_tqdm: bool = field(default=False, metadata={"help": "Disable tqdm progress bar."})
remove_unused_columns: bool = field(
Expand Down Expand Up @@ -193,22 +196,6 @@ def train():
custom_config = json.load(f)
config["eagle_architecture_config"].update(custom_config)

# Hidden size and vocab size must match base model
llm_config = (
model.config.llm_config if hasattr(model.config, "llm_config") else model.config
)
config["eagle_architecture_config"].update(
{
"hidden_size": llm_config.hidden_size,
"vocab_size": llm_config.vocab_size,
# we also overwrite max_pos_embedding for deployment compatibility
"max_position_embeddings": llm_config.max_position_embeddings,
"draft_vocab_size": custom_config["draft_vocab_size"]
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
else llm_config.vocab_size,
}
)

mtsp.convert(model, [("eagle", config)])

# read draft vocab cache
Expand Down Expand Up @@ -238,7 +225,7 @@ def train():
model=model,
processing_class=tokenizer,
args=training_args,
callbacks=[EagleTrainingPlot(training_args.ar_validate_steps)],
callbacks=[EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)],
**data_module,
)

Expand Down
104 changes: 80 additions & 24 deletions modelopt/torch/export/plugins/hf_spec_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,62 @@
import torch
import torch.nn as nn

EAGLE_MODELOPT_TO_OFFICIAL = {
"required": {
"layers.0.self_attn.q_proj.weight": "midlayer.self_attn.q_proj.weight",
"layers.0.self_attn.k_proj.weight": "midlayer.self_attn.k_proj.weight",
"layers.0.self_attn.v_proj.weight": "midlayer.self_attn.v_proj.weight",
"layers.0.self_attn.o_proj.weight": "midlayer.self_attn.o_proj.weight",
"layers.0.mlp.gate_proj.weight": "midlayer.mlp.gate_proj.weight",
"layers.0.mlp.up_proj.weight": "midlayer.mlp.up_proj.weight",
"layers.0.mlp.down_proj.weight": "midlayer.mlp.down_proj.weight",
"hidden_norm.weight": "midlayer.hidden_norm.weight",
"input_embeds_norm.weight": "midlayer.input_layernorm.weight",
"layers.0.post_attention_layernorm.weight": "midlayer.post_attention_layernorm.weight",
"norm.weight": "norm.weight",
"fc.weight": "fc.weight",
},
"optional": {
"d2t": "d2t",
"eagle_lm_head.weight": "lm_head.weight",
},
}

def eagle_state_dict_key_convert(num_hidden_layers: int = 1) -> dict[str, dict[str, str]]:
"""Convert our eagle model state dict key to official format key(s)."""
assert num_hidden_layers >= 1, "num_hidden_layers should be at least 1."
eagle_modelopt_to_official = {
"required": {
"norm.weight": "norm.weight",
"fc.weight": "fc.weight",
},
"optional": {
"d2t": "d2t",
"eagle_lm_head.weight": "lm_head.weight",
},
}
if num_hidden_layers == 1:
eagle_modelopt_to_official["required"].update(
{
"hidden_norm.weight": "midlayer.hidden_norm.weight",
"input_embeds_norm.weight": "midlayer.input_layernorm.weight",
}
)
else:
eagle_modelopt_to_official["required"].update(
{
"hidden_norm.weight": "midlayer.0.hidden_norm.weight",
"input_embeds_norm.weight": "midlayer.0.input_layernorm.weight",
}
)
for i in range(num_hidden_layers):
if num_hidden_layers == 1:
index = ""
else:
index = f".{i}"
eagle_modelopt_to_official["required"].update(
{
f"layers.{i}.self_attn.q_proj.weight": "midlayer"
+ index
+ ".self_attn.q_proj.weight",
f"layers.{i}.self_attn.k_proj.weight": "midlayer"
+ index
+ ".self_attn.k_proj.weight",
f"layers.{i}.self_attn.v_proj.weight": "midlayer"
+ index
+ ".self_attn.v_proj.weight",
f"layers.{i}.self_attn.o_proj.weight": "midlayer"
+ index
+ ".self_attn.o_proj.weight",
f"layers.{i}.mlp.gate_proj.weight": "midlayer" + index + ".mlp.gate_proj.weight",
f"layers.{i}.mlp.up_proj.weight": "midlayer" + index + ".mlp.up_proj.weight",
f"layers.{i}.mlp.down_proj.weight": "midlayer" + index + ".mlp.down_proj.weight",
f"layers.{i}.post_attention_layernorm.weight": "midlayer"
+ index
+ ".post_attention_layernorm.weight",
}
)
return eagle_modelopt_to_official


def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
Expand All @@ -61,15 +97,16 @@ def export_spec_ckpt_state_dict(model: nn.Module):
# check the model has only speculative decoding
assert spec_opt_only(model), "Not purely eagle model."

eagle_modelopt_to_official = eagle_state_dict_key_convert(model.eagle_config.num_hidden_layers)
# Check if the state dict keys match
_check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
_check_state_dict_keys_match(model.eagle_module, eagle_modelopt_to_official["required"])

# Convert key names and save the state dict
eagle_state = model.eagle_module.state_dict()
export_state_dict = {}
for ours_key, export_key in {
**EAGLE_MODELOPT_TO_OFFICIAL["required"],
**EAGLE_MODELOPT_TO_OFFICIAL["optional"],
**eagle_modelopt_to_official["required"],
**eagle_modelopt_to_official["optional"],
}.items():
if ours_key in eagle_state:
export_state_dict[export_key] = eagle_state[ours_key]
Expand All @@ -78,6 +115,21 @@ def export_spec_ckpt_state_dict(model: nn.Module):
if "eagle_lm_head.weight" not in eagle_state:
export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"]

# Add parallel draft weights
if model.eagle_config.parallel_draft_step > 1:
for i in range(model.eagle_config.parallel_draft_step - 1):
for j in range(model.eagle_config.parallel_draft_heads_num_layers):
export_state_dict[f"parallel_draft_heads.{i}.medusa_layers.{j}.linear.weight"] = (
eagle_state[f"parallel_draft_heads.{i}.{j}.linear.weight"]
)
if f"parallel_draft_heads.{i}.{j}.linear.bias" in eagle_state:
export_state_dict[f"parallel_draft_heads.{i}.medusa_layers.{j}.linear.bias"] = (
eagle_state[f"parallel_draft_heads.{i}.{j}.linear.bias"]
)
export_state_dict[f"parallel_draft_heads.{i}.lm_head.weight"] = eagle_state[
f"parallel_draft_heads.{i}.{model.eagle_config.parallel_draft_heads_num_layers}.weight"
]

return export_state_dict


Expand Down Expand Up @@ -120,6 +172,9 @@ def export_spec_ckpt_config(model: nn.Module):
"use_input_layernorm_in_first_layer": None,
"use_last_layernorm": None,
"use_mtp_layernorm": None,
"next_layer_regular": True,
"parallel_draft_step": None,
"parallel_draft_heads_num_layers": None,
},
}

Expand All @@ -136,7 +191,8 @@ def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module):
if isinstance(value, dict):
# for eagle config, we find it in model.eagle_config
for sub_key in value:
value[sub_key] = _get_config_from_eagle_config_or_base_config(sub_key, model)
if value[sub_key] is None:
value[sub_key] = _get_config_from_eagle_config_or_base_config(sub_key, model)
elif value is None:
# First, we try to load fron eagle config.
new_value = _get_config_from_eagle_config_or_base_config(key, model)
Expand Down
12 changes: 9 additions & 3 deletions modelopt/torch/export/plugins/mcore_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@
"final_layernorm": NameRemapping("norm."),
"d2t": NameRemapping("d2t"),
"output_layer": NameRemapping("lm_head."),
"parallel_draft_heads.medusa_layers": NameRemapping("parallel_draft_heads.{}.{}.linear."),
"parallel_draft_heads.medusa_layers": NameRemapping(
"parallel_draft_heads.{}.medusa_layers.{}.linear."
),
"parallel_draft_heads.lm_head": NameRemapping("parallel_draft_heads.{}.lm_head."),
}

Expand All @@ -115,7 +117,9 @@
"final_layernorm": NameRemapping("norm."),
"d2t": NameRemapping("d2t"),
"output_layer": NameRemapping("lm_head."),
"parallel_draft_heads.medusa_layers": NameRemapping("parallel_draft_heads.{}.{}.linear."),
"parallel_draft_heads.medusa_layers": NameRemapping(
"parallel_draft_heads.{}.medusa_layers.{}.linear."
),
"parallel_draft_heads.lm_head": NameRemapping("parallel_draft_heads.{}.lm_head."),
}

Expand All @@ -133,7 +137,9 @@
"final_layernorm": NameRemapping("norm."),
"d2t": NameRemapping("d2t"),
"output_layer": NameRemapping("lm_head."),
"parallel_draft_heads.medusa_layers": NameRemapping("parallel_draft_heads.{}.{}.linear."),
"parallel_draft_heads.medusa_layers": NameRemapping(
"parallel_draft_heads.{}.medusa_layers.{}.linear."
),
"parallel_draft_heads.lm_head": NameRemapping("parallel_draft_heads.{}.lm_head."),
}

Expand Down
32 changes: 8 additions & 24 deletions modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,26 +212,14 @@ def __init__(
)

eagle_config = {
"use_input_layernorm_in_first_layer": mode_cfg["config"][
"eagle_architecture_config"
]["use_input_layernorm_in_first_layer"],
"use_last_layernorm": mode_cfg["config"]["eagle_architecture_config"][
"use_last_layernorm"
],
"use_mtp_layernorm": mode_cfg["config"]["eagle_architecture_config"][
"use_mtp_layernorm"
],
"use_aux_hidden_state": mode_cfg["config"]["eagle_architecture_config"][
"use_aux_hidden_state"
],
"use_input_layernorm_in_first_layer": model.eagle_config.use_input_layernorm_in_first_layer,
"use_last_layernorm": model.eagle_config.use_last_layernorm,
"use_mtp_layernorm": model.eagle_config.use_mtp_layernorm,
"use_aux_hidden_state": model.eagle_config.use_aux_hidden_state,
"eagle_aux_hidden_state_layer_ids": model.eagle_config.eagle_aux_hidden_state_layer_ids,
"next_layer_regular": True,
"parallel_draft_step": mode_cfg["config"]["eagle_architecture_config"][
"parallel_draft_step"
],
"parallel_draft_heads_num_layers": mode_cfg["config"][
"eagle_architecture_config"
]["parallel_draft_heads_num_layers"],
"parallel_draft_step": model.eagle_config.parallel_draft_step,
"parallel_draft_heads_num_layers": model.eagle_config.parallel_draft_heads_num_layers,
}

eagle_config_update = {
Expand All @@ -243,9 +231,7 @@ def __init__(
"max_position_embeddings": self._hf_text_config.max_position_embeddings,
"num_attention_heads": model.eagle_module.config.num_attention_heads,
"num_key_value_heads": model.eagle_module.config.num_query_groups,
"num_hidden_layers": mode_cfg["config"]["eagle_architecture_config"][
"num_hidden_layers"
],
"num_hidden_layers": model.eagle_config.num_layers,
"vocab_size": self._hf_text_config.vocab_size,
# Unset any special token ids given that the tokenizer can change here.
"bos_token_id": None,
Expand All @@ -254,9 +240,7 @@ def __init__(
"sep_token_id": None,
# The following attributes are EAGLE specific
"eagle_config": eagle_config,
"draft_vocab_size": mode_cfg["config"]["eagle_architecture_config"][
"draft_vocab_size"
],
"draft_vocab_size": model.eagle_config.draft_vocab_size,
}

self._hf_extra_config.update(eagle_config_update)
Expand Down
4 changes: 0 additions & 4 deletions modelopt/torch/speculative/eagle/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
default_eagle_config = {
"hidden_act": "silu",
"torch_dtype": "bfloat16",
"vocab_size": 128256,
"draft_vocab_size": 128256,
"max_position_embeddings": 8192,
"position_embedding_type": "rope",
"rope_scaling": {
"factor": 8.0,
Expand All @@ -31,7 +28,6 @@
},
"rope_theta": 500000.0,
"num_hidden_layers": 1,
"hidden_size": 4096,
"intermediate_size": 14336,
"num_attention_heads": 32,
"num_key_value_heads": 8,
Expand Down
6 changes: 0 additions & 6 deletions modelopt/torch/speculative/eagle/eagle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

"""Eagle model to support eagle decoding."""

import torch

from modelopt.torch.opt.dynamic import DynamicModule


Expand Down Expand Up @@ -45,7 +43,3 @@ def modify(
self.eagle_report_acc = eagle_report_acc
self.eagle_reuse_base_decoder = eagle_reuse_base_decoder
self.eagle_loss_decay_factor = eagle_loss_decay_factor

if eagle_architecture_config.get("parallel_draft_step", 1) > 1:
for i in range(eagle_architecture_config.get("parallel_draft_step") - 1):
self.register_buffer(f"mask_token_{i}", torch.tensor(-1))
Loading