Skip to content

Commit 5a305ff

Browse files
committed
[ET-VK][qlinear] Faster weight only quantized linear gemv kernel
Pull Request resolved: #12444 ## Changes * Introduce a new compute shader for int4 linear's gemv cases that performs much better than the existing shader. This shader is inspired from MNN's gemv_1x1_conv_buf.cl shader. With this compute kernel, transformer models' text generation can execute much faster than before. On Samsung Galaxy S24 for Llama 3.2 1B, generating 128 tokens: Before: ~25 tok/s After: ~49 tok/s ## Why this new shader is faster The biggest reason is due to vectorized loading of the uint4 weight buffer. This new shader loads the weight buffer as a buffer/image of `uvec4`, whereas the old shader loads the weight buffer as a buffer/image of `u8vec4`. Using the Adreno Offline Compiler, I found that in the former, only one load instruction was used to load from the weight tensor, whereas in the latter 16 load instructions were used to load from the weight tensor. It appears that the data loading was not being vectorized at the assembly level. This is potentially behaviour that can be approved in the SPIR-V shader compiler. An additional factor is better weight packing layout. The new prepacking routine results in better memory coalescing between threads in a work group. The final major factor is the use of tree based reduction to co-operatively reduce partial results into the final output. Previously, a single thread was responsible for the final reduction. ## Future Work * Introduce faster shader for int4 linear gemm cases * Update QCSNW to also use these updated shaders ghstack-source-id: 296437697 Differential Revision: [D78275584](https://our.internmc.facebook.com/intern/diff/D78275584/)
1 parent ee3797e commit 5a305ff

11 files changed

+528
-180
lines changed

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,18 @@
6868
*/
6969
#define mod4(x) ((x) & 3)
7070

71+
#define ALIGN_UP_4(x) (((x) + 3) & ~3)
72+
73+
#define DIV_UP_8(x) (((x) + 7) >> 3)
74+
#define DIV_UP_4(x) (((x) + 3) >> 2)
75+
76+
#define DIV_4(x) ((x) >> 2)
77+
#define DIV_2(x) ((x) >> 1)
78+
79+
#define MUL_8(x) ((x) << 3)
80+
#define MUL_4(x) ((x) << 2)
81+
#define MUL_2(x) ((x) << 1)
82+
7183
/*
7284
* Get the staging buffer indices that contain the data of the texel that
7385
* corresponds to the provided tensor index. Since the texel have 4 elements,

backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl

Lines changed: 110 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -13,187 +13,143 @@
1313
#define T ${buffer_scalar_type(DTYPE)}
1414
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
1515

16-
#define TILE_ROWS ${TILE_ROWS}
17-
18-
#define NGROUPS 8
19-
#define NWORKERS 8
16+
#define WGS ${WGS}
2017

2118
${define_required_extensions(DTYPE)}
22-
$if WEIGHT_STORAGE == "buffer":
23-
${define_required_extensions("uint8")}
19+
${define_required_extensions("uint8")}
2420

2521
#extension GL_EXT_control_flow_attributes : require
22+
#extension GL_EXT_debug_printf : require
2623

2724
layout(std430) buffer;
2825

29-
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
30-
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)}
31-
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
26+
#include "indexing_utils.h"
27+
28+
${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)}
29+
${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)}
30+
${layout_declare_tensor(B, "r", "t_qmat2", "uint", WEIGHT_STORAGE, is_scalar_array=False)}
3231
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)}
3332

3433
layout(push_constant) uniform restrict Block {
35-
ivec4 out_sizes;
36-
ivec4 mat1_sizes;
37-
ivec4 qmat2_sizes;
34+
ivec4 output_sizes;
35+
ivec4 input_sizes;
36+
ivec4 weight_sizes;
3837
};
3938

40-
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
39+
layout(local_size_x = WGS, local_size_y = 1, local_size_z = 1) in;
4140

4241
layout(constant_id = 3) const int group_size = 64;
4342

44-
shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][2];
45-
46-
/*
47-
* This shader computes a linear operator between a floating point input matrix
48-
* x and a weights matrix that is quantized to 4 bits. Please refer to the
49-
* q_4w_linear shader for more details.
50-
*
51-
* This shader implements a co-operative algorithm to compute the output. The
52-
* work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads
53-
* cooperative to compute TILE_ROWS * 2 output texels. Therefore,
54-
* NGROUP * TILE_ROWS * 2 output texels are computed across one work group.
55-
*
56-
* The threads co-operate by each thread computing a partial reduction along the
57-
* K dimension. To illustrate the computation, consider a scalar variant of the
58-
* algorithm that computes the dot product of 2 vectors. Also assume that
59-
* NWORKERS is 8.
60-
*
61-
* Thread 1 in each group will compute:
62-
* (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ...
63-
*
64-
* Thread 2 in each group will compute:
65-
* (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ...
66-
*
67-
* Thread 3 in each group will compute:
68-
* (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ...
69-
*
70-
* The partial accumulations is structured such that memory accesses in each
71-
* loop iteration can be coalesced.
72-
*
73-
* Then, at the end first thread in each group will accumulate the partial
74-
* accumulations computed by each thread to obtain the final result.
75-
*
76-
* Note that this shader assumes that all tensors are width packed.
77-
*/
78-
void main() {
79-
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
80-
// Each thread writes out 2 texels along the width axis, equivalent to 8
81-
// scalar elements. Therefore multiply the thread_idx.x by 8.
82-
const uint out_col = gl_GlobalInvocationID.x << 3;
83-
// Similar reasoning to the above, each thread works on 2 texels along the
84-
// width axis so multiply thread_idx.x by 2.
85-
const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1;
86-
87-
const uint gid = gl_LocalInvocationID.x; // group id
88-
const uint wid = gl_LocalInvocationID.z; // worker id
89-
90-
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
91-
return;
92-
}
43+
shared VEC4_T partial_sums[WGS][2];
9344

94-
const int num_blocks = mat1_sizes.x / group_size;
95-
96-
VEC4_T mat1[TILE_ROWS];
97-
VEC4_T qmat2[4][2];
98-
VEC4_T local_sums[TILE_ROWS][2];
99-
100-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
101-
local_sums[r][0] = VEC4_T(0);
102-
local_sums[r][1] = VEC4_T(0);
103-
}
104-
105-
VEC4_T scales[2];
106-
VEC4_T zeros[2];
107-
108-
$if WEIGHT_STORAGE == "buffer":
109-
const int qmat2_stride = qmat2_sizes.x >> 2;
110-
$if PARAMS_STORAGE == "buffer":
111-
const int qparams_y_stride = out_sizes.x >> 2;
112-
const int qparams_z_stride = qparams_y_stride * 2;
113-
114-
for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
115-
$if PARAMS_STORAGE == "buffer":
116-
scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx];
117-
zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride];
118-
119-
scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1];
120-
zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride];
121-
$else:
122-
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
123-
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);
124-
125-
scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
126-
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);
127-
128-
for (uint g_idx = 4 * wid; g_idx < group_size; g_idx += (4 * NWORKERS)) {
129-
const uint k = block_idx * group_size + g_idx;
130-
131-
// Preload B
132-
[[unroll]] for (int r = 0; r < 4; ++r) {
133-
$if WEIGHT_STORAGE == "buffer":
134-
const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x];
135-
$else:
136-
const uvec4 packed_weight_tex = texelFetch(
137-
t_qmat2,
138-
ivec2(gl_GlobalInvocationID.x, k + r),
139-
0);
140-
141-
qmat2[r][0] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0) * scales[0] + zeros[0];
142-
qmat2[r][1] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0) * scales[1] + zeros[1];
143-
}
45+
$if IO_STORAGE == "buffer":
46+
#define BUFFER_IO
47+
$if WEIGHT_STORAGE == "buffer":
48+
#define BUFFER_WEIGHT
14449

145-
// Preload A
146-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
147-
$if IN_STORAGE == "buffer":
148-
mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2];
149-
$else:
150-
mat1[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0);
151-
}
50+
#include "qlinear_utils.glslh"
15251

