Skip to content

Commit 7ca2c06

Browse files
committed
vulkan: fuse rms_norm + mul + rope (+ view + set_rows)
This change combines the rms_norm+mul and rope+view+set_rows fusions to allow fusing the whole sequence together. This comes up in Qwen3, Bailing, and some other models.
1 parent ad51c0a commit 7ca2c06

File tree

12 files changed

+1012
-606
lines changed

12 files changed

+1012
-606
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 572 additions & 354 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
#include "rte.glsl"
55
#include "utils.glsl"
6+
#if RMS_NORM_ROPE_FUSION
7+
#include "rope_params.glsl"
8+
#endif
69

710
layout (push_constant) uniform parameter
811
{
@@ -12,11 +15,16 @@ layout (push_constant) uniform parameter
1215
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
1316
uint misalign_offsets;
1417
float param1; float param2; int param3;
18+
#if RMS_NORM_ROPE_FUSION
19+
rope_params rope;
20+
#endif
1521
} p;
1622

23+
#if !RMS_NORM_ROPE_FUSION
1724
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1825
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
1926
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
27+
#endif
2028

2129
// true if src0/src1 are the same shape and the indices can be reused without additional modulus
2230
layout(constant_id = 0) const bool norepeat = false;

ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,32 @@
33
#include "generic_binary_head.glsl"
44
#include "types.glsl"
55

6+
#if RMS_NORM_ROPE_FUSION
7+
8+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
9+
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
10+
11+
// data is passed from rms_norm -> rope through shared memory.
12+
// rms_norm calls this data_d, rope calls this rope_data_a.
13+
// Binding 2 is not used
14+
shared FLOAT_TYPE rope_data_a[1024];
15+
#define data_d rope_data_a
16+
17+
layout (binding = 3) readonly buffer R_Y {int rope_data_pos[];};
18+
layout (binding = 4) readonly buffer R_Z {float rope_data_ff[];};
19+
layout (binding = 5) writeonly buffer R_D {ROPE_D_TYPE rope_data_d[];};
20+
layout (binding = 6) readonly buffer R_I {uvec2 rope_data_i[];}; // indices for set_rows
21+
22+
#include "rope_params.glsl"
23+
#include "rope_funcs.glsl"
24+
25+
#define GGML_ROPE_TYPE_NORMAL 0
26+
#define GGML_ROPE_TYPE_NEOX 2
27+
#define GGML_ROPE_TYPE_MROPE 8
28+
#define GGML_ROPE_TYPE_VISION 24
29+
30+
#endif
31+
632
#extension GL_EXT_control_flow_attributes : enable
733
#define BLOCK_SIZE 512
834

@@ -28,8 +54,12 @@ void rms_norm(uint num_iters) {
2854

2955
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
3056
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
57+
#if RMS_NORM_ROPE_FUSION
58+
// Per-row offset in shared memory
59+
uint32_t d_offset = 0;
60+
#else
3161
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
32-
62+
#endif
3363
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
3464

3565
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
@@ -79,6 +109,18 @@ void rms_norm(uint num_iters) {
79109
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
80110
}
81111
}
112+
#if RMS_NORM_ROPE_FUSION
113+
barrier();
114+
rope_params rp = p.rope;
115+
uint rope_row = (samp*nchannels + channel)*nrows + row;
116+
for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
117+
if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
118+
rope_neox(t, rope_row, rp);
119+
} else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
120+
rope_norm(t, rope_row, rp);
121+
}
122+
}
123+
#endif
82124
}
83125

