Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correction to TP logic for Mamba Mixer 2 when Num Groups not divisible by TP Size #13660

Merged
merged 2 commits into from
Feb 22, 2025

Conversation

fabianlim
Copy link
Contributor

@fabianlim fabianlim commented Feb 21, 2025

The current logic is incorrect for the case when n_groups cannot can be divided by TP size, this requires more change to the kernels. This is because it currently uses this simple ratio logic to map the head to the correct group. However, one can imagine in the more general case it is more complicated.

For example, n_groups=3, n_heads=15, and TP_size=5, then in this case we end up with this kind of splitting of the heads (in the below, the numbers 0 - 2 represent which each of the 15 heads map to )

000 | 001 | 111 | 122 | 222

In this case, we will end up with a very heterogenous situation, were we need to

  • duplicate to at least 2 groups per shard,
  • cannot rely on the ratio logic if we want to stay with 2 groups per shard.

Of course, if we duplicated groups to the extent that they equal heads, then its possible, but not sure if there is an more efficient method.

Current Strategy in this PR

For now, maybe its easier to patch it to specicially only support the following two special cases

  1. If TP size divides n_groups,
  2. If TP_size does not divide n_groups, but n_groups == 1.

These two scenarios support existing models such as Codestral, Bamba, Zamba, etc.., where n_groups are either 1 or some power of 2

cc: @tlrmchlsmth @yury-tokpanov

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@DarkLight1337
Copy link
Member

Thanks for fixing, can we add some tests to avoid future regressions?

@fabianlim
Copy link
Contributor Author

yes currently we have a TP test for selected models, but AFAIK the tests for mamba2 are not yet automated, and there is some problem running them in a suite because they are not setup properly using torch multiprocessing, but let me discuss this with @tlrmchlsmth

@tlrmchlsmth tlrmchlsmth self-assigned this Feb 21, 2025
@tlrmchlsmth
Copy link
Collaborator

The current logic is incorrect for the case when n_groups cannot be divided by TP size

I'm a little confused by this because we're seeing poor gsm8k results on TP size 2 on Mamba Codestral, which has n_groups == 8, so n_groups is divisible in this case.

For reference, these are GSM8k results on Codestral-7B from @yury-tokpanov
TP=1

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4731|±  |0.0138|
|     |       |strict-match    |     5|exact_match|↑  |0.4610|±  |0.0137|

vs TP=2

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.0167|±  |0.0035|
|     |       |strict-match    |     5|exact_match|↑  |0.0076|±  |0.0024|

@tlrmchlsmth
Copy link
Collaborator

For now, maybe its easier to patch it to specicially only support the following two special cases

  1. If TP size divides n_groups,
  2. If TP_size does not divide n_groups, but n_groups == 1.

These two scenarios support existing models such as Codestral, Bamba, Zamba, etc.., where n_groups are either 1 or some power of 2

This does make sense to me, and I'm on board with this approach

self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

assert num_heads % self.tp_size == 0, \
"Tensor parallel world size must divide num heads."

assert (n_groups % self.tp_size) != 0 and n_groups == 0, \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be:

assert (n_groups % self.tp_size) == 0 or n_groups == 1, ...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hit this assert when running

lm_eval --model vllm --model_args pretrained=mistralai/Mamba-Codestral-7B-v0.1,gpu_memory_utilization=0.8,max_model_len=4096,tensor_parallel_size=2 --batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k

but with my suggestion, the TP==2 gsm8k results look good:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4701|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.4549|±  |0.0137|

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry you are right

@fabianlim
Copy link
Contributor Author

I'm a little confused by this because we're seeing poor gsm8k results on TP size 2 on Mamba Codestral, which has

@tlrmchlsmth sorry that was a typo, i fixed it in the main description

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@yury-tokpanov
Copy link

I think originally groups were introduced to support easy TP for mamba2, right? So, in the original logic n_groups should always be divisible by TP size.

I'm wondering, whether the general case for n_groups being not divisible by TP size would be quite exotic. Probably supporting the special case of n_groups == 1 would be enough for now.

@yury-tokpanov
Copy link

Thanks for the fix and all the work on TP for mamba2!

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 21, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the fix!

@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) February 21, 2025 22:20
@simon-mo simon-mo merged commit fca2084 into vllm-project:main Feb 22, 2025
55 of 58 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants