Skip to content

Commit

Permalink
Improve performance by using hipBlasLt
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Mar 4, 2025
2 parents c0d5b78 + bff0b1c commit fe2b2e7
Show file tree
Hide file tree
Showing 51 changed files with 1,317 additions and 86 deletions.
3 changes: 1 addition & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -323,7 +323,6 @@ rocm_enable_cppcheck(
${CMAKE_CURRENT_SOURCE_DIR}/test/include
DEFINE
MIGRAPHX_MLIR=1
MIGRAPHX_ENABLE_HIPBLASLT_GEMM=1
MIGRAPHX_HAS_EXECUTORS=0
CPPCHECK=1
MIGRAPHX_USE_MIOPEN=1
Expand Down
4 changes: 3 additions & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def rocmnodename(name) {
node_name = "${rocmtest_name} && navi21";
} else if(name == "mi100+") {
node_name = "${rocmtest_name} && (gfx908 || gfx90a) && !vm";
} else if(name == "mi200+") {
node_name = "${rocmtest_name} && (gfx90a || gfx942) && !vm";
} else if(name == "cdna") {
node_name = "${rocmtest_name} && (gfx908 || gfx90a || vega20) && !vm";
} else if(name == "navi32") {
Expand Down Expand Up @@ -160,7 +162,7 @@ node("(rocmtest || migraphx)") {
}
}

rocmtest clang_debug: rocmnode('mi100+') { cmake_build ->
rocmtest clang_debug: rocmnode('mi200+') { cmake_build ->
stage('hipRTC Debug') {
// Disable MLIR since it doesnt work with all ub sanitizers
withEnv(['MIGRAPHX_DISABLE_MLIR=1']) {
Expand Down
7 changes: 4 additions & 3 deletions docs/dev/env_vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ Use it in conjunction with ``MIGRAPHX_DISABLE_MLIR=1``.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Disables use of the rocMLIR library.

.. envvar:: MIGRAPHX_ENABLE_HIPBLASLT_GEMM
Set to "1", "enable", "enabled", "yes", or "true" to use.
Enables use of hipBLASLt.
.. envvar:: MIGRAPHX_SET_GEMM_PROVIDER

Set to "hipblaslt" to use hipBLASLt.
Set to "rocblas" to use rocBLAS.

.. envvar:: MIGRAPHX_COPY_LITERALS

Expand Down
4 changes: 2 additions & 2 deletions src/include/migraphx/op/binary.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -88,7 +88,7 @@ struct binary : op_name<Derived>
}
else
{
return {s0.type(), s0.lens()};
return shape::from_permutation(s0.type(), s0.lens(), find_permutation({s0, s1}));
}
}

Expand Down
15 changes: 12 additions & 3 deletions src/include/migraphx/op/quantizelinear.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -87,8 +87,17 @@ struct quantizelinear
auto min_value = std::numeric_limits<quant_type>::lowest();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
double quantized = static_cast<double>(std::nearbyint(input[i] / scales[i])) +
static_cast<double>(zero_pts[i]);
double quantized;
if constexpr(std::is_integral<quant_type>{})
{
quantized = static_cast<double>(std::nearbyint(input[i] / scales[i])) +
static_cast<double>(zero_pts[i]);
}
else
{
quantized = static_cast<double>(input[i]) / static_cast<double>(scales[i]) +
static_cast<double>(zero_pts[i]);
}
output[i] = std::max(static_cast<double>(min_value),
std::min(static_cast<double>(max_value), quantized));
});
Expand Down
14 changes: 10 additions & 4 deletions src/layout_convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -50,6 +50,13 @@ std::vector<int64_t> get_permutation(instruction_ref ins, const layout_convoluti
return find_permutation(ins->inputs().front()->get_shape());
}

std::vector<int64_t> get_default_permutation(instruction_ref ins)
{
std::vector<int64_t> perm(ins->get_shape().ndim());
std::iota(perm.begin(), perm.end(), 0);
return perm;
}

bool skip_layout(const shape& s)
{
return s.ndim() == 1 or s.dynamic() or s.type() == shape::tuple_type;
Expand Down Expand Up @@ -91,10 +98,9 @@ void transform_convolutions(module& m, const layout_convolution& lc)
if(ins->get_shape().lens().size() != 4)
continue;
auto v = ins->get_operator().to_value();
if(v.at("group").to<int>() > 1)
continue;
bool is_group_conv = v.at("group").to<int>() > 1;
auto args = ins->inputs();
auto perm = get_permutation(ins, lc);
auto perm = is_group_conv ? get_default_permutation(ins) : get_permutation(ins, lc);
std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) {
return m.insert_instruction(ins, make_op("layout", {{"permutation", perm}}), i);
});
Expand Down
25 changes: 24 additions & 1 deletion src/onnx/conv.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -22,6 +22,9 @@
* THE SOFTWARE.
*/
#include <migraphx/onnx/conv.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/permutation.hpp>
#include <algorithm>

namespace migraphx {
Expand All @@ -47,6 +50,26 @@ void recalc_conv_attributes(value& v, size_t kdims)
}
}

static instruction_ref
apply_nhwc_perm(const onnx_parser::node_info& info, instruction_ref ins, bool invert)
{
std::vector<int64_t> perm(ins->get_shape().ndim());
std::iota(begin(perm) + 1, end(perm) - 1, 2);
perm.back() = 1;
return info.add_instruction(
make_op("transpose", {{"permutation", invert ? invert_permutation(perm) : perm}}), ins);
}

