Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions example/qwen3_5/hf_fwd_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import argparse

import torch

try:
from transformers import Qwen3_5MoeForConditionalGeneration
except:
print(f"your install the tranformers>=5.2.0 or install from source")

from example.qwen3_5.load_model_and_forward import get_sample_for_forward

if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser(description="Load model and generate text")
parser.add_argument(
"--model_path", type=str, required=True, help="HuggingFace model path"
)
parser.add_argument(
"--sample_type",
type=str,
default="image",
choices=["image", "video", "mix"],
help="sample type",
)
args = parser.parse_args()

# default: Load the model on the available device(s)
torch.set_grad_enabled(False)
model = Qwen3_5MoeForConditionalGeneration.from_pretrained(
args.model_path,
dtype="auto",
device_map="cuda:0",
)

for pname, params in model.named_parameters():
print(f"Model weight {pname=} {params.shape} {params.dtype} {params.sum()}")

# Preparation for inference
inputs = get_sample_for_forward(args.model_path, args.sample_type)

for k in inputs:
inputs[k] = inputs[k].cuda()

# Inference: Generation of the output
hf_output = model.forward(**inputs)

print(hf_output.logits.shape, hf_output.logits.device, hf_output.logits.dtype)
torch.save(hf_output.logits.cpu(), "qwen3_5_save/hf_qwen3_5.pt")

print(f"hf Done")
Loading
Loading