Skip to content

Commit

Permalink
[Misc] Standardize RoPE handling for Qwen2-VL (vllm-project#9250)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Oct 16, 2024
1 parent ed92013 commit 7e7eae3
Show file tree
Hide file tree
Showing 16 changed files with 102 additions and 200 deletions.
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def benchmark_rope_kernels_multi_lora(
# batched RoPE can take multiple scaling factors
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, {
"type": "linear",
"rope_type": "linear",
"factor": tuple(scaling_factors)
})
# non-batched RoPE takes only one scaling factor, we create multiple
Expand All @@ -41,7 +41,7 @@ def benchmark_rope_kernels_multi_lora(
non_batched_ropes.append(
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
{
"type": "linear",
"rope_type": "linear",
"factor": (scaling_factor, )
}))

Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ numpy < 2.0.0
requests >= 2.26.0
tqdm
py-cpuinfo
transformers >= 4.45.0 # Required for Llama 3.2.
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
Expand Down
8 changes: 4 additions & 4 deletions tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_batched_rotary_embedding(
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"type": "linear",
"rope_type": "linear",
"factor": (1, )
})
rope = rope.to(dtype=dtype)
Expand Down Expand Up @@ -166,7 +166,7 @@ def test_batched_rotary_embedding_multi_lora(
rotary_dim = head_size
scaling_factors: List[int] = [1, 2, 4]
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"type": "linear",
"rope_type": "linear",
"factor": tuple(scaling_factors)
})
rope = rope.to(dtype=dtype)
Expand Down Expand Up @@ -211,10 +211,10 @@ def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000]
ROPE_SCALINGS = (None, {
"type": "linear",
"rope_type": "linear",
"factor": (1, )
}, {
"type": "dynamic",
"rope_type": "dynamic",
"factor": 1
})
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
Expand Down
2 changes: 1 addition & 1 deletion tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
lora_rope.create_lora_weights(max_loras, lora_config)
linear_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, {
"type": "linear",
"rope_type": "linear",
"factor": scaling_factors
})
linear_rope = linear_rope.to(dtype=dtype)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def test_get_sliding_window():


def test_rope_customization():
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
TEST_ROPE_THETA = 16_000_000.0
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}
LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}

llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
Expand Down
21 changes: 7 additions & 14 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,16 +1739,10 @@ def _get_and_verify_max_len(

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
if "type" in rope_scaling:
rope_type = rope_scaling["type"]
elif "rope_type" in rope_scaling:
rope_type = rope_scaling["rope_type"]
else:
raise ValueError(
"rope_scaling must have a 'type' or 'rope_type' key.")
# No need to consider "type" key because of patch_rope_scaling when
# loading HF config
rope_type = rope_scaling["rope_type"]

# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_type not in ("su", "longrope", "llama3"):
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
Expand All @@ -1758,11 +1752,10 @@ def _get_and_verify_max_len(
"with rope_scaling. Please raise an issue so we can "
"investigate.")

if rope_type == "mrope":
scaling_factor = 1
else:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
# NOTE: rope_type == "default" does not define factor
# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
scaling_factor = rope_scaling.get("factor", 1.0)

if rope_type == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
Expand Down
11 changes: 6 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument('--rope-scaling',
default=None,
type=json.loads,
help='RoPE scaling configuration in JSON format. '
'For example, {"type":"dynamic","factor":2.0}')
parser.add_argument(
'--rope-scaling',
default=None,
type=json.loads,
help='RoPE scaling configuration in JSON format. '
'For example, {"rope_type":"dynamic","factor":2.0}')
parser.add_argument('--rope-theta',
default=None,
type=float,
Expand Down
47 changes: 28 additions & 19 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,13 +920,10 @@ def get_rope(
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:
scaling_type = rope_scaling[
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type not in {"su", "longrope"}:
scaling_factor = rope_scaling.get("factor", 1.0)
scaling_type = rope_scaling["rope_type"]

if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling[
Expand All @@ -937,16 +934,39 @@ def get_rope(
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
Expand All @@ -961,6 +981,7 @@ def get_rope(
scaling_factor, dtype,
**extra_kwargs)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
Expand All @@ -973,9 +994,7 @@ def get_rope(
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
# The correct one should be "longrope" but keep "su" here
# for backward compatible
elif scaling_type == "su" or scaling_type == "longrope":
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
Expand All @@ -989,16 +1008,6 @@ def get_rope(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)
elif scaling_type == "mrope":
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __init__(
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling['type'] = 'deepseek_yarn'
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
rope_scaling["factor"] = self.rope_position_scale
else:
rope_scaling = {
"type": "linear",
"rope_type": "linear",
"factor": self.rope_position_scale,
}

Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from transformers.image_utils import (get_image_size,
infer_channel_dimension_format,
to_numpy_array)
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLConfig, Qwen2VLVisionConfig)
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
make_batched_images, make_batched_videos, smart_resize)

Expand Down Expand Up @@ -62,8 +64,7 @@
from vllm.multimodal.image import cached_get_image_processor
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
Qwen2VLVisionConfig)
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import get_processor
from vllm.utils import is_cpu

Expand Down Expand Up @@ -1061,8 +1062,7 @@ def forward(
if image_input is None and video_input is None:
inputs_embeds = None
else:
rope_scaling = getattr(self.config, "rope_scaling", {})
if rope_scaling.get("type", None) == "mrope":
if uses_mrope(self.config):
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}")
Expand Down
44 changes: 41 additions & 3 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config,
Qwen2VLConfig, RWConfig,
SolarConfig, UltravoxConfig)
RWConfig, SolarConfig,
UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file

Expand Down Expand Up @@ -57,7 +57,6 @@
"NVLM_D": NVLM_D_Config,
"solar": SolarConfig,
"ultravox": UltravoxConfig,
"qwen2_vl": Qwen2VLConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
}

Expand Down Expand Up @@ -91,6 +90,43 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
return False


def patch_rope_scaling(config: PretrainedConfig) -> None:
"""Provide backwards compatibility for RoPE."""
text_config = getattr(config, "text_config", None)
if text_config is not None:
patch_rope_scaling(text_config)

rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None:
patch_rope_scaling_dict(rope_scaling)


def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None:
if "rope_type" not in rope_scaling and "type" in rope_scaling:
rope_scaling["rope_type"] = rope_scaling["type"]
logger.info("Replacing legacy 'type' key with 'rope_type'")

if "rope_type" not in rope_scaling:
raise ValueError("rope_scaling should have a 'rope_type' key")

if rope_scaling["rope_type"] == "su":
rope_scaling["rope_type"] = "longrope"
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
elif rope_scaling["rope_type"] == "mrope":
assert "mrope_section" in rope_scaling
rope_scaling["rope_type"] = "default"
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")


def uses_mrope(config: PretrainedConfig) -> bool:
"""Detect if the model with this config uses M-ROPE."""
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is None:
return False

return "mrope_section" in rope_scaling


def get_config(
model: Union[str, Path],
trust_remote_code: bool,
Expand Down Expand Up @@ -191,6 +227,8 @@ def get_config(
)
config.update({key: value})

patch_rope_scaling(config)

return config


Expand Down
4 changes: 0 additions & 4 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
Qwen2VLVisionConfig)
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

Expand All @@ -35,6 +33,4 @@
"NVLM_D_Config",
"SolarConfig",
"UltravoxConfig",
"Qwen2VLConfig",
"Qwen2VLVisionConfig",
]
Loading

0 comments on commit 7e7eae3

Please sign in to comment.