Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 59 additions & 22 deletions tensorrt_llm/_torch/models/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,26 +217,46 @@ def filter_weights(prefix: str, weights: dict):
if "self_attn.qkv_proj" in name:
# The weights need to be split correctly before sharding to support tp_size >1.
qkv_weight = module_weights['weight'][:]
q_weight = qkv_weight[:hidden_size, :]
k_weight = qkv_weight[hidden_size:hidden_size +
num_kv_heads * head_dim, :]
v_weight = qkv_weight[hidden_size +
num_kv_heads * head_dim:, :]
qk_split_index = hidden_size
kv_split_index = hidden_size + num_kv_heads * head_dim

q_dict = {'weight': qkv_weight[:qk_split_index, :]}
k_dict = {
'weight':
qkv_weight[qk_split_index:kv_split_index, :]
}
v_dict = {'weight': qkv_weight[kv_split_index:, :]}

# Get the scale factor for the fused QKV projection
qkv_scale = module_weights.get('weight_scale', None)

q_dict = {'weight': q_weight}
if qkv_scale is not None:
q_dict['weight_scale'] = qkv_scale

k_dict = {'weight': k_weight}
if qkv_scale is not None:
k_dict['weight_scale'] = qkv_scale # Use same scale

v_dict = {'weight': v_weight}
if qkv_scale is not None:
v_dict['weight_scale'] = qkv_scale # Use same scale
if qkv_scale.shape and qkv_scale.shape[
0] == qkv_weight.shape[0]:
q_dict[
'weight_scale'] = qkv_scale[:
qk_split_index, :]
k_dict['weight_scale'] = qkv_scale[
qk_split_index:kv_split_index, :]
v_dict['weight_scale'] = qkv_scale[
kv_split_index:, :]
else: # use same scale
q_dict['weight_scale'] = qkv_scale
k_dict['weight_scale'] = qkv_scale
v_dict['weight_scale'] = qkv_scale

input_scale = module_weights.get('input_scale', None)
if input_scale is not None:
q_dict['input_scale'] = input_scale
k_dict['input_scale'] = input_scale
v_dict['input_scale'] = input_scale

weight_scale_2 = module_weights.get(
'weight_scale_2', None)
if weight_scale_2 is not None:
q_dict['weight_scale_2'] = weight_scale_2
k_dict['weight_scale_2'] = weight_scale_2
v_dict['weight_scale_2'] = weight_scale_2

module.load_weights(weights=[q_dict, k_dict, v_dict])
elif "mlp.gate_up_proj" in name:
Expand All @@ -246,16 +266,33 @@ def filter_weights(prefix: str, weights: dict):
gate_weight = gate_up_weight[:intermediate_size, :]
up_weight = gate_up_weight[intermediate_size:, :]

# Get the scale factors if they exist
gate_up_scale = module_weights.get('weight_scale', None)

gate_dict = {'weight': gate_weight}
if gate_up_scale is not None:
gate_dict['weight_scale'] = gate_up_scale

up_dict = {'weight': up_weight}

# Get the scale factors if they exist
gate_up_scale = module_weights.get('weight_scale', None)
if gate_up_scale is not None:
up_dict['weight_scale'] = gate_up_scale
if gate_up_scale.shape and gate_up_scale.shape[
0] == gate_up_weight.shape[0]:
gate_dict[
'weight_scale'] = gate_up_scale[:
intermediate_size, :]
up_dict['weight_scale'] = gate_up_scale[
intermediate_size:, :]
else: # use same scale
gate_dict['weight_scale'] = gate_up_scale
up_dict['weight_scale'] = gate_up_scale

input_scale = module_weights.get('input_scale', None)
if input_scale is not None:
gate_dict['input_scale'] = input_scale
up_dict['input_scale'] = input_scale

weight_scale_2 = module_weights.get(
'weight_scale_2', None)
if weight_scale_2 is not None:
gate_dict['weight_scale_2'] = weight_scale_2
up_dict['weight_scale_2'] = weight_scale_2

module.load_weights(weights=[gate_dict, up_dict])
else:
Expand Down
19 changes: 11 additions & 8 deletions tensorrt_llm/_torch/models/modeling_phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def _load_phi4mm_classes(local_path):
# Add parent folder to sys.path to enable relative import.
original_sys_path = sys.path.copy()
package_folder = Path(local_path)
package_name = package_folder.name
parent_folder = str(package_folder.parent)
if parent_folder not in sys.path:
sys.path.insert(0, parent_folder)

