Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)

MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
MOE_MODEL = "Qwen/Qwen1.5-MoE-A2.7B"
MOE_MODEL = "allenai/OLMoE-1B-7B-0125"


def get_test_actor_config(model: str = MODEL) -> SkyRLTrainConfig:
Expand Down
24 changes: 16 additions & 8 deletions tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,16 @@
# TODO (erictang000): we would prefer to use this smaller MoE model for testing, but seeing incorrect logprobs when using EP > 1
# this might be a model specific mbridge issue - see if this persists when we transition to Megatron-Bridge
# MOE_MODEL_NAME = "Qwen/Qwen1.5-MoE-A2.7B"
MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B"
QWEN_MOE_FALLBACK_MODEL_NAME = "Qwen/Qwen1.5-MoE-A2.7B"
MOE_MODEL_NAME = "allenai/OLMoE-1B-7B-0125"


def get_test_megatron_worker_model(ep: int) -> str:
# TODO (devpatel): keep OLMoE as the CI MoE default, but use the Qwen MoE fallback for the
# EP worker tests since TP > 1 is causing issues in bridge. Remove this once OLMoE's bridge path is fixed.
if ep > 1:
return QWEN_MOE_FALLBACK_MODEL_NAME
return MODEL_NAME


def get_test_actor_config(model_name=MODEL_NAME) -> SkyRLTrainConfig:
Expand Down Expand Up @@ -232,7 +241,8 @@ async def test_megatron_forward(
"""
Test that the Megatron forward pass is numerically equivalent to just running a huggingface model forward.
"""
cfg = get_test_actor_config(model_name=MOE_MODEL_NAME if ep > 1 else MODEL_NAME)
model_name = get_test_megatron_worker_model(ep)
cfg = get_test_actor_config(model_name=model_name)
#### Megatron forward pass ####
cfg.trainer.strategy = "megatron"
cfg.trainer.placement.policy_num_gpus_per_node = gpus_per_node
Expand Down Expand Up @@ -300,9 +310,7 @@ def run_hf_forward(batch, model_name):

return attention_mask.to("cpu").detach(), action_log_probs.to("cpu").detach(), num_actions

attention_mask, action_log_probs, num_actions = ray.get(
run_hf_forward.remote(batch, MOE_MODEL_NAME if ep > 1 else MODEL_NAME)
)
attention_mask, action_log_probs, num_actions = ray.get(run_hf_forward.remote(batch, model_name))

#### Compare results ####
# compare just non-padding tokens
Expand Down Expand Up @@ -352,7 +360,7 @@ async def test_megatron_lora_forward(ray_init_fixture, tp, pp, cp, ep, etp, gpus
"""
Test that the Megatron + lora forward pass is numerically equivalent to just running a megatron model forward.
"""
cfg = get_test_actor_config(model_name=MOE_MODEL_NAME if ep > 1 else MODEL_NAME)
cfg = get_test_actor_config(model_name=get_test_megatron_worker_model(ep))
#### Megatron forward pass ####
cfg.trainer.strategy = "megatron"
cfg.trainer.placement.policy_num_gpus_per_node = gpus_per_node
Expand Down Expand Up @@ -461,7 +469,7 @@ async def test_megatron_train(
"""
Full test: initialize actor group, send dummy experience to training_step, validate output.
"""
cfg = get_test_actor_config(model_name=MODEL_NAME if ep == 1 else MOE_MODEL_NAME)
cfg = get_test_actor_config(model_name=get_test_megatron_worker_model(ep))
batch = get_test_training_batch(batch_size=gpus_per_node)

cfg.trainer.strategy = "megatron"
Expand Down Expand Up @@ -724,7 +732,7 @@ async def test_megatron_offload_memory_and_correctness(ray_init_fixture, worker_
6. Backload model to GPU and check memory usage.
7. Run another training step and ensure output consistency.
"""
cfg = get_test_actor_config(MOE_MODEL_NAME) # use MoE model for testing
cfg = get_test_actor_config(QWEN_MOE_FALLBACK_MODEL_NAME) # use Qwen MoE fallback for this worker test
cfg.trainer.strategy = "megatron"
# 0 learning rate and wd so we can optimizer step to free gradients but still check results are the same
getattr(cfg.trainer, worker_type).optimizer_config.lr = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
init_worker_with_type,
)

MOE_MODEL_NAME = "moonshotai/Moonlight-16B-A3B-Instruct"
MOE_MODEL_NAME = "allenai/OLMoE-1B-7B-0125"
NUM_PROMPTS = 5
N_SAMPLES_PER_PROMPT = 2
MAX_GENERATE_LENGTH = 128
Expand Down
Loading