Skip to content

Conversation

@ngxson
Copy link
Collaborator

@ngxson ngxson commented Jan 8, 2026

I was quite curious why there was a function called fix_query_key_value_ordering in transformers code (which was mirrored into llama.cpp implementation). Just wondering what are they trying to fix.

Turns out, the projected QKVZ was in a wrong order. Not sure why they don't fix the original weight instead of permuting the results. But that's not very not important.

I took my trusty pen & paper to see what can be done:

image

Since the projected matrix is big, I supposed it will give a significant boost in perf. But I was quite disappointed to see only 1% improvement in pp512. So I don't even know if it's worth the fix. GGUF will need to reconverted to take advantage of this.

(Update: there are more improvements that allow from 5% all the way to 20% boost depending on backend, see my comment below)

master:

model size params backend threads test t/s
qwen3next 80B.A3B F16 148.50 GiB 79.67 B Metal,BLAS 24 pp512 790.87 ± 17.93
qwen3next 80B.A3B F16 148.50 GiB 79.67 B Metal,BLAS 24 tg128 22.33 ± 0.08

PR:

model size params backend threads test t/s
qwen3next 80B.A3B F16 148.50 GiB 79.67 B Metal,BLAS 24 pp512 965.15 ± 8.00
qwen3next 80B.A3B F16 148.50 GiB 79.67 B Metal,BLAS 24 tg128 28.87 ± 0.06

Hardware: mac studio m3 256gb ram


I uploaded a q8_0 here: https://huggingface.co/ngxson/qwen3_next_fixed/tree/main (converted from Instruct version)

It can be used to test against q8_0 already exist on the internet: https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct-GGUF/blob/main/Qwen3-Next-80B-A3B-Instruct-Q8_0.gguf

@ngxson ngxson requested a review from CISC as a code owner January 8, 2026 00:00
@ngxson ngxson requested review from pwilkin and removed request for CISC January 8, 2026 00:00
@ngxson ngxson marked this pull request as draft January 8, 2026 00:00
@pwilkin
Copy link
Collaborator

pwilkin commented Jan 8, 2026

I like the idea (and it surely does make the computation clearer), but I'm not sure if people would appreciate having to recreate and redownload GGUFs at this point for a 1% performance increase. I don't mind, but would like to see what people who use the model more think first.

EDIT: never mind, I only now looked at the code in detail and you added a compatibility path. Yeah, I think it's worth to do it :) maybe it'll help more on other backends?

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 8, 2026

Yeah I would appreciate helps on testing other backend too. Curious to see if it's better on CUDA. The apple unified memory thingy can make things a bit complicated to measure.

(My optimization should only take effects on system with constrained memory bandwidth)

@jeffbolznv
Copy link
Collaborator

Do I need to do anything special to test it (e.g. regenerate ggufs)?

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 8, 2026

@jeffbolznv yes, you will need to generate the GGUF with the PR:

Download the model here: https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct

Then convert it as normal:

python ../../llama.cpp/convert_hf_to_gguf.py . --outtype f16 --outfile model.gguf

@pwilkin
Copy link
Collaborator

pwilkin commented Jan 8, 2026

It's actually BF16 :)

"torch_dtype": "bfloat16"

@jacekpoplawski
Copy link
Contributor

I understand that many 1% improvements add up to something significant.
If someone publishes a quantized GGUF, it will be much easier for people to benchmark it on different configurations.

@am17an
Copy link
Collaborator

am17an commented Jan 8, 2026

a merged QKV in normal models gives about ~3-4% improvement in PP last time I checked. #16813. However, many backends already support computing these things in parallel so the METAL backend perhaps won't benefit as much. You can try the CUDA backend, I suspect the gains will be slightly more pronounced

@jeffbolznv
Copy link
Collaborator

I have limited cycles for the next week or so, if somebody can make a quantized model easily accessible it'll be easier for me to try it.

@IIIIIllllIIIIIlllll
Copy link

I don't have a CUDA device at the moment, so I simply did the test on AI MAX+ 395.

Master

command: /home/mark/llama.cpp/llama.cpp-master/build/bin/llama-bench -m /home/mark/Models/Q8/Qwen3-Next-80B-A3B-Instruct-Q8_0/next-master.gguf -r 5 -p 2048 -n 32 -b 2048 -ub 2048 -fa 1 -mmp 0

ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        579.45 ± 1.62 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         28.98 ± 0.04 |