try:
# Import Phi4MMConfig from configuration_phi4mm.py.
config_path = os.path.join(local_path, 'configuration_phi4mm.py')
Expand All @@ -111,8 +111,7 @@ def _load_phi4mm_classes(local_path):
# `Phi-4-multimodal-instruct` as the package name to avoid relative import errors.
# `hf_modeling_phi4mm` as the module name to avoid name conflicts.
spec = importlib.util.spec_from_file_location(
"Phi-4-multimodal-instruct.hf_modeling_phi4mm",
modeling_phi4mm_path)
f"{package_name}.hf_modeling_phi4mm", modeling_phi4mm_path)
hf_modeling_phi4mm = importlib.util.module_from_spec(spec)
spec.loader.exec_module(hf_modeling_phi4mm)
Phi4MMAudioEmbedding = hf_modeling_phi4mm.Phi4MMAudioEmbedding
Expand Down Expand Up @@ -989,12 +988,16 @@ def load_weights(self, weights):
weights = {k: v for k, v in weights.items() if '.lora_' not in k}
# Rename base layer weights.
updated_weights = {}
base_layer_weight_names = [
'weight', 'input_scale', 'weight_scale', 'weight_scale_2'
]
for k in weights.keys():
if 'base_layer.weight' in k:
new_k = k.replace('base_layer.weight', 'weight')
updated_weights[new_k] = weights[k]
else:
updated_weights[k] = weights[k]
new_k = k
for weight_name in base_layer_weight_names:
if f'base_layer.{weight_name}' in k:
new_k = k.replace(f'base_layer.{weight_name}', weight_name)
break
updated_weights[new_k] = weights[k]
weights = updated_weights
self.llm.load_weights(weights)

Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/inputs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,10 @@ def convert_to_conversation_message(
# Check if mdata is a MultimodalData
if isinstance(mdata,
dict) and "modality" in mdata and "data" in mdata:
modality = mdata["modality"]
mdata_modality = mdata["modality"]
if modality == "multiple_image":
modality = "image"
mm_data_tracker.add_data(modality, mdata["data"])
mdata_modality = "image"
mm_data_tracker.add_data(mdata_modality, mdata["data"])
else:
# Add embeddings to the tracker for placeholder handling
mm_data_tracker.add_data(mdata["modality"],
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/defs/accuracy/references/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ mistralai/Mistral-Small-3.1-24B-Instruct-2503:
accuracy: 89.23
microsoft/Phi-4-multimodal-instruct:
- accuracy: 81.19
- quant_algo: FP8
accuracy: 80.82
- quant_algo: NVFP4
accuracy: 69.33
microsoft/Phi-4-multimodal-instruct-long-rope:
- accuracy: 75.85
microsoft/Phi-4-mini-instruct:
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/defs/accuracy/references/mmlu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ mistralai/Ministral-8B-Instruct-2410:
accuracy: 65.96
microsoft/Phi-4-multimodal-instruct:
- accuracy: 69.69
- quant_algo: FP8
accuracy: 68.86
- quant_algo: NVFP4
accuracy: 64.04
microsoft/Phi-4-multimodal-instruct-long-rope:
- accuracy: 65.98
microsoft/phi-4:
Expand Down
18 changes: 18 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3314,6 +3314,24 @@ def test_auto_dtype_long_rope(self):
task = GSM8K(model_name)
task.evaluate(llm)

@skip_pre_blackwell
def test_fp4(self):
model_path = f"{self.MODEL_PATH}-FP4"
with LLM(model_path, max_seq_len=4096) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@skip_pre_hopper
def test_fp8(self):
model_path = f"{self.MODEL_PATH}-FP8"
with LLM(model_path, max_seq_len=4096) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)


@skip_pre_hopper
@pytest.mark.skip_less_device_memory(80000)
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/defs/perf/test_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@
"phi_4_multimodal_instruct": "multimodals/Phi-4-multimodal-instruct",
"phi_4_multimodal_instruct_image": "multimodals/Phi-4-multimodal-instruct",
"phi_4_multimodal_instruct_audio": "multimodals/Phi-4-multimodal-instruct",
"phi_4_multimodal_instruct_fp4_image":
"multimodals/Phi-4-multimodal-instruct-FP4",
"phi_4_multimodal_instruct_fp4_audio":
"multimodals/Phi-4-multimodal-instruct-FP4",
"phi_4_multimodal_instruct_fp8_image":
"multimodals/Phi-4-multimodal-instruct-FP8",
"phi_4_multimodal_instruct_fp8_audio":
"multimodals/Phi-4-multimodal-instruct-FP8",
"bielik_11b_v2.2_instruct": "Bielik-11B-v2.2-Instruct",
"bielik_11b_v2.2_instruct_fp8": "Bielik-11B-v2.2-Instruct-FP8",
"mistral_small_v3.1_24b": "Mistral-Small-3.1-24B-Instruct-2503",
Expand Down Expand Up @@ -177,6 +185,14 @@
"multimodals/Phi-4-multimodal-instruct/vision-lora",
"phi_4_multimodal_instruct_audio":
"multimodals/Phi-4-multimodal-instruct/speech-lora",
"phi_4_multimodal_instruct_fp4_image":
"multimodals/Phi-4-multimodal-instruct-FP4/vision-lora",
"phi_4_multimodal_instruct_fp4_audio":
"multimodals/Phi-4-multimodal-instruct-FP4/speech-lora",
"phi_4_multimodal_instruct_fp8_image":
"multimodals/Phi-4-multimodal-instruct-FP8/vision-lora",
"phi_4_multimodal_instruct_fp8_audio":
"multimodals/Phi-4-multimodal-instruct-FP8/speech-lora",
}

TIMING_CACHE_DIR = os.environ.get("TIMING_CACHE_DIR", "")
Expand Down
Loading