diff --git a/src/targets/gpu/fuse_ops.cpp b/src/targets/gpu/fuse_ops.cpp index 50ab5d184f3..0d433aa6e27 100644 --- a/src/targets/gpu/fuse_ops.cpp +++ b/src/targets/gpu/fuse_ops.cpp @@ -753,16 +753,8 @@ struct find_hipblas_gemm_pointwise : gemm_pointwise }; #endif -struct find_contiguous_tranpose_gemm +struct contiguous_transpose_gemm { - auto matcher() const - { - return match::name("gpu::contiguous")(match::arg(0)( - match::name("transpose")( - match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm"))) - .bind("transpose"))); - } - template static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j) { @@ -773,6 +765,17 @@ struct find_contiguous_tranpose_gemm std::swap(perm2[i], perm2[j]); return perm2 == perm; } +}; + +struct find_contiguous_transpose_rocblas_gemm : contiguous_transpose_gemm +{ + auto matcher() const + { + return match::name("gpu::contiguous")(match::arg(0)( + match::name("transpose")( + match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm"))) + .bind("transpose"))); + } void apply(module& m, const match::matcher_result& r) const { @@ -811,6 +814,67 @@ struct find_contiguous_tranpose_gemm } }; +#if MIGRAPHX_USE_HIPBLASLT +struct find_contiguous_transpose_hip_gemm : contiguous_transpose_gemm +{ + auto matcher() const + { + return match::name("gpu::contiguous")(match::arg(0)( + match::name("transpose")( + match::arg(0)( + match::name("gpu::hipblaslt_op")(match::used_once()).bind("hip_gemm"))) + .bind("transpose"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto gemm_ins = r.instructions["hip_gemm"]; + auto gemm_op = any_cast(gemm_ins->get_operator()).op; + + if(gemm_op.name() != "gpu::hip_gemm") + return; + + auto gemm = any_cast>(gemm_op); + + auto alloc = gemm_ins->inputs().back(); + auto transpose = r.instructions["transpose"]; + auto perm = transpose->get_operator().to_value()["permutation"].to_vector(); + auto iperm = invert_permutation(perm); + + if(perm.size() < 3) + return; + + if(not is_swapped(perm, perm.size() - 3, perm.size() - 2)) + return; + + auto lens = gemm_ins->get_shape().lens(); + if(lens.size() > 3 and + not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; })) + return; + + gemm.trans_batch = 1; + + auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)}; + auto new_alloc = + m.insert_instruction(gemm_ins, make_op("allocate", {{"shape", to_value(s)}})); + + auto alloc_transpose = m.insert_instruction( + gemm_ins, make_op("transpose", {{"permutation", perm}}), new_alloc); + + auto inputs = gemm_ins->inputs(); + inputs.back() = alloc_transpose; + operation new_gemm_op = gemm; + auto new_gemm = m.insert_instruction( + gemm_ins, make_op("gpu::hipblaslt_op", {{"op", to_value(new_gemm_op)}}), inputs); + + auto gemm_transpoe = m.insert_instruction(gemm_ins, transpose->get_operator(), new_gemm); + + m.replace_instruction(ins, gemm_transpoe); + } +}; +#endif + struct find_commutative_broadcast { auto matcher() const @@ -980,7 +1044,10 @@ void fuse_ops::apply(module& m) const match::find_matches(m, find_layernorm_pointwise{}, find_concat_pointwise{}, - find_contiguous_tranpose_gemm{}, + find_contiguous_transpose_rocblas_gemm{}, +#if MIGRAPHX_USE_HIPBLASLT + find_contiguous_transpose_hip_gemm{}, +#endif find_commutative_broadcast{}); match::find_matches(m, find_contiguous{}); } diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index 03bda69081e..22f6a92b3af 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -31,6 +31,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -110,6 +111,19 @@ void blas_shape_hip(const shape& s) MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); } +shape transpose_batch_hip(const shape& s, unsigned trans_batch) +{ + if(trans_batch == 0) + return s; + if(s.lens().size() < 3) + return s; + auto batch = s.lens().size() - 3; + std::vector perm(s.lens().size()); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[batch], perm[batch + trans_batch]); + return shape::from_permutation(s.type(), s.lens(), perm); +} + static bool is_transposed_hip(const shape& s) { return s.transposed() and s.strides().back() != 1; } static int32_t get_batch_stride_hip(const shape& s) diff --git a/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp b/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp index 9f74bc02813..8c3d67bcd93 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp @@ -41,6 +41,7 @@ namespace gpu { struct context; void blas_shape_hip(const shape& s); +shape transpose_batch_hip(const shape& s, unsigned trans_batch); template struct hip_gemm @@ -48,13 +49,16 @@ struct hip_gemm Op op; float alpha = 1; float beta = 0; + unsigned trans_batch = 0; int32_t solution_idx = 0; + template static auto reflect(Self& self, F f) { return pack_join(migraphx::reflect(self.op, f), pack(f(self.alpha, "alpha"), f(self.beta, "beta"), + f(self.trans_batch, "trans_batch"), f(self.solution_idx, "solution_idx"))); } @@ -98,10 +102,10 @@ struct hip_gemm to_string(cmat_shape.type()) + ", it must be: " + to_string(op_out_shape.type())); } - return op_out_shape; + return transpose_batch_hip(op_out_shape, trans_batch); } - return op.compute_shape(in_shapes); + return transpose_batch_hip(op.compute_shape(in_shapes), trans_batch); } argument