Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add changes for contiguous transpose gemm fusion for hipblaslt #3706

Merged
merged 2 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 77 additions & 10 deletions src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,16 +753,8 @@ struct find_hipblas_gemm_pointwise : gemm_pointwise
};
#endif

struct find_contiguous_tranpose_gemm
struct contiguous_transpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
}

template <class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{
Expand All @@ -773,6 +765,17 @@ struct find_contiguous_tranpose_gemm
std::swap(perm2[i], perm2[j]);
return perm2 == perm;
}
};

struct find_contiguous_transpose_rocblas_gemm : contiguous_transpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
}

void apply(module& m, const match::matcher_result& r) const
{
Expand Down Expand Up @@ -811,6 +814,67 @@ struct find_contiguous_tranpose_gemm
}
};

#if MIGRAPHX_USE_HIPBLASLT
struct find_contiguous_transpose_hip_gemm : contiguous_transpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(
match::name("gpu::hipblaslt_op")(match::used_once()).bind("hip_gemm")))
.bind("transpose")));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["hip_gemm"];
auto gemm_op = any_cast<hipblaslt_op>(gemm_ins->get_operator()).op;

if(gemm_op.name() != "gpu::hip_gemm")
return;

auto gemm = any_cast<hip_gemm<op::dot>>(gemm_op);

auto alloc = gemm_ins->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);

if(perm.size() < 3)
return;

if(not is_swapped(perm, perm.size() - 3, perm.size() - 2))
return;

auto lens = gemm_ins->get_shape().lens();
if(lens.size() > 3 and
not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
return;

gemm.trans_batch = 1;

auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)};
auto new_alloc =
m.insert_instruction(gemm_ins, make_op("allocate", {{"shape", to_value(s)}}));

auto alloc_transpose = m.insert_instruction(
gemm_ins, make_op("transpose", {{"permutation", perm}}), new_alloc);

auto inputs = gemm_ins->inputs();
inputs.back() = alloc_transpose;
operation new_gemm_op = gemm;
auto new_gemm = m.insert_instruction(
gemm_ins, make_op("gpu::hipblaslt_op", {{"op", to_value(new_gemm_op)}}), inputs);

auto gemm_transpoe = m.insert_instruction(gemm_ins, transpose->get_operator(), new_gemm);

m.replace_instruction(ins, gemm_transpoe);
}
};
#endif

struct find_commutative_broadcast
{
auto matcher() const
Expand Down Expand Up @@ -980,7 +1044,10 @@ void fuse_ops::apply(module& m) const
match::find_matches(m,
find_layernorm_pointwise{},
find_concat_pointwise{},
find_contiguous_tranpose_gemm{},
find_contiguous_transpose_rocblas_gemm{},
#if MIGRAPHX_USE_HIPBLASLT
find_contiguous_transpose_hip_gemm{},
#endif
find_commutative_broadcast{});
match::find_matches(m, find_contiguous{});
}
Expand Down
14 changes: 14 additions & 0 deletions src/targets/gpu/hip_gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <migraphx/permutation.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -110,6 +111,19 @@ void blas_shape_hip(const shape& s)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
}

shape transpose_batch_hip(const shape& s, unsigned trans_batch)
{
if(trans_batch == 0)
return s;
if(s.lens().size() < 3)
return s;
auto batch = s.lens().size() - 3;
std::vector<int64_t> perm(s.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[batch], perm[batch + trans_batch]);
return shape::from_permutation(s.type(), s.lens(), perm);
}

static bool is_transposed_hip(const shape& s) { return s.transposed() and s.strides().back() != 1; }

static int32_t get_batch_stride_hip(const shape& s)
Expand Down
8 changes: 6 additions & 2 deletions src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,24 @@ namespace gpu {

struct context;
void blas_shape_hip(const shape& s);
shape transpose_batch_hip(const shape& s, unsigned trans_batch);

template <class Op>
struct hip_gemm
{
Op op;
float alpha = 1;
float beta = 0;
unsigned trans_batch = 0;
int32_t solution_idx = 0;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack_join(migraphx::reflect(self.op, f),
pack(f(self.alpha, "alpha"),
f(self.beta, "beta"),
f(self.trans_batch, "trans_batch"),
f(self.solution_idx, "solution_idx")));
}

Expand Down Expand Up @@ -98,10 +102,10 @@ struct hip_gemm
to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type()));
}
return op_out_shape;
return transpose_batch_hip(op_out_shape, trans_batch);
}

return op.compute_shape(in_shapes);
return transpose_batch_hip(op.compute_shape(in_shapes), trans_batch);
}

argument
Expand Down
Loading