84126
void main() {
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
2+
float rope_yarn_ramp(const float low, const float high, const uint i0) {
3+
const float y = (i0 / 2 - low) / max(0.001f, high - low);
4+
return 1.0f - min(1.0f, max(0.0f, y));
5+
}
6+
7+
uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
8+
#if RMS_NORM_ROPE_FUSION
9+
// Per-row offset in shared memory
10+
const uint ix = i0;
11+
#else
12+
const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
13+
#endif
14+
return ix;
15+
}
16+
17+
void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta, rope_params p) {
18+
float mscale = p.attn_factor;
19+
// Get n-d rotational scaling corrected for extrapolation
20+
float theta_interp = p.freq_scale * theta_extrap;
21+
float theta = theta_interp;
22+
if (p.ext_factor != 0.0f) {
23+
float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
24+
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
25+
26+
// Get n-d magnitude scaling corrected for interpolation
27+
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
28+
}
29+
// Backprogagation uses inverted rotation
30+
if (p.is_back != 0) {
31+
theta = -theta;
32+
}
33+
cos_theta = cos(theta) * mscale;
34+
sin_theta = sin(theta) * mscale;
35+
}
36+
37+
void rope_norm(const uint i0, const uint i1, rope_params p) {
38+
uint ne0 = p.ncols;
39+
uint ne1 = p.p_delta_rows;
40+
41+
if (i0 >= ne0) {
42+
return;
43+
}
44+
45+
// i1 is actually i2*nb2+i1, but the rows are contiguous
46+
const uint i01 = i1 % ne1;
47+
const uint i02 = i1 / ne1;
48+
49+
uint idst = i1*ne0 + i0;
50+
const uint ix = rope_a_coord(i0, i01, i02, p);
51+
52+
// Fusion optimization: ROPE + VIEW + SET_ROWS..
53+
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
54+
if (p.set_rows_stride != 0) {
55+
idst = i01*ne0 + i0;
56+
idst += rope_data_i[i02].x * p.set_rows_stride;
57+
}
58+
59+
if (i0 >= p.n_dims) {
60+
rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]);
61+
rope_data_d[idst + 1] = ROPE_D_TYPE(rope_data_a[ix + 1]);
62+
63+
return;
64+
}
65+
66+
const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
67+
68+
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
69+
70+
float cos_theta, sin_theta;
71+
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
72+
73+
const float x0 = float(rope_data_a[ix + 0]);
74+
const float x1 = float(rope_data_a[ix + 1]);
75+
76+
rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
77+
rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
78+
}
79+
80+
void rope_neox(const uint i0, const uint i1, rope_params p) {
81+
uint ne0 = p.ncols;
82+
uint ne1 = p.p_delta_rows;
83+
84+
if (i0 >= ne0) {
85+
return;
86+
}
87+
88+
const uint i01 = i1 % ne1;
89+
const uint i02 = i1 / ne1;
90+
91+
uint idst = i1*ne0 + i0/2;
92+
const uint ix = rope_a_coord(i0/2, i01, i02, p);
93+
94+
// Fusion optimization: ROPE + VIEW + SET_ROWS..
95+
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
96+
if (p.set_rows_stride != 0) {
97+
idst = i01*ne0 + i0/2;
98+
idst += rope_data_i[i02].x * p.set_rows_stride;
99+
}
100+
101+
if (i0 >= p.n_dims) {
102+
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
103+
rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
104+
105+
return;
106+
}
107+
108+
const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
109+
110+
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
111+
112+
float cos_theta, sin_theta;
113+
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
114+
115+
const float x0 = float(rope_data_a[ix + 0]);
116+
const float x1 = float(rope_data_a[ix + p.n_dims/2]);
117+
118+
rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
119+
rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
120+
}
121+
122+
123+
void rope_multi(const uint i0, const uint i1, rope_params p) {
124+
uint ne0 = p.ncols;
125+
uint ne1 = p.p_delta_rows;
126+
uint ne2 = p.ne02;
127+
128+
if (i0 >= ne0) {
129+
return;
130+
}
131+
132+
const uint i01 = i1 % ne1;
133+
const uint i02 = i1 / ne1;
134+
135+
const uint idst = i1*ne0 + i0/2;
136+
const uint ix = rope_a_coord(i0/2, i01, i02, p);
137+
138+
if (i0 >= p.n_dims) {
139+
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
140+
rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
141+
142+
return;
143+
}
144+
145+
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
146+
const int sec_w = p.sections[1] + p.sections[0];
147+
const uint sector = (i0 / 2) % sect_dims;
148+
149+
float theta_base = 0.0;
150+
if (p.is_imrope != 0) {
151+
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
152+
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
153+
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
154+
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
155+
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
156+
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
157+
} else {
158+
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
159+
}
160+
} else {
161+
if (sector < p.sections[0]) {
162+
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
163+
}
164+
else if (sector >= p.sections[0] && sector < sec_w) {
165+
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
166+
}
167+
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
168+
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
169+
}
170+
else if (sector >= sec_w + p.sections[2]) {
171+
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
172+
}
173+
}
174+
175+
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
176+
177+
float cos_theta, sin_theta;
178+
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
179+
180+
const float x0 = float(rope_data_a[ix + 0]);
181+
const float x1 = float(rope_data_a[ix + p.n_dims/2]);
182+
183+
rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
184+
rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
185+
}
186+
187+
void rope_vision(const uint i0, const uint i1, rope_params p) {
188+
uint ne0 = p.ncols;
189+
uint ne1 = p.p_delta_rows;
190+
uint ne2 = p.ne02;
191+
192+
if (i0 >= ne0) {
193+
return;
194+
}
195+
196+
const uint i01 = i1 % ne1;
197+
const uint i02 = i1 / ne1;
198+
199+
const uint idst = i1*ne0 + i0/2;
200+
const uint ix = rope_a_coord(i0/2, i01, i02, p);
201+
202+
const int sect_dims = p.sections[0] + p.sections[1];
203+
const int sec_w = p.sections[1] + p.sections[0];
204+
const uint sector = (i0 / 2) % sect_dims;
205+
206+
float theta_base = 0.0;
207+
if (sector < p.sections[0]) {
208+
const uint p0 = sector;
209+
theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
210+
}
211+
else if (sector >= p.sections[0] && sector < sec_w) {
212+
const uint p0 = sector - p.sections[0];
213+
theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
214+
}
215+
216+
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
217+
218+
float cos_theta, sin_theta;
219+
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
220+
221+
const float x0 = float(rope_data_a[ix + 0]);
222+
const float x1 = float(rope_data_a[ix + p.n_dims]);
223+
224+
rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
225+
rope_data_d[idst + p.n_dims] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
226+
}
227+

ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl

Lines changed: 9 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,56 +3,18 @@
33
#extension GL_EXT_shader_16bit_storage : require
44

55
#include "rte.glsl"
6+
#include "rope_params.glsl"
67

78
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
89

9-
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
10-
layout (binding = 1) readonly buffer Y {int data_pos[];};
11-
layout (binding = 2) readonly buffer Z {float data_ff[];};
12-
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
13-
layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
10+
layout (binding = 0) readonly buffer X {A_TYPE rope_data_a[];};
11+
layout (binding = 1) readonly buffer Y {int rope_data_pos[];};
12+
layout (binding = 2) readonly buffer Z {float rope_data_ff[];};
13+
layout (binding = 3) writeonly buffer D {ROPE_D_TYPE rope_data_d[];};
14+
layout (binding = 4) readonly buffer I {uvec2 rope_data_i[];}; // indices for set_rows
1415

15-
layout (push_constant) uniform parameter {
16-
uint ncols;
17-
uint n_dims;
18-
float freq_scale;
19-
uint p_delta_rows;
20-
float freq_base;
21-
float ext_factor;
22-
float attn_factor;
23-
float corr_dims[2];
24-
float theta_scale;
25-
uint has_ff;
26-
uint ne02;
27-
uint s1;
28-
uint s2;
29-
int sections[4];
30-
uint is_imrope;
31-
uint is_back;
32-
uint set_rows_stride;
33-
} p;
34-
35-
float rope_yarn_ramp(const float low, const float high, const uint i0) {
36-
const float y = (i0 / 2 - low) / max(0.001f, high - low);
37-
return 1.0f - min(1.0f, max(0.0f, y));
38-
}
3916

40-
void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
41-
float mscale = p.attn_factor;
42-
// Get n-d rotational scaling corrected for extrapolation
43-
float theta_interp = p.freq_scale * theta_extrap;
44-
float theta = theta_interp;
45-
if (p.ext_factor != 0.0f) {
46-
float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
47-
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
17+
layout (push_constant) uniform parameter {
18+
rope_params pc;
19+
};
4820

49-
// Get n-d magnitude scaling corrected for interpolation
50-
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
51-
}
52-
// Backprogagation uses inverted rotation
53-
if (p.is_back != 0) {
54-
theta = -theta;
55-
}
56-
cos_theta = cos(theta) * mscale;
57-
sin_theta = sin(theta) * mscale;
58-
}

0 commit comments

Comments
 (0)