instruction_ref from_nhwc(const onnx_parser::node_info& info, instruction_ref ins)
{
return apply_nhwc_perm(info, ins, true);
}

instruction_ref to_nhwc(const onnx_parser::node_info& info, instruction_ref ins)
{
return apply_nhwc_perm(info, ins, false);
}

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
7 changes: 6 additions & 1 deletion src/onnx/include/migraphx/onnx/conv.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,13 +26,18 @@

#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/instruction_ref.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

void recalc_conv_attributes(value& v, size_t kdims);

instruction_ref from_nhwc(const onnx_parser::node_info& info, instruction_ref ins);
instruction_ref to_nhwc(const onnx_parser::node_info& info, instruction_ref ins);

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Expand Down
16 changes: 15 additions & 1 deletion src/onnx/parse_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ struct parse_convolution : op_parser<parse_convolution>
{
std::vector<op_desc> operators() const
{
return {{"Conv", "convolution"}, {"ConvInteger", "quant_convolution"}};
return {{"Conv", "convolution"},
{"ConvInteger", "quant_convolution"},
{"NhwcConv", "convolution"}};
}

// Convert to half prior to a shift to ensure we preserve accuracy here then
Expand Down Expand Up @@ -240,6 +242,13 @@ struct parse_convolution : op_parser<parse_convolution>
auto values = op.to_value();
auto x = args[0];
auto weights = args[1];

if(opd.onnx_name == "NhwcConv")
{
x = from_nhwc(info, x);
weights = from_nhwc(info, weights);
}

auto x_shape = x->get_shape();
auto w_shape = weights->get_shape();
auto in_lens = x_shape.max_lens();
Expand Down Expand Up @@ -362,6 +371,11 @@ struct parse_convolution : op_parser<parse_convolution>
ret = info.add_bias(args, ret, 1);
}

if(opd.onnx_name == "NhwcConv")
{
ret = to_nhwc(info, ret);
}

return ret;
}
};
Expand Down
84 changes: 75 additions & 9 deletions src/onnx/parse_groupnorm.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -25,43 +25,96 @@
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/permutation.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

static instruction_ref
apply_channels_last_perm(const onnx_parser::node_info& info, instruction_ref ins, bool invert)
{
std::vector<int64_t> perm(ins->get_shape().ndim());
std::iota(perm.begin() + 1, perm.end() - 1, 2);
perm.back() = 1;
return info.add_instruction(
make_op("transpose", {{"permutation", invert ? invert_permutation(perm) : perm}}), ins);
}

struct parse_groupnorm : op_parser<parse_groupnorm>
{
std::vector<op_desc> operators() const { return {{"GroupNormalization"}}; }
std::vector<op_desc> operators() const
{
return {{"GroupNormalization", "GroupNorm"}, {"GroupNorm", "Contrib_GroupNorm"}};
}

instruction_ref parse(const op_desc& /*opd*/,
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
bool is_contrib = (opd.op_name == ("Contrib_GroupNorm"));

float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
size_t num_groups;
if(contains(info.attributes, "num_groups"))
if(contains(info.attributes, "num_groups") or contains(info.attributes, "groups"))
{
num_groups = parser.parse_value(info.attributes.at("num_groups")).at<size_t>();
if(is_contrib)
{
num_groups =
std::abs(parser.parse_value(info.attributes.at("groups")).at<int64_t>());
}
else
{
num_groups =
std::abs(parser.parse_value(info.attributes.at("num_groups")).at<int64_t>());
}
}
else
{
MIGRAPHX_THROW("PARSE_GROUPNORM: num_groups must be available");
}

bool is_channels_last = false;
if(is_contrib)
{ // default state for GroupNorm Contrib op
is_channels_last = true;
if(contains(info.attributes, "channels_last"))
{
is_channels_last =
(1 == parser.parse_value(info.attributes.at("channels_last")).at<size_t>());
}
}

bool silu_activation = false;
if(contains(info.attributes, "activation") and is_contrib)
{
silu_activation =
(1 == parser.parse_value(info.attributes.at("activation")).at<size_t>());
}
else if(is_contrib)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: activation must be available");
}

if(args.size() != 3)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input count");
}

auto x = args.at(0);
auto scale = args.at(1);
auto bias = args.at(2);
// Adjust chanels from channels_last-> NCHW if last channel is set for contrib op
auto x = args.at(0);
if(is_channels_last and is_contrib)
{
x = apply_channels_last_perm(info, x, true);
}

auto scale = args.at(1); // gamma in the GroupNorm contrib case
auto bias = args.at(2); // beta in the GroupNorm contrib case

auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
Expand Down Expand Up @@ -120,7 +173,20 @@ struct parse_groupnorm : op_parser<parse_groupnorm>
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast);
auto y = info.add_instruction(make_op("add"), scaled, bias_bcast);
return info.add_instruction(make_op("reshape", {{"dims", x_dims}}), y);
auto output = info.add_instruction(make_op("reshape", {{"dims", x_dims}}), y);

// Convert to NCHW -> channels_last for contrib GroupNorm
if(is_channels_last and is_contrib)
{
output = apply_channels_last_perm(info, output, false);
}
if(silu_activation)
{
// SiLU activation is just out = x * sigmoid(x)
auto sigmoid = info.add_instruction(make_op("sigmoid"), output);
output = info.add_instruction(make_op("mul"), output, sigmoid);
}
return output;
}
};

Expand Down
Loading

0 comments on commit fe2b2e7

Please sign in to comment.