Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 76 additions & 35 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,12 @@ enum vk_conv_shapes {
CONV_SHAPE_COUNT,
};

uint32_t conv_shapes_wg_denoms[][3] = {
{ 128, 128, 1 },
{ 64, 32, 1 },
{ 32, 256, 1 },
};

enum dmmv_wg_sizes {
DMMV_WG_SIZE_SUBGROUP,
DMMV_WG_SIZE_LARGE,
Expand Down Expand Up @@ -379,6 +385,18 @@ struct vk_fa_pipeline_state {
}
};

struct vk_conv2d_pipeline_state {
vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH)
: s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {}

uint32_t s0, s1, p0, p1, d0, d1, KW, KH;

bool operator<(const vk_conv2d_pipeline_state &b) const {
return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) <
std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH);
}
};

enum shader_reduction_mode {
SHADER_REDUCTION_MODE_SHMEM,
SHADER_REDUCTION_MODE_HYBRID,
Expand Down Expand Up @@ -668,10 +686,10 @@ struct vk_device_struct {
vk_pipeline pipeline_ssm_conv_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_opt_step_sgd_f32;
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;

Expand Down Expand Up @@ -1244,17 +1262,13 @@ struct vk_op_conv2d_push_constants {
uint32_t nb2;
uint32_t nb3;

// init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH
uint32_t KWmp; uint32_t KWL;
uint32_t KWKHmp; uint32_t KWKHL;
// init_fastdiv_values constants for dividing by OW, OW*OH
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
};

template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
// Compute magic values to divide by KW, KW*KH, OW, OW*OH
init_fastdiv_values(p.KW, p.KWmp, p.KWL);
init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
// Compute magic values to divide by OW, OW*OH
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
}
Expand Down Expand Up @@ -1290,23 +1304,15 @@ struct vk_op_conv_transpose_2d_push_constants {
uint32_t nb2;
uint32_t nb3;

// init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1
uint32_t KWmp; uint32_t KWL;
uint32_t KWKHmp; uint32_t KWKHL;
// init_fastdiv_values constants for dividing by OW, OW*OH
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t s0mp; uint32_t s0L;
uint32_t s1mp; uint32_t s1L;
};

template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) {
// Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1
init_fastdiv_values(p.KW, p.KWmp, p.KWL);
init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
// Compute magic values to divide by OW, OW*OH
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
init_fastdiv_values(p.s0, p.s0mp, p.s0L);
init_fastdiv_values(p.s1, p.s1mp, p.s1L);
}

struct vk_op_conv2d_dw_push_constants {
Expand Down Expand Up @@ -3828,22 +3834,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
switch (s) {
default:
case CONV_SHAPE_128x128:
conv2d_BS_K = 128;
conv2d_BS_NPQ = 128;
conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_128x128][0];
conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_128x128][1];
conv2d_BS_CRS = 16;
if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) {
conv2d_UNROLL = false;
}
break;
case CONV_SHAPE_64x32:
conv2d_BS_K = 64;
conv2d_BS_NPQ = 32;
conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_64x32][0];
conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_64x32][1];
conv2d_BS_CRS = 32;
conv2d_TS_K = 4;
break;
case CONV_SHAPE_32x256:
conv2d_BS_K = 32;
conv2d_BS_NPQ = 256;
conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_32x256][0];
conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_32x256][1];
conv2d_BS_CRS = 16;
break;
}
Expand Down Expand Up @@ -3877,10 +3883,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };

#define CREATE_CONV(name, type_suffix, spv_suffix) \
ggml_vk_create_pipeline( \
device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \
name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
for (auto &c : device->pipeline_##name##type_suffix[s]) { \
const vk_conv2d_pipeline_state &state = c.first; \
std::vector<uint32_t> spec_constants_cpy = spec_constants; \
spec_constants_cpy.push_back(state.s0); \
spec_constants_cpy.push_back(state.s1); \
spec_constants_cpy.push_back(state.p0); \
spec_constants_cpy.push_back(state.p1); \
spec_constants_cpy.push_back(state.d0); \
spec_constants_cpy.push_back(state.d1); \
spec_constants_cpy.push_back(state.KW); \
spec_constants_cpy.push_back(state.KH); \
ggml_vk_create_pipeline( \
device, c.second, #name #type_suffix, \
name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \
}
#define CREATE_CONVS(spv_suffix) \
CREATE_CONV(conv2d, _f32, spv_suffix) \
CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
Expand Down Expand Up @@ -8551,7 +8569,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const

uint32_t tiles[CONV_SHAPE_COUNT];
for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) {
tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]);
tiles[i] = CEIL_DIV(elements[0], conv_shapes_wg_denoms[i][0]) * CEIL_DIV(elements[1], conv_shapes_wg_denoms[i][1]);
}

// We can't query number of shader cores on Intel, use 32 as a placeholder
Expand All @@ -8566,19 +8584,42 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
shape = CONV_SHAPE_64x32;
}

uint32_t KW = static_cast<uint32_t>(src0->ne[0]);
uint32_t KH = static_cast<uint32_t>(src0->ne[1]);
uint32_t s0 = static_cast<uint32_t>(dst->op_params[0]);
uint32_t s1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[1]) : static_cast<uint32_t>(dst->op_params[0]);
uint32_t p0 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[2]) : 0;
uint32_t p1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[3]) : 0;
uint32_t d0 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[4]) : 1;
uint32_t d1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[5]) : 1;

vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH);

std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr;
if (op == GGML_OP_CONV_2D) {
if (src0->type == GGML_TYPE_F32) {
return ctx->device->pipeline_conv2d_f32[shape];
pipelines = &ctx->device->pipeline_conv2d_f32[shape];
} else if (src0->type == GGML_TYPE_F16) {
return ctx->device->pipeline_conv2d_f16_f32[shape];
pipelines = &ctx->device->pipeline_conv2d_f16_f32[shape];
}
} else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
if (src0->type == GGML_TYPE_F32) {
return ctx->device->pipeline_conv_transpose_2d_f32[shape];
pipelines = &ctx->device->pipeline_conv_transpose_2d_f32[shape];
} else if (src0->type == GGML_TYPE_F16) {
return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
pipelines = &ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
}
}

vk_pipeline pipeline = nullptr;

auto it = pipelines->find(conv2d_pipeline_state);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will need to be rebased on #17024 and hold a lock when searching the map.

if (it != pipelines->end()) {
pipeline = it->second;
} else {
(*pipelines)[conv2d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
}

return pipeline;
}
return nullptr;
case GGML_OP_CONV_2D_DW:
Expand Down
75 changes: 39 additions & 36 deletions ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,8 @@ layout(push_constant) uniform parameter {
uint32_t nb3;

// fastdiv helper values
uint32_t KWmp; uint32_t KWL;
uint32_t KWKHmp; uint32_t KWKHL;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
#ifdef TRANSPOSE
uint32_t s0mp; uint32_t s0L;
uint32_t s1mp; uint32_t s1L;
#endif
}

p;
Expand All @@ -84,6 +78,15 @@ layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint use_collectives = 1;
layout(constant_id = 6) const uint SHMEM_PAD = 4;

layout(constant_id = 7) const uint s0 = 1;
layout(constant_id = 8) const uint s1 = 1;
layout(constant_id = 9) const uint p0 = 0;
layout(constant_id = 10) const uint p1 = 0;
layout(constant_id = 11) const uint d0 = 1;
layout(constant_id = 12) const uint d1 = 1;
layout(constant_id = 13) const uint KW = 1;
layout(constant_id = 14) const uint KH = 1;

uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;

Expand All @@ -92,7 +95,7 @@ uint splitWork(uint work_size, uint block_size) {
}

uint32_t K = p.Cout;
uint32_t CRS = p.Cin * p.KH * p.KW;
uint32_t CRS = p.Cin * KH * KW;
uint32_t NPQ = p.N * p.OH * p.OW;

uint32_t n_elems_out = K * NPQ;
Expand Down Expand Up @@ -187,7 +190,7 @@ void main() {
}
#endif
/* Advance block in CRS dim */
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
[[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a;
uint32_t Cin_idx_a;
uint32_t KH_idx_a;
Expand All @@ -200,32 +203,32 @@ void main() {
uint32_t cached_KW_idx;
if (use_collectives == 1) {
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
cached_Cin_idx = cached_CRS_idx / (KW * KH);
uint32_t cached_CRS_remainder = cached_CRS_idx % (KW * KH);
cached_KH_idx = cached_CRS_remainder / KW;
cached_KW_idx = cached_CRS_remainder % KW;

CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
} else {
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
Cin_idx_a = CRS_idx_a / (KW * KH);
uint32_t CRS_remainder = CRS_idx_a % (KW * KH);
KH_idx_a = CRS_remainder / KW;
KW_idx_a = CRS_remainder % KW;
}
#else
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
Cin_idx_a = CRS_idx_a / (KW * KH);
CRS_remainder = CRS_idx_a % (KW * KH);
KH_idx_a = CRS_remainder / KW;
KW_idx_a = CRS_remainder % KW;
#endif

/* Load kernel to A_block: (BS_K x BS_CRS)*/
for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
uint32_t B_ly = r_offset + Ar;
uint32_t B_lx = Ac;
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
Expand Down Expand Up @@ -262,35 +265,35 @@ void main() {
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
} else {
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
Cin_idx_b = CRS_idx_b / (KW * KH);
uint32_t CRS_remainder = CRS_idx_b % (KW * KH);
KH_idx_b = CRS_remainder / KW;
KW_idx_b = CRS_remainder % KW;
}
#else
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
Cin_idx_b = CRS_idx_b / (KW * KH);
uint32_t CRS_remainder = CRS_idx_b % (KW * KH);
KH_idx_b = CRS_remainder / KW;
KW_idx_b = CRS_remainder % KW;
#endif

#ifdef TRANSPOSE
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1;
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0;
uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L);
uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L);
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * d1 + p1;
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * d0 + p0;
uint32_t H_idx = H_idx_x_s1 / s1;
uint32_t W_idx = W_idx_x_s0 / s0;
#else
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
uint32_t H_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
uint32_t W_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
#endif
uint32_t src_idx =
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
float val = src_data[src_idx];
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
|| H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)
#ifdef TRANSPOSE
|| (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
|| (H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0)
#endif
) {
val = 0.0;
Expand Down
Loading