Skip to content

[ET-VK] Enable int8 tiled compute shader to be used with buffer tensors #10302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gh/SS-JIA/214/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@

${define_required_extensions(DTYPE)}

$if STORAGE == "buffer":
$if WEIGHT_STORAGE == "buffer":
${define_required_extensions("int8")}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_weight", "int8", STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_scales", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)}


layout(push_constant) uniform restrict Block {
Expand All @@ -50,7 +50,7 @@ void main() {
VEC4_T b[4];
VEC4_T c[TILE_ROWS];

$if STORAGE == "buffer":
$if SCALES_STORAGE == "buffer":
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
$else:
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0));
Expand All @@ -62,15 +62,15 @@ void main() {
for (int pos = 0; pos < in_sizes.x; pos += 4) {
// Preload weight tensor
[[unroll]] for (int i = 0; i < 4; i++) {
$if STORAGE == "buffer":
b[i] = t_weight[((pos + i) * B_sizes.x + out_col) >> 2];
$if WEIGHT_STORAGE == "buffer":
b[i] = t_weight[((pos + i) * out_sizes.x + out_col) >> 2];
$else:
b[i] = VEC4_T(texelFetch(t_weight, ivec3(out_col >> 2, pos + i, 0), 0));
b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0));
}

// Preload input tensor
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
$if STORAGE == "buffer":
$if IN_STORAGE == "buffer":
a[i] = t_in[((out_row + i) * in_sizes.x + (pos)) >> 2];
$else:
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
Expand All @@ -84,8 +84,10 @@ void main() {

// Store output tensor
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
$if STORAGE == "buffer":
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
$if OUT_STORAGE == "buffer":
if (out_row + i < out_sizes.y) {
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
}
$else:
imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales);
}
Expand Down
27 changes: 20 additions & 7 deletions backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,25 @@
q_8w_linear_tiled:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
IN_STORAGE: texture3d
OUT_STORAGE: texture3d
WEIGHT_STORAGE: texture2d
SCALES_STORAGE: buffer
TILE_ROWS: 4
generate_variant_forall:
TILE_ROWS:
- VALUE: 1
SUFFIX: o4x1
- VALUE: 4
SUFFIX: o4x4
- VALUE: 6
SUFFIX: o4x6
shader_variants:
- NAME: q_8w_linear_tiled_o4x4_texture3d_float
STORAGE: texture3d
TILE_ROWS: 4
- NAME: q_8w_linear_tiled_o4x6_texture3d_float
STORAGE: texture3d
TILE_ROWS: 6
- NAME: q_8w_linear_tiled_texture3d_texture3d_texture2d_float
- NAME: q_8w_linear_tiled_buffer_buffer_texture2d_float
IN_STORAGE: buffer
OUT_STORAGE: buffer
- NAME: q_8w_linear_tiled_buffer_buffer_buffer_float
IN_STORAGE: buffer
OUT_STORAGE: buffer
WEIGHT_STORAGE: buffer
42 changes: 27 additions & 15 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,28 +146,45 @@ void add_q_8w_linear_tiled_node(
const ValueRef q_mat2_data,
const ValueRef scales_data,
const ValueRef out) {
utils::StorageType stype = graph.storage_type_of(out);
utils::StorageType q_mat2_storage = utils::kTexture2D;

uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim();
std::vector<int64_t> qmat2_orig_sizes = graph.sizes_of(q_mat2_data);
const int64_t ndim = graph.dim_of(q_mat2_data);
const int64_t K = qmat2_orig_sizes.at(ndim - 1);
const int64_t N = qmat2_orig_sizes.at(ndim - 2);

if (N > max_extent * 4 || K > max_extent) {
q_mat2_storage = utils::kBuffer;
}

ValueRef q_mat2 = prepack_standard_hw_transposed(
graph, q_mat2_data, stype, utils::kWidthPacked);
graph, q_mat2_data, q_mat2_storage, utils::kWidthPacked);

ValueRef scales =
prepack_standard(graph, scales_data, stype, utils::kWidthPacked);
prepack_standard(graph, scales_data, utils::kBuffer, utils::kWidthPacked);

std::string kernel_name = "q_8w_linear_tiled";
kernel_name.reserve(kShaderNameReserve);
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1));
add_storage_type_suffix(kernel_name, graph.storage_type_of(q_mat2));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
const int64_t M = utils::val_at(-2, mat1_sizes);
int out_tile_nrows = 4;
if (M % 6 == 0) {
kernel_name += "_o4x6";
out_tile_nrows = 6;
} else if (M % 1 == 0) {
kernel_name += "_o4x1";
out_tile_nrows = 1;
} else {
kernel_name += "_o4x4";
out_tile_nrows = 4;
}

add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

utils::uvec3 global_wg_size = graph.logical_limits_of(out);
global_wg_size[1] = global_wg_size[1] / out_tile_nrows;

Expand Down Expand Up @@ -209,18 +226,13 @@ bool can_use_tiled_impl(
if (graph.size_at<int>(-1, mat1) % 4 != 0) {
return false;
}
// Check that M is a multiple of 4 or 6
if (graph.size_at<int>(-2, mat1) % 4 != 0 &&
graph.size_at<int>(-2, mat1) % 6 != 0) {
return false;
}
// Check that the storage type is texture
// TODO(ssjia): Add support for buffer storage in the tiled impl
if (graph.storage_type_of(out) != utils::kTexture3D) {
// Check that N is a multiple of 4
if (graph.size_at<int>(-1, out) % 4 != 0) {
return false;
}
// Check that the packed dim is the width dim
if (graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
if (graph.packed_dim_of(mat1) != WHCN::kWidthDim &&
graph.packed_dim_of(out) != WHCN::kWidthDim) {
return false;
}
// Check that no special axis mapping is used for the input
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def get_linear_inputs():
@register_test_suite("aten._weight_int8pack_mm.default")
def get_weight_int8pack_mm_inputs():
MKN_list = [
[3, 480, 256],
[6, 480, 256],
[6, 256, 1024],
[6, 1024, 256],
Expand Down
Loading