Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
from tpu_inference.models.jax.qwen2_5_vl import \
Qwen2_5_VLForConditionalGeneration
from tpu_inference.models.jax.qwen2_vl import \
Qwen2VLForConditionalGeneration
from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM

if os.getenv("NEW_MODEL_DESIGN", False):
Expand All @@ -52,6 +54,8 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
_MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
_MODEL_REGISTRY[
"Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
_MODEL_REGISTRY[
"Qwen2VLForConditionalGeneration"] = Qwen2VLForConditionalGeneration
_MODEL_REGISTRY["Phi3ForCausalLM"] = Phi3ForCausalLM
_MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM

Expand Down Expand Up @@ -290,11 +294,30 @@ def get_vllm_model(
)
params, lora_manager = model.load_weights()

# Get M-ROPE function if the model supports it (must be after load_weights)
get_mrope_input_positions_fn = None
if hasattr(model, 'model') and hasattr(model.model, 'vllm_model'):
vllm_model = model.model.vllm_model
if hasattr(vllm_model, 'get_mrope_input_positions'):
get_mrope_input_positions_fn = vllm_model.get_mrope_input_positions
logger.info(
f"Found get_mrope_input_positions function in {type(vllm_model).__name__}"
)
else:
logger.info(
f"No get_mrope_input_positions function found in {type(vllm_model).__name__}"
)
else:
logger.info(
"Could not access vllm_model to check for get_mrope_input_positions"
)

jit_model = model.jit_step_func()
compute_logits_fn = model.jit_compute_logits_func()
# the model needs to be returned because lora weights are neither torch.nn.parameter nor torch.nn.buffer. After we load the lora weights and set it to the torch.nn.Module, we can shard it and move it to TPU.
combine_hidden_states_fn = None
return jit_model, compute_logits_fn, combine_hidden_states_fn, None, None, None, params, lora_manager, model

return jit_model, compute_logits_fn, combine_hidden_states_fn, None, None, get_mrope_input_positions_fn, params, lora_manager, model


def get_model(
Expand Down
Loading