153-
// Accumulate local output tile
154-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
155-
local_sums[r][0] += mat1[r].x * qmat2[0][0]
156-
+ mat1[r].y * qmat2[1][0]
157-
+ mat1[r].z * qmat2[2][0]
158-
+ mat1[r].w * qmat2[3][0];
159-
160-
local_sums[r][1] += mat1[r].x * qmat2[0][1]
161-
+ mat1[r].y * qmat2[1][1]
162-
+ mat1[r].z * qmat2[2][1]
163-
+ mat1[r].w * qmat2[3][1];
164-
}
52+
void main() {
53+
const uint lid = gl_LocalInvocationID.x;
54+
const uint n8 = gl_GlobalInvocationID.y;
55+
// The output tensor will have a shape of [n, 1, 1, 1]. Each thread computes
56+
// 8 output elements, so each thread will write to 8 elements starting at the
57+
// tensor index (gid.x * 8, 0, 0, 0).
58+
const uint n = MUL_8(n8);
59+
const uint K4 = DIV_UP_4(input_sizes.x);
60+
61+
const int block_num = input_sizes.x / group_size;
62+
63+
VEC4_T out_texels[2];
64+
out_texels[0] = VEC4_T(0);
65+
out_texels[1] = VEC4_T(0);
66+
67+
// initialize the group index to a value larger than the largest possible
68+
uint cur_group_idx = input_sizes.x;
69+
70+
// Each thread in the work group accumulates a partial result.
71+
for (uint k4 = lid; k4 < DIV_UP_4(input_sizes.x); k4 += WGS) {
72+
const uint k = MUL_4(k4);
73+
const uint group_idx = k / group_size;
74+
75+
VEC4_T scales[2];
76+
VEC4_T zeros[2];
77+
78+
// Only update the scales/zeros if the current iteration is now working on a
79+
// new quantization group.
80+
if (group_idx != cur_group_idx) {
81+
// The qparams tensor contains the quantization scales and zeros, with
82+
// shape [2, N, K / group_size, 1].
83+
// Loading a texel from the qparams tensor will return 2 scales and 2
84+
// zeros for 2 adjacent output channels.
85+
uint qparams_bufi = group_idx * DIV_2(output_sizes.x) + DIV_2(n);
86+
VEC4_T scales_zeros_texels[4];
87+
$for comp in range(4):
88+
scales_zeros_texels[${comp}] = t_qparams[qparams_bufi++];
89+
90+
scales[0] = VEC4_T(scales_zeros_texels[0].xz, scales_zeros_texels[1].xz);
91+
zeros[0] = VEC4_T(scales_zeros_texels[0].yw, scales_zeros_texels[1].yw);
92+
93+
scales[1] = VEC4_T(scales_zeros_texels[2].xz, scales_zeros_texels[3].xz);
94+
zeros[1] = VEC4_T(scales_zeros_texels[2].yw, scales_zeros_texels[3].yw);
95+
96+
cur_group_idx = group_idx;
16597
}
98+
// The input tensor will have a shape of [K, 1, 1, 1]; in each iteration,
99+
// load 4 elements starting from the tensor index (k, 0, 0, 0).
100+
VEC4_T in_texel = load_input_texel(k4);
101+
// Extract each element of the in_texel into a separate vectorized variable;
102+
// these are used to "broadcast" the input values in subsequent fma calls.
103+
VEC4_T in_texel_val[4];
104+
$for comp in range(4):
105+
in_texel_val[${comp}] = VEC4_T(in_texel[${comp}]);
106+
107+
uvec4 packed_weight_block = load_transposed_weight_block(k4, n8, K4);
108+
109+
VEC4_T weight_texels[2];
110+
$for comp in range(4):
111+
{
112+
weight_texels[0].x = extract_4bit_from_transposed_block(packed_weight_block, 0, ${comp});
113+
weight_texels[0].y = extract_4bit_from_transposed_block(packed_weight_block, 1, ${comp});
114+
weight_texels[0].z = extract_4bit_from_transposed_block(packed_weight_block, 2, ${comp});
115+
weight_texels[0].w = extract_4bit_from_transposed_block(packed_weight_block, 3, ${comp});
116+
117+
weight_texels[1].x = extract_4bit_from_transposed_block(packed_weight_block, 4, ${comp});
118+
weight_texels[1].y = extract_4bit_from_transposed_block(packed_weight_block, 5, ${comp});
119+
weight_texels[1].z = extract_4bit_from_transposed_block(packed_weight_block, 6, ${comp});
120+
weight_texels[1].w = extract_4bit_from_transposed_block(packed_weight_block, 7, ${comp});
121+
122+
weight_texels[0] = fma(weight_texels[0], scales[0], zeros[0]);
123+
weight_texels[1] = fma(weight_texels[1], scales[1], zeros[1]);
124+
125+
out_texels[0] = fma(in_texel_val[${comp}], weight_texels[0], out_texels[0]);
126+
out_texels[1] = fma(in_texel_val[${comp}], weight_texels[1], out_texels[1]);
127+
}
166128
}
167129

168-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
169-
partial_sums[gid][wid][r][0] = local_sums[r][0];
170-
partial_sums[gid][wid][r][1] = local_sums[r][1];
171-
}
130+
partial_sums[lid][0] = out_texels[0];
131+
partial_sums[lid][1] = out_texels[1];
172132

