|
13 | 13 | #define T ${buffer_scalar_type(DTYPE)}
|
14 | 14 | #define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
|
15 | 15 |
|
16 |
| -#define TILE_ROWS ${TILE_ROWS} |
17 |
| - |
18 |
| -#define NGROUPS 8 |
19 |
| -#define NWORKERS 8 |
| 16 | +#define WGS ${WGS} |
20 | 17 |
|
21 | 18 | ${define_required_extensions(DTYPE)}
|
22 |
| -$if WEIGHT_STORAGE == "buffer": |
23 |
| - ${define_required_extensions("uint8")} |
| 19 | +${define_required_extensions("uint8")} |
24 | 20 |
|
25 | 21 | #extension GL_EXT_control_flow_attributes : require
|
| 22 | +#extension GL_EXT_debug_printf : require |
26 | 23 |
|
27 | 24 | layout(std430) buffer;
|
28 | 25 |
|
29 |
| -${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} |
30 |
| -${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)} |
31 |
| -${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} |
| 26 | +#include "indexing_utils.h" |
| 27 | + |
| 28 | +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} |
| 29 | +${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} |
| 30 | +${layout_declare_tensor(B, "r", "t_qmat2", "uint", WEIGHT_STORAGE, is_scalar_array=False)} |
32 | 31 | ${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)}
|
33 | 32 |
|
34 | 33 | layout(push_constant) uniform restrict Block {
|
35 |
| - ivec4 out_sizes; |
36 |
| - ivec4 mat1_sizes; |
37 |
| - ivec4 qmat2_sizes; |
| 34 | + ivec4 output_sizes; |
| 35 | + ivec4 input_sizes; |
| 36 | + ivec4 weight_sizes; |
38 | 37 | };
|
39 | 38 |
|
40 |
| -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; |
| 39 | +layout(local_size_x = WGS, local_size_y = 1, local_size_z = 1) in; |
41 | 40 |
|
42 | 41 | layout(constant_id = 3) const int group_size = 64;
|
43 | 42 |
|
44 |
| -shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][2]; |
45 |
| - |
46 |
| -/* |
47 |
| - * This shader computes a linear operator between a floating point input matrix |
48 |
| - * x and a weights matrix that is quantized to 4 bits. Please refer to the |
49 |
| - * q_4w_linear shader for more details. |
50 |
| - * |
51 |
| - * This shader implements a co-operative algorithm to compute the output. The |
52 |
| - * work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads |
53 |
| - * cooperative to compute TILE_ROWS * 2 output texels. Therefore, |
54 |
| - * NGROUP * TILE_ROWS * 2 output texels are computed across one work group. |
55 |
| - * |
56 |
| - * The threads co-operate by each thread computing a partial reduction along the |
57 |
| - * K dimension. To illustrate the computation, consider a scalar variant of the |
58 |
| - * algorithm that computes the dot product of 2 vectors. Also assume that |
59 |
| - * NWORKERS is 8. |
60 |
| - * |
61 |
| - * Thread 1 in each group will compute: |
62 |
| - * (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ... |
63 |
| - * |
64 |
| - * Thread 2 in each group will compute: |
65 |
| - * (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ... |
66 |
| - * |
67 |
| - * Thread 3 in each group will compute: |
68 |
| - * (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ... |
69 |
| - * |
70 |
| - * The partial accumulations is structured such that memory accesses in each |
71 |
| - * loop iteration can be coalesced. |
72 |
| - * |
73 |
| - * Then, at the end first thread in each group will accumulate the partial |
74 |
| - * accumulations computed by each thread to obtain the final result. |
75 |
| - * |
76 |
| - * Note that this shader assumes that all tensors are width packed. |
77 |
| - */ |
78 |
| -void main() { |
79 |
| - const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; |
80 |
| - // Each thread writes out 2 texels along the width axis, equivalent to 8 |
81 |
| - // scalar elements. Therefore multiply the thread_idx.x by 8. |
82 |
| - const uint out_col = gl_GlobalInvocationID.x << 3; |
83 |
| - // Similar reasoning to the above, each thread works on 2 texels along the |
84 |
| - // width axis so multiply thread_idx.x by 2. |
85 |
| - const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; |
86 |
| - |
87 |
| - const uint gid = gl_LocalInvocationID.x; // group id |
88 |
| - const uint wid = gl_LocalInvocationID.z; // worker id |
89 |
| - |
90 |
| - if (out_col >= out_sizes.x || out_row >= out_sizes.y) { |
91 |
| - return; |
92 |
| - } |
| 43 | +shared VEC4_T partial_sums[WGS][2]; |
93 | 44 |
|
94 |
| - const int num_blocks = mat1_sizes.x / group_size; |
95 |
| - |
96 |
| - VEC4_T mat1[TILE_ROWS]; |
97 |
| - VEC4_T qmat2[4][2]; |
98 |
| - VEC4_T local_sums[TILE_ROWS][2]; |
99 |
| - |
100 |
| - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { |
101 |
| - local_sums[r][0] = VEC4_T(0); |
102 |
| - local_sums[r][1] = VEC4_T(0); |
103 |
| - } |
104 |
| - |
105 |
| - VEC4_T scales[2]; |
106 |
| - VEC4_T zeros[2]; |
107 |
| - |
108 |
| - $if WEIGHT_STORAGE == "buffer": |
109 |
| - const int qmat2_stride = qmat2_sizes.x >> 2; |
110 |
| - $if PARAMS_STORAGE == "buffer": |
111 |
| - const int qparams_y_stride = out_sizes.x >> 2; |
112 |
| - const int qparams_z_stride = qparams_y_stride * 2; |
113 |
| - |
114 |
| - for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { |
115 |
| - $if PARAMS_STORAGE == "buffer": |
116 |
| - scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx]; |
117 |
| - zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride]; |
118 |
| - |
119 |
| - scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1]; |
120 |
| - zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride]; |
121 |
| - $else: |
122 |
| - scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0); |
123 |
| - zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0); |
124 |
| - |
125 |
| - scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); |
126 |
| - zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0); |
127 |
| - |
128 |
| - for (uint g_idx = 4 * wid; g_idx < group_size; g_idx += (4 * NWORKERS)) { |
129 |
| - const uint k = block_idx * group_size + g_idx; |
130 |
| - |
131 |
| - // Preload B |
132 |
| - [[unroll]] for (int r = 0; r < 4; ++r) { |
133 |
| - $if WEIGHT_STORAGE == "buffer": |
134 |
| - const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; |
135 |
| - $else: |
136 |
| - const uvec4 packed_weight_tex = texelFetch( |
137 |
| - t_qmat2, |
138 |
| - ivec2(gl_GlobalInvocationID.x, k + r), |
139 |
| - 0); |
140 |
| - |
141 |
| - qmat2[r][0] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0) * scales[0] + zeros[0]; |
142 |
| - qmat2[r][1] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0) * scales[1] + zeros[1]; |
143 |
| - } |
| 45 | +$if IO_STORAGE == "buffer": |
| 46 | + #define BUFFER_IO |
| 47 | +$if WEIGHT_STORAGE == "buffer": |
| 48 | + #define BUFFER_WEIGHT |
144 | 49 |
|
145 |
| - // Preload A |
146 |
| - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { |
147 |
| - $if IN_STORAGE == "buffer": |
148 |
| - mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2]; |
149 |
| - $else: |
150 |
| - mat1[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0); |
151 |
| - } |
| 50 | +#include "qlinear_utils.glslh" |
152 | 51 |
|
153 |
| - // Accumulate local output tile |
154 |
| - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { |
155 |
| - local_sums[r][0] += mat1[r].x * qmat2[0][0] |
156 |
| - + mat1[r].y * qmat2[1][0] |
157 |
| - + mat1[r].z * qmat2[2][0] |
158 |
| - + mat1[r].w * qmat2[3][0]; |
159 |
| - |
160 |
| - local_sums[r][1] += mat1[r].x * qmat2[0][1] |
161 |
| - + mat1[r].y * qmat2[1][1] |
162 |
| - + mat1[r].z * qmat2[2][1] |
163 |
| - + mat1[r].w * qmat2[3][1]; |
164 |
| - } |
| 52 | +void main() { |
| 53 | + const uint lid = gl_LocalInvocationID.x; |
| 54 | + const uint n8 = gl_GlobalInvocationID.y; |
| 55 | + // The output tensor will have a shape of [n, 1, 1, 1]. Each thread computes |
| 56 | + // 8 output elements, so each thread will write to 8 elements starting at the |
| 57 | + // tensor index (gid.x * 8, 0, 0, 0). |
| 58 | + const uint n = MUL_8(n8); |
| 59 | + const uint K4 = DIV_UP_4(input_sizes.x); |
| 60 | + |
| 61 | + const int block_num = input_sizes.x / group_size; |
| 62 | + |
| 63 | + VEC4_T out_texels[2]; |
| 64 | + out_texels[0] = VEC4_T(0); |
| 65 | + out_texels[1] = VEC4_T(0); |
| 66 | + |
| 67 | + // initialize the group index to a value larger than the largest possible |
| 68 | + uint cur_group_idx = input_sizes.x; |
| 69 | + |
| 70 | + // Each thread in the work group accumulates a partial result. |
| 71 | + for (uint k4 = lid; k4 < DIV_UP_4(input_sizes.x); k4 += WGS) { |
| 72 | + const uint k = MUL_4(k4); |
| 73 | + const uint group_idx = k / group_size; |
| 74 | + |
| 75 | + VEC4_T scales[2]; |
| 76 | + VEC4_T zeros[2]; |
| 77 | + |
| 78 | + // Only update the scales/zeros if the current iteration is now working on a |
| 79 | + // new quantization group. |
| 80 | + if (group_idx != cur_group_idx) { |
| 81 | + // The qparams tensor contains the quantization scales and zeros, with |
| 82 | + // shape [2, N, K / group_size, 1]. |
| 83 | + // Loading a texel from the qparams tensor will return 2 scales and 2 |
| 84 | + // zeros for 2 adjacent output channels. |
| 85 | + uint qparams_bufi = group_idx * DIV_2(output_sizes.x) + DIV_2(n); |
| 86 | + VEC4_T scales_zeros_texels[4]; |
| 87 | + $for comp in range(4): |
| 88 | + scales_zeros_texels[${comp}] = t_qparams[qparams_bufi++]; |
| 89 | + |
| 90 | + scales[0] = VEC4_T(scales_zeros_texels[0].xz, scales_zeros_texels[1].xz); |
| 91 | + zeros[0] = VEC4_T(scales_zeros_texels[0].yw, scales_zeros_texels[1].yw); |
| 92 | + |
| 93 | + scales[1] = VEC4_T(scales_zeros_texels[2].xz, scales_zeros_texels[3].xz); |
| 94 | + zeros[1] = VEC4_T(scales_zeros_texels[2].yw, scales_zeros_texels[3].yw); |
| 95 | + |
| 96 | + cur_group_idx = group_idx; |
165 | 97 | }
|
| 98 | + // The input tensor will have a shape of [K, 1, 1, 1]; in each iteration, |
| 99 | + // load 4 elements starting from the tensor index (k, 0, 0, 0). |
| 100 | + VEC4_T in_texel = load_input_texel(k4); |
| 101 | + // Extract each element of the in_texel into a separate vectorized variable; |
| 102 | + // these are used to "broadcast" the input values in subsequent fma calls. |
| 103 | + VEC4_T in_texel_val[4]; |
| 104 | + $for comp in range(4): |
| 105 | + in_texel_val[${comp}] = VEC4_T(in_texel[${comp}]); |
| 106 | + |
| 107 | + uvec4 packed_weight_block = load_transposed_weight_block(k4, n8, K4); |
| 108 | + |
| 109 | + VEC4_T weight_texels[2]; |
| 110 | + $for comp in range(4): |
| 111 | + { |
| 112 | + weight_texels[0].x = extract_4bit_from_transposed_block(packed_weight_block, 0, ${comp}); |
| 113 | + weight_texels[0].y = extract_4bit_from_transposed_block(packed_weight_block, 1, ${comp}); |
| 114 | + weight_texels[0].z = extract_4bit_from_transposed_block(packed_weight_block, 2, ${comp}); |
| 115 | + weight_texels[0].w = extract_4bit_from_transposed_block(packed_weight_block, 3, ${comp}); |
| 116 | + |
| 117 | + weight_texels[1].x = extract_4bit_from_transposed_block(packed_weight_block, 4, ${comp}); |
| 118 | + weight_texels[1].y = extract_4bit_from_transposed_block(packed_weight_block, 5, ${comp}); |
| 119 | + weight_texels[1].z = extract_4bit_from_transposed_block(packed_weight_block, 6, ${comp}); |
| 120 | + weight_texels[1].w = extract_4bit_from_transposed_block(packed_weight_block, 7, ${comp}); |
| 121 | + |
| 122 | + weight_texels[0] = fma(weight_texels[0], scales[0], zeros[0]); |
| 123 | + weight_texels[1] = fma(weight_texels[1], scales[1], zeros[1]); |
| 124 | + |
| 125 | + out_texels[0] = fma(in_texel_val[${comp}], weight_texels[0], out_texels[0]); |
| 126 | + out_texels[1] = fma(in_texel_val[${comp}], weight_texels[1], out_texels[1]); |
| 127 | + } |
166 | 128 | }
|
167 | 129 |
|
168 |
| - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { |
169 |
| - partial_sums[gid][wid][r][0] = local_sums[r][0]; |
170 |
| - partial_sums[gid][wid][r][1] = local_sums[r][1]; |
171 |
| - } |
| 130 | + partial_sums[lid][0] = out_texels[0]; |
| 131 | + partial_sums[lid][1] = out_texels[1]; |
172 | 132 |
|
173 | 133 | memoryBarrierShared();
|
174 | 134 | barrier();
|
175 | 135 |
|
176 |
| - if (wid != 0) { |
177 |
| - return; |
178 |
| - } |
179 |
| - |
180 |
| - VEC4_T sums[TILE_ROWS][2]; |
181 |
| - |
182 |
| - for (int r = 0; r < TILE_ROWS; ++r) { |
183 |
| - sums[r][0] = VEC4_T(0); |
184 |
| - sums[r][1] = VEC4_T(0); |
185 |
| - [[unroll]] for (int worker = 0; worker < NWORKERS; ++ worker) { |
186 |
| - sums[r][0] += partial_sums[gid][worker][r][0]; |
187 |
| - sums[r][1] += partial_sums[gid][worker][r][1]; |
| 136 | + // Tree reduction to compute the overall result. |
| 137 | + for (int i = WGS / 2; i > 0; i /= 2) { |
| 138 | + if (lid < i) { |
| 139 | + partial_sums[lid][0] = partial_sums[lid][0] + partial_sums[lid + i][0]; |
| 140 | + partial_sums[lid][1] = partial_sums[lid][1] + partial_sums[lid + i][1]; |
188 | 141 | }
|
| 142 | + memoryBarrierShared(); |
| 143 | + barrier(); |
189 | 144 | }
|
190 | 145 |
|
191 |
| - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { |
192 |
| - $if OUT_STORAGE == "buffer": |
193 |
| - t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = sums[r][0]; |
194 |
| - t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = sums[r][1]; |
195 |
| - $else: |
196 |
| - imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), sums[r][0]); |
197 |
| - imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), sums[r][1]); |
| 146 | + // Only the first thread will write out result |
| 147 | + if (lid == 0) { |
| 148 | + out_texels[0] = partial_sums[0][0]; |
| 149 | + out_texels[1] = partial_sums[0][1]; |
| 150 | + |
| 151 | + uint n4 = DIV_4(n); |
| 152 | + write_output_texel(out_texels[0], n4); |
| 153 | + write_output_texel(out_texels[1], n4 + 1); |
198 | 154 | }
|
199 | 155 | }
|
0 commit comments