-
Notifications
You must be signed in to change notification settings - Fork 14.4k
model: try to improve Qwen3 Next #18683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
I like the idea (and it surely does make the computation clearer), but I'm not sure if people would appreciate having to recreate and redownload GGUFs at this point for a 1% performance increase. I don't mind, but would like to see what people who use the model more think first. EDIT: never mind, I only now looked at the code in detail and you added a compatibility path. Yeah, I think it's worth to do it :) maybe it'll help more on other backends? |
|
Yeah I would appreciate helps on testing other backend too. Curious to see if it's better on CUDA. The apple unified memory thingy can make things a bit complicated to measure. (My optimization should only take effects on system with constrained memory bandwidth) |
|
Do I need to do anything special to test it (e.g. regenerate ggufs)? |
|
@jeffbolznv yes, you will need to generate the GGUF with the PR: Download the model here: https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct Then convert it as normal: python ../../llama.cpp/convert_hf_to_gguf.py . --outtype f16 --outfile model.gguf |
|
It's actually BF16 :)
|
|
I understand that many 1% improvements add up to something significant. |
|
a merged QKV in normal models gives about ~3-4% improvement in PP last time I checked. #16813. However, many backends already support computing these things in parallel so the METAL backend perhaps won't benefit as much. You can try the CUDA backend, I suspect the gains will be slightly more pronounced |
|
I have limited cycles for the next week or so, if somebody can make a quantized model easily accessible it'll be easier for me to try it. |
|
I don't have a CUDA device at the moment, so I simply did the test on AI MAX+ 395. Master This |
|
Btw, I'm also planning to apply the #18550 functionality to Qwen3 Next when it is ready in order to make the ggml graphs static in memory and avoid the graph reallocations that currently occur because of the chunking. It still won't allow CUDA graphs to be used, but at least we should be able to avoid a lot of overhead that we currently have from the changing number of nodes based on the batch size. |
|
@am17an the problem with qwen3 next is that qkv is not actually used for attention. Instead, this qkvz tensor is used for ssm conv (naming is a bit confusing indeed). The old logic does: qkvz projection --> permute --> concat which is redundant (compared to attention: projection --> rope --> permute --> attention) As I expected, seems like this really have an improvement for system with less memory bandwidth as @IIIIIllllIIIIIlllll confirmed. The projected qkvz tensors is large, so it should be faster if we don't need to do any permutation on it. |
|
@jeffbolznv I'll try to upload a q8_0 later today |
|
I uploaded a q8_0 here: https://huggingface.co/ngxson/qwen3_next_fixed/tree/main (converted from Instruct version) It can be used to test against q8_0 already exist on the internet: https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct-GGUF/blob/main/Qwen3-Next-80B-A3B-Instruct-Q8_0.gguf Note: @pwilkin we usually test on f16 because it's the most compatible type cross all backend. Some backend like CPU internally convert bf16 to f16 if the hardware doesn't support it; for hardware that supports bf16, there should be no differences between the 2 But that's not important for this PR: we are comparing perf before/after so the importance is to make sure we're using the same type |
|
Ah, okay, got it. |
|
@ggerganov I'm thinking of this idea: can we enforce the number of chunks? For example, we can enforce cgraph to always allocate This should make the cgraph topology to be more static, although, it is unclear for me if CUDA graph expects the tensor shapes to be unchanged (the case of 0 elements for unused chunk that I talked above) One extra question: I remember we had a mechanism in ggml to detect no-op (like |
|
My test system is an RTX 5090 with 32GB, so I'm not sure I can get useful data from this 80GB model. I ran with |
Not sure how robust ggml is for zero-sized tensors. Think it will need some significant changes to be compatible. But we can definitely give this idea a try.
Hm, not sure I understand. Transpose is already a noop: llama.cpp/ggml/src/ggml-impl.h Lines 87 to 100 in 2524c26
|
|
@ggerganov what I mean is that there can be also cases where |
System Config: DGX Spark (Grace CPU + Blackwell GPU)# OS & Kernel
Description: Ubuntu 24.04.3 LTS
6.11.0-1016-nvidia
# CPU (Grace)
Architecture: aarch64
Byte Order: Little Endian
Vendor ID: ARM
Model name: Cortex-X925
Model: 1
Thread(s) per core: 1
Model name: Cortex-A725
Model: 1
Thread(s) per core: 1
# GPU (Blackwell)
name, driver_version, memory.total [MiB], compute_cap
NVIDIA GB10, 580.95.05, [N/A], 12.1
# CUDA Version
Cuda compilation tools, release 13.0, V13.0.88I ran the following command for the benchmarks (same as used in comment): ./build/bin/llama-bench -m models/Qwen3-Next-80B-A3B-Instruct-bf16-Q8_0.gguf -r 5 -p 2048 -n 32 -b 2048 -ub 2048 -fa 1 -mmp 0Let me know if there are other configurations that you'd like me to run. master: | model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0 | 78.98 GiB | 79.67 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 899.72 ± 2.06 |
| qwen3next 80B.A3B Q8_0 | 78.98 GiB | 79.67 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 29.77 ± 0.04 |
build: 8ece3836b (7681)
PR: | model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0 | 78.98 GiB | 79.67 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 924.32 ± 2.21 |
| qwen3next 80B.A3B Q8_0 | 78.98 GiB | 79.67 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 31.20 ± 0.52 |
build: 65602e899 (7686)
|
|
I removed 2 redundant transposes which improved from 793 --> 820 t/s in my test (GGUF stays the same). Would appreciate if you can rerun the test with latest commit |
|
With d96eb69
|
|
Here's a quick before/after for Q2_K on Vulkan RTX 5090. Seems like it's at the limits of vidmem and occasionally spills and gets very slow, both with and without this change. Looks like ~5% faster. |
|
Results using d96eb69: | model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0 | 78.98 GiB | 79.67 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 940.81 ± 2.31 |
| qwen3next 80B.A3B Q8_0 | 78.98 GiB | 79.67 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 31.44 ± 0.36 |
build: 8c1336cb9 (7697) |
|
I ran Working on optimizing chunking logic today, hope to get a bit more perf out of it |
|
So I just pushed some changes that bumped the perf in my tests from 820 to 960 t/s on pp512, very significant improvement. So in total, we're getting ~20% boost between master and PR. The 2 most important commits are:
But first, would appreciate if you can run a simple On my mac m3 ultra: |
TG improved by 1t/s, rest also got a small bump. |
|
The test results are still AI MAX+ 395. The performance has indeed improved significantly. ============================== cmd: output: ============================== cmd: output: |
| //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? | ||
| ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ggerganov do you have any ideas why the commented version can be slower? (I'm testing on Metal)
both tensors are contiguous, here are the shapes:
k_gdiff ne: 128 64 1 32
v_new ne: 128 64 1 32
|
Maybe another 5-10% depending on mode with the latest. The numbers continue to be noisy (likely due to spilling) but it's clearly improved. |
|
@jeffbolznv could you show the log of graph splits? (I don't quite remember which env var to enable it) On metal, I got some splits because ggml_pad is not supported, but it can be added in another PR. I'm curious to see why vulkan graph is split |
|
Alright, this is as far as I can go for now. My last commits don't improve much on performance (at least on metal), but it further reduced the number of nodes. On master branch, we current have 11k nodes in total, and this PR brought it down to ~9k nodes. |
|
There is no further performance improvement on my AI MAX+ 395 either. |
|
I'm not at my computer for a while, but I dont think there should be any splits, except for the initial layer(s). To enable the logging set GGML_SCHED_DEBUG=2 and use -v. |
I was quite curious why there was a function called
fix_query_key_value_orderingin transformers code (which was mirrored into llama.cpp implementation). Just wondering what are they trying to fix.Turns out, the projected QKVZ was in a wrong order. Not sure why they don't fix the original weight instead of permuting the results. But that's not very not important.
I took my trusty pen & paper to see what can be done:
Since the projected matrix is big, I supposed it will give a significant boost in perf. But I was quite disappointed to see only 1% improvement in pp512. So I don't even know if it's worth the fix. GGUF will need to reconverted to take advantage of this.
(Update: there are more improvements that allow from 5% all the way to 20% boost depending on backend, see my comment below)
master:
PR:
Hardware: mac studio m3 256gb ram
I uploaded a q8_0 here: https://huggingface.co/ngxson/qwen3_next_fixed/tree/main (converted from Instruct version)
It can be used to test against q8_0 already exist on the internet: https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct-GGUF/blob/main/Qwen3-Next-80B-A3B-Instruct-Q8_0.gguf