173133
memoryBarrierShared();
174134
barrier();
175135

176-
if (wid != 0) {
177-
return;
178-
}
179-
180-
VEC4_T sums[TILE_ROWS][2];
181-
182-
for (int r = 0; r < TILE_ROWS; ++r) {
183-
sums[r][0] = VEC4_T(0);
184-
sums[r][1] = VEC4_T(0);
185-
[[unroll]] for (int worker = 0; worker < NWORKERS; ++ worker) {
186-
sums[r][0] += partial_sums[gid][worker][r][0];
187-
sums[r][1] += partial_sums[gid][worker][r][1];
136+
// Tree reduction to compute the overall result.
137+
for (int i = WGS / 2; i > 0; i /= 2) {
138+
if (lid < i) {
139+
partial_sums[lid][0] = partial_sums[lid][0] + partial_sums[lid + i][0];
140+
partial_sums[lid][1] = partial_sums[lid][1] + partial_sums[lid + i][1];
188141
}
142+
memoryBarrierShared();
143+
barrier();
189144
}
190145

191-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
192-
$if OUT_STORAGE == "buffer":
193-
t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = sums[r][0];
194-
t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = sums[r][1];
195-
$else:
196-
imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), sums[r][0]);
197-
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), sums[r][1]);
146+
// Only the first thread will write out result
147+
if (lid == 0) {
148+
out_texels[0] = partial_sums[0][0];
149+
out_texels[1] = partial_sums[0][1];
150+
151+
uint n4 = DIV_4(n);
152+
write_output_texel(out_texels[0], n4);
153+
write_output_texel(out_texels[1], n4 + 1);
198154
}
199155
}

backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,13 @@
77
linear_qga4w_coop:
88
parameter_names_with_default_values:
99
DTYPE: float
10-
OUT_STORAGE: texture3d
11-
IN_STORAGE: texture3d
10+
IO_STORAGE: texture3d
1211
WEIGHT_STORAGE: texture2d
13-
PARAMS_STORAGE: buffer
14-
TILE_ROWS: 1
12+
WGS: 64
1513
shader_variants:
1614
- NAME: linear_qga4w_coop_texture3d_texture3d_texture2d_float
1715
- NAME: linear_qga4w_coop_buffer_buffer_texture2d_float
18-
OUT_STORAGE: buffer
19-
IN_STORAGE: buffer
16+
IO_STORAGE: buffer
2017
- NAME: linear_qga4w_coop_buffer_buffer_buffer_float
21-
OUT_STORAGE: buffer
22-
IN_STORAGE: buffer
18+
IO_STORAGE: buffer
2319
WEIGHT_STORAGE: buffer

backends/vulkan/runtime/graph/ops/glsl/no_op.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ no_op:
1313
- VALUE: half
1414
- VALUE: float
1515
- VALUE: int32
16+
- VALUE: uint32
1617
- VALUE: int8
1718
- VALUE: uint8
1819
STORAGE:

0 commit comments

Comments
 (0)