build: unknown (0)

This

command: /home/mark/llama.cpp/llama.cpp-xsn-qwen3next_improve/build/bin/llama-bench -m /home/mark/Models/Q8/Test-Qwen3-Next/model.gguf -r 5 -p 2048 -n 32 -b 2048 -ub 2048 -fa 1 -mmp 0

ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        593.37 ± 1.45 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         29.18 ± 0.03 |

build: unknown (0)

@ggerganov
Copy link
Member

Btw, I'm also planning to apply the #18550 functionality to Qwen3 Next when it is ready in order to make the ggml graphs static in memory and avoid the graph reallocations that currently occur because of the chunking. It still won't allow CUDA graphs to be used, but at least we should be able to avoid a lot of overhead that we currently have from the changing number of nodes based on the batch size.

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 8, 2026

@am17an the problem with qwen3 next is that qkv is not actually used for attention. Instead, this qkvz tensor is used for ssm conv (naming is a bit confusing indeed). The old logic does: qkvz projection --> permute --> concat which is redundant (compared to attention: projection --> rope --> permute --> attention)

As I expected, seems like this really have an improvement for system with less memory bandwidth as @IIIIIllllIIIIIlllll confirmed. The projected qkvz tensors is large, so it should be faster if we don't need to do any permutation on it.

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 8, 2026

@jeffbolznv I'll try to upload a q8_0 later today

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 8, 2026

I uploaded a q8_0 here: https://huggingface.co/ngxson/qwen3_next_fixed/tree/main (converted from Instruct version)

It can be used to test against q8_0 already exist on the internet: https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct-GGUF/blob/main/Qwen3-Next-80B-A3B-Instruct-Q8_0.gguf

Note: @pwilkin we usually test on f16 because it's the most compatible type cross all backend. Some backend like CPU internally convert bf16 to f16 if the hardware doesn't support it; for hardware that supports bf16, there should be no differences between the 2

But that's not important for this PR: we are comparing perf before/after so the importance is to make sure we're using the same type

@pwilkin
Copy link
Collaborator

pwilkin commented Jan 8, 2026

Ah, okay, got it.

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 8, 2026

@ggerganov I'm thinking of this idea: can we enforce the number of chunks? For example, we can enforce cgraph to always allocate n_chunks = ubatch_size / CHUNK_SIZE, and unused chunks will have 0 elements, which allow backend to skip them.

This should make the cgraph topology to be more static, although, it is unclear for me if CUDA graph expects the tensor shapes to be unchanged (the case of 0 elements for unused chunk that I talked above)

One extra question: I remember we had a mechanism in ggml to detect no-op (like GGML_OP_VIEW). I'm wondering if it can be beneficial to extend it to consider singleton as no-op (for example: transpose a tensor with n_elements = 1)

@jeffbolznv
Copy link
Collaborator

My test system is an RTX 5090 with 32GB, so I'm not sure I can get useful data from this 80GB model. I ran with -ngl 19 which is the most layers I can fit, and see like a 10x slowdown with the new model/changes but it's probably either due to paging or due to worse CPU inferencing performance. I've done my previous qwen3next testing on a Q2_K_L model which is around 29GB.

@ggerganov
Copy link
Member

@ggerganov I'm thinking of this idea: can we enforce the number of chunks? For example, we can enforce cgraph to always allocate n_chunks = ubatch_size / CHUNK_SIZE, and unused chunks will have 0 elements, which allow backend to skip them.

Not sure how robust ggml is for zero-sized tensors. Think it will need some significant changes to be compatible. But we can definitely give this idea a try.

One extra question: I remember we had a mechanism in ggml to detect no-op (like GGML_OP_VIEW). I'm wondering if it can be beneficial to extend it to consider singleton as no-op (for example: transpose a tensor with n_elements = 1)

Hm, not sure I understand. Transpose is already a noop:

static bool ggml_op_is_empty(enum ggml_op op) {
switch (op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
return true;
default:
return false;
}
}

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 8, 2026

@ggerganov what I mean is that there can be also cases where ggml_cont on a contiguous tensor, or ggml_sum_rows on a tensor with t->ne[0] == 1, etc, which can all be consider as no-op. That's just an idea for maybe merge 2 branches: chunk and autoregressive together in the future (although, it can be much more difficult to actually do it)

@danbev
Copy link
Member

danbev commented Jan 9, 2026

System Config: DGX Spark (Grace CPU + Blackwell GPU)
# OS & Kernel
Description:	Ubuntu 24.04.3 LTS
6.11.0-1016-nvidia

# CPU (Grace)
Architecture:                         aarch64
Byte Order:                           Little Endian
Vendor ID:                            ARM
Model name:                           Cortex-X925
Model:                                1
Thread(s) per core:                   1
Model name:                           Cortex-A725
Model:                                1
Thread(s) per core:                   1

# GPU (Blackwell)
name, driver_version, memory.total [MiB], compute_cap
NVIDIA GB10, 580.95.05, [N/A], 12.1

# CUDA Version
Cuda compilation tools, release 13.0, V13.0.88

I ran the following command for the benchmarks (same as used in comment):

./build/bin/llama-bench -m models/Qwen3-Next-80B-A3B-Instruct-bf16-Q8_0.gguf -r 5 -p 2048 -n 32 -b 2048 -ub 2048 -fa 1 -mmp 0

Let me know if there are other configurations that you'd like me to run.

master:

| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |     2048 |  1 |    0 |          pp2048 |        899.72 ± 2.06 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |     2048 |  1 |    0 |            tg32 |         29.77 ± 0.04 |

build: 8ece3836b (7681)

PR:

| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |     2048 |  1 |    0 |          pp2048 |        924.32 ± 2.21 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |     2048 |  1 |    0 |            tg32 |         31.20 ± 0.52 |

build: 65602e899 (7686)

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 10, 2026

I removed 2 redundant transposes which improved from 793 --> 820 t/s in my test (GGUF stays the same). Would appreciate if you can rerun the test with latest commit

@lemmi
Copy link

lemmi commented Jan 10, 2026

With d96eb69 --mmap 0 works again and the numbers have also improved a bit.

model size params backend ngl n_ubatch fa mmap test t/s
qwen3next 80B.A3B Q8_0 78.98 GiB 79.67 B Vulkan 99 256 1 0 pp4096 392.86 ± 2.40
qwen3next 80B.A3B Q8_0 78.98 GiB 79.67 B Vulkan 99 256 1 0 tg128 33.71 ± 0.02
qwen3next 80B.A3B Q8_0 78.98 GiB 79.67 B Vulkan 99 512 1 0 pp4096 457.53 ± 1.80
qwen3next 80B.A3B Q8_0 78.98 GiB 79.67 B Vulkan 99 512 1 0 tg128 33.64 ± 0.03
qwen3next 80B.A3B Q8_0 78.98 GiB 79.67 B Vulkan 99 1024 1 0 pp4096 392.29 ± 2.50
qwen3next 80B.A3B Q8_0 78.98 GiB 79.67 B Vulkan 99 1024 1 0 tg128 33.68 ± 0.03

@jeffbolznv
Copy link
Collaborator

Here's a quick before/after for Q2_K on Vulkan RTX 5090. Seems like it's at the limits of vidmem and occasionally spills and gets very slow, both with and without this change. Looks like ~5% faster.

before

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -m c:\models\model-orig-Q2_K.gguf -n 128 -p 512,1024,2048 -ub 512,1024,2048
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
load_backend: loaded Vulkan backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-vulkan.dll
load_backend: loaded CPU backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-cpu.dll
| model                          |       size |     params | backend    | ngl | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |      512 |  1 |           pp512 |      4350.55 ± 85.44 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |      512 |  1 |          pp1024 |      4181.56 ± 11.97 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |      512 |  1 |          pp2048 |       4281.07 ± 9.39 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |      512 |  1 |           tg128 |        147.02 ± 2.24 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     1024 |  1 |           pp512 |      4408.04 ± 20.33 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     1024 |  1 |          pp1024 |      5080.18 ± 30.54 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     1024 |  1 |          pp2048 |    3942.66 ± 1148.31 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     1024 |  1 |           tg128 |        147.01 ± 3.10 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     2048 |  1 |           pp512 |      4390.52 ± 24.24 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     2048 |  1 |          pp1024 |      5196.31 ± 10.23 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     2048 |  1 |          pp2048 |      5338.67 ± 18.21 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     2048 |  1 |           tg128 |        142.55 ± 3.30 |

build: 53eb9435d (7684)

after

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -m c:\models\model-Q2_K.gguf -n 128 -p 512,1024,2048 -ub 512,1024,2048
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
load_backend: loaded Vulkan backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-vulkan.dll
load_backend: loaded CPU backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-cpu.dll
| model                          |       size |     params | backend    | ngl | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |      512 |  1 |           pp512 |      4520.18 ± 19.25 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |      512 |  1 |          pp1024 |      4344.81 ± 12.70 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |      512 |  1 |          pp2048 |      4432.59 ± 20.83 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |      512 |  1 |           tg128 |        146.07 ± 7.39 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     1024 |  1 |           pp512 |      4527.51 ± 21.06 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     1024 |  1 |          pp1024 |      5324.04 ± 30.80 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     1024 |  1 |          pp2048 |      5083.77 ± 17.57 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     1024 |  1 |           tg128 |        156.39 ± 0.50 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     2048 |  1 |           pp512 |      4463.32 ± 25.80 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     2048 |  1 |          pp1024 |      5278.55 ± 31.29 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     2048 |  1 |          pp2048 |        343.07 ± 1.07 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     2048 |  1 |           tg128 |        154.33 ± 2.70 |

build: d96eb69e2 (7694)

@danbev
Copy link
Member

danbev commented Jan 10, 2026

Results using d96eb69:

| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |     2048 |  1 |    0 |          pp2048 |        940.81 ± 2.31 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |     2048 |  1 |    0 |            tg32 |         31.44 ± 0.36 |

build: 8c1336cb9 (7697)

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 10, 2026

I ran llama-eval-callback to make sure that I didn't break anything, looks good so far. So until this point I can confirm 4-5% perf boost.

Working on optimizing chunking logic today, hope to get a bit more perf out of it

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 10, 2026

So I just pushed some changes that bumped the perf in my tests from 820 to 960 t/s on pp512, very significant improvement. So in total, we're getting ~20% boost between master and PR.

The 2 most important commits are:

  • e1f8ad2 gives 2% boost; I removed some redundant ggml_cont. I tested on metal and see no side effects on metal. Though, it's best to test it properly. To test it, use llama-cli -f prompt.txt --temp 0 --top-k 1 with a large enough prompt.txt (around 2048 tokens is enough). If the results stay unchanged between PR/master, that means nothing is broken
  • f8ad742 gives another ~10% perf boost. The old code path returns 2 tensors by ggml_concat them, then later separate using ggml_view. Here we can return them as std::pair

But first, would appreciate if you can run a simple llama-bench on your side @danbev @jeffbolznv @lemmi

On my mac m3 ultra:

| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3next 80B.A3B F16          | 148.50 GiB |    79.67 B | Metal,BLAS |      24 |           pp512 |        965.15 ± 8.00 |
| qwen3next 80B.A3B F16          | 148.50 GiB |    79.67 B | Metal,BLAS |      24 |           tg128 |         28.87 ± 0.06 |

build: f8ad742ae (7699)

@lemmi
Copy link

lemmi commented Jan 10, 2026

model size params backend ngl fa mmap test t/s
qwen3next 80B.A3B Q8_0 78.98 GiB 79.67 B Vulkan 99 1 0 pp512 481.87 ± 3.65
qwen3next 80B.A3B Q8_0 78.98 GiB 79.67 B Vulkan 99 1 0 pp4096 460.94 ± 1.39
qwen3next 80B.A3B Q8_0 78.98 GiB 79.67 B Vulkan 99 1 0 tg128 34.49 ± 0.03

TG improved by 1t/s, rest also got a small bump.

@IIIIIllllIIIIIlllll
Copy link

The test results are still AI MAX+ 395. The performance has indeed improved significantly.

==============================
This PR:

cmd:
/home/mark/llama.cpp/llama.cpp-xsn-qwen3next_improve/build/bin/llama-bench -m /home/mark/Models/Q8/Test-Qwen3-Next/new.gguf --repetitions 5 --output md --mmap 0 --delay 0 --n-prompt 512 --n-gen 128 --n-depth 0 --batch-size 2048 --ubatch-size 512 --cache-type-k f16 --cache-type-v f16 --threads 16 --n-gpu-layers 99 --n-cpu-moe 0 --flash-attn 1

output:

ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |  1 |    0 |           pp512 |        605.77 ± 4.80 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |  1 |    0 |           tg128 |         29.86 ± 0.03 |

build: unknown (0)

==============================
Master PR:

cmd:
/home/mark/llama.cpp/llama.cpp-master-rocm/llama-bench -m /home/mark/Models/Q8/Qwen3-Next-80B-A3B-Instruct-Q8_0/next-master.gguf --repetitions 5 --output md --mmap 0 --delay 0 --n-prompt 512 --n-gen 128 --n-depth 0 --batch-size 2048 --ubatch-size 512 --cache-type-k f16 --cache-type-v f16 --threads 16 --n-gpu-layers 99 --n-cpu-moe 0 --flash-attn 1

output:

ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |  1 |    0 |           pp512 |        576.19 ± 4.31 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |  1 |    0 |           tg128 |         28.99 ± 0.01 |

build: unknown (0)

Comment on lines +320 to +321
//ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
Copy link
Collaborator Author

@ngxson ngxson Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov do you have any ideas why the commented version can be slower? (I'm testing on Metal)

both tensors are contiguous, here are the shapes:

k_gdiff ne: 128 64 1 32
v_new ne: 128 64 1 32

@jeffbolznv
Copy link
Collaborator

Maybe another 5-10% depending on mode with the latest. The numbers continue to be noisy (likely due to spilling) but it's clearly improved.


Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -m c:\models\model-Q2_K.gguf -n 128 -p 512,1024,2048 -ub 512,1024,2048
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
load_backend: loaded Vulkan backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-vulkan.dll
load_backend: loaded CPU backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-cpu.dll
| model                          |       size |     params | backend    | ngl | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |      512 |  1 |           pp512 |     4820.56 ± 102.46 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |      512 |  1 |          pp1024 |      4716.74 ± 16.33 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |      512 |  1 |          pp2048 |      4809.92 ± 14.32 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |      512 |  1 |           tg128 |        164.14 ± 1.20 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     1024 |  1 |           pp512 |      4885.30 ± 27.15 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     1024 |  1 |          pp1024 |      5657.14 ± 27.04 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     1024 |  1 |          pp2048 |      5484.45 ± 19.96 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     1024 |  1 |           tg128 |        163.62 ± 1.17 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     2048 |  1 |           pp512 |      4883.83 ± 33.79 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     2048 |  1 |          pp1024 |      5660.00 ± 18.91 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     2048 |  1 |          pp2048 |      536.38 ± 264.19 |
| qwen3next 80B.A3B Q2_K - Medium |  27.12 GiB |    79.67 B | Vulkan     |  99 |     2048 |  1 |           tg128 |        160.92 ± 3.24 |

build: 9299ced6b (7701)

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 10, 2026

@jeffbolznv could you show the log of graph splits? (I don't quite remember which env var to enable it)

On metal, I got some splits because ggml_pad is not supported, but it can be added in another PR. I'm curious to see why vulkan graph is split

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 10, 2026

Alright, this is as far as I can go for now. My last commits don't improve much on performance (at least on metal), but it further reduced the number of nodes. On master branch, we current have 11k nodes in total, and this PR brought it down to ~9k nodes.

PR:
llama_context: Flash Attention was auto, set to enabled
llama_context:      Metal compute buffer size =   846.80 MiB
llama_context:        CPU compute buffer size =   524.01 MiB
llama_context: graph nodes  = 9374 (with bs=512), 5918 (with bs=1)
llama_context: graph splits = 76 (with bs=512), 4 (with bs=1)

master:
llama_context: Flash Attention was auto, set to enabled
llama_context:      Metal compute buffer size =   870.80 MiB
llama_context:        CPU compute buffer size =   524.01 MiB
llama_context: graph nodes  = 11186 (with bs=512), 6614 (with bs=1)
llama_context: graph splits = 76 (with bs=512), 4 (with bs=1)

@ngxson ngxson marked this pull request as ready for review January 10, 2026 17:36
@ngxson ngxson requested review from CISC and ggerganov January 10, 2026 17:38
@IIIIIllllIIIIIlllll
Copy link

There is no further performance improvement on my AI MAX+ 395 either.

ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |  1 |    0 |           pp512 |        605.15 ± 5.10 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |  1 |    0 |           tg128 |         30.00 ± 0.01 |

build: unknown (0)

@jeffbolznv
Copy link
Collaborator

I'm not at my computer for a while, but I dont think there should be any splits, except for the initial layer(s).

To enable the logging set GGML_SCHED_DEBUG=2 and use -v.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants