diff --git a/src/targets/cpu/lowering.cpp b/src/targets/cpu/lowering.cpp index d63df7c25a4..5dd1225a94e 100644 --- a/src/targets/cpu/lowering.cpp +++ b/src/targets/cpu/lowering.cpp @@ -340,7 +340,7 @@ struct cpu_apply extend_op("softmax", "dnnl::softmax"); extend_op("im2col", "cpu::im2col", false); - extend_op("leaky_relu", "cpu::leaky_relu", false); + // extend_op("leaky_relu", "cpu::leaky_relu", false); extend_op("pad", "cpu::pad", false); extend_op("rnn_var_sl_last_output", "cpu::rnn_var_sl_last_output", false); } diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index ee725b1d638..cba0a0f6786 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -175,6 +175,7 @@ add_library(migraphx_gpu nonzero.cpp pack_args.cpp prefuse_ops.cpp + prepare_mlir.cpp prepare_reduce.cpp perfdb.cpp pooling.cpp diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index cf3cf6f3416..108a0c7483e 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -463,22 +463,23 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) } const std::initializer_list any_type_ops = {"@literal", "@param", "@return"}; const std::initializer_list no_bool_ops = { + "abs", + "add", + "clip", "convolution", - "quant_convolution", + "dequantizelinear", + "div", "dot", + "leaky_relu", + "mul", + "neg", + "pow", + "quant_convolution", "quant_dot", - "add", - "clip", + "quantizelinear", "relu", "sub", - "mul", - "div", - "pow", "where", - "quantizelinear", - "dequantizelinear", - "abs", - "neg", }; const std::initializer_list fp_only_ops = { "ceil", @@ -552,6 +553,8 @@ bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i) return is_pointwise_op_supported_by_mlir(i); } +static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } + MIGRAPHX_PRED_MATCHER(mlir_split_reduce, instruction_ref ins) { if(ins->name() != "split_fused_reduce") diff --git a/src/targets/gpu/include/migraphx/gpu/mlir.hpp b/src/targets/gpu/include/migraphx/gpu/mlir.hpp index d1f19c1e8ef..4a09d94dbe8 100644 --- a/src/targets/gpu/include/migraphx/gpu/mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/mlir.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -53,8 +53,6 @@ struct MIGRAPHX_GPU_EXPORT mlir_code_object std::vector prefill_values = {}; }; -MIGRAPHX_GPU_EXPORT bool is_reduce(const instruction& ins); - MIGRAPHX_GPU_EXPORT mlir_code_object compile_mlir(const context& migraphx_ctx, module m, const std::vector& in_shapes, diff --git a/src/targets/gpu/include/migraphx/gpu/prepare_mlir.hpp b/src/targets/gpu/include/migraphx/gpu/prepare_mlir.hpp new file mode 100644 index 00000000000..230d9b62d81 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/prepare_mlir.hpp @@ -0,0 +1,47 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef MIGRAPHX_GUARD_GPU_PREPARE_MLIR_HPP +#define MIGRAPHX_GUARD_GPU_PREPARE_MLIR_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +namespace gpu { + +struct prepare_mlir +{ + std::string name() const { return "gpu::prepare_mlir"; } + void apply(module& m) const; +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_PREPARE_REDUCE_HPP diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 9ca5c35330c..e9b80ac5e68 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -1070,58 +1071,12 @@ struct mlir_program std::string sym_name; }; -bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } - -static void rewrite_reduce(module& m) -{ - for(auto i : iterator_for(m)) - { - if(is_reduce(*i)) - { - auto reduce_op = i->get_operator().to_value(); - auto reduce_op_name = i->get_operator().name(); - auto reduce_axes = reduce_op["axes"].to_vector(); - auto reduce_lens = i->get_shape().lens(); - auto in_shape = i->inputs().front()->get_shape(); - const auto& in_lens = in_shape.lens(); - assert(in_shape.standard()); - assert(reduce_lens.size() == in_lens.size()); - assert(std::adjacent_find( - reduce_axes.begin(), reduce_axes.end(), [](auto axis_1, auto axis_2) { - return axis_2 - axis_1 > 1; - }) == reduce_axes.end()); - - std::vector new_rsp_dims; - std::vector new_reduce_axes; - for(const auto axis : range(in_shape.ndim())) - { - if(reduce_lens[axis] == in_lens[axis]) - { - new_rsp_dims.push_back(in_lens[axis]); - } - else if(new_reduce_axes.empty()) - { - assert(reduce_lens[axis] == 1); - new_rsp_dims.push_back(-1); - new_reduce_axes.push_back(axis); - } - } - auto rsp_ins = m.insert_instruction( - i, migraphx::make_op("reshape", {{"dims", new_rsp_dims}}), i->inputs().front()); - auto collapsed_reduce = m.insert_instruction( - i, migraphx::make_op(reduce_op_name, {{"axes", new_reduce_axes}}), rsp_ins); - auto rsp_back = m.insert_instruction( - i, migraphx::make_op("reshape", {{"dims", reduce_lens}}), collapsed_reduce); - m.replace_instruction(i, rsp_back); - } - } - migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); -} +static void prepare(module& m) { run_passes(m, {prepare_mlir{}}); } bool is_module_fusible(const module& m, const context& migraphx_ctx, const value& solution) { auto mm = m; - rewrite_reduce(mm); + prepare(mm); mlir_program mp; mp.set_gpu_properties(migraphx_ctx); mp.parse(mm); @@ -1171,7 +1126,7 @@ std::string dump_mlir(module m, const std::vector& inputs) { adjust_param_shapes(m, inputs); } - rewrite_reduce(m); + prepare(m); mlir_program mp; mp.parse(*mr); auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); @@ -1232,7 +1187,7 @@ void dump_mlir_to_file(module m, const std::vector& inputs, const fs::pat { adjust_param_shapes(m, inputs); } - rewrite_reduce(m); + prepare(m); auto name = compute_dump_name(m, ".mlir"); auto f = location / name; @@ -1255,7 +1210,7 @@ mlir_code_object compile_mlir(const context& migraphx_ctx, const value& solution) { adjust_param_shapes(m, in_shapes); - rewrite_reduce(m); + prepare(m); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); static std::mutex mutex; @@ -1336,7 +1291,7 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx, bool exhaustive) { adjust_param_shapes(m, inputs); - rewrite_reduce(m); + prepare(m); mlir_program mp; mp.set_gpu_properties(migraphx_ctx); mp.parse(m); diff --git a/src/targets/gpu/prepare_mlir.cpp b/src/targets/gpu/prepare_mlir.cpp new file mode 100644 index 00000000000..f2f0de341bb --- /dev/null +++ b/src/targets/gpu/prepare_mlir.cpp @@ -0,0 +1,143 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +namespace { + +struct find_reduce +{ + auto matcher() const { return match::name_contains("reduce"); } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto reduce_op = ins->get_operator().to_value(); + auto reduce_op_name = ins->get_operator().name(); + auto reduce_axes = reduce_op["axes"].to_vector(); + auto reduce_lens = ins->get_shape().lens(); + auto in_shape = ins->inputs().front()->get_shape(); + const auto& in_lens = in_shape.lens(); + assert(in_shape.standard()); + assert(reduce_lens.size() == in_lens.size()); + assert(std::adjacent_find( + reduce_axes.begin(), reduce_axes.end(), [](auto axis_1, auto axis_2) { + return axis_2 - axis_1 > 1; + }) == reduce_axes.end()); + + std::vector new_rsp_dims; + std::vector new_reduce_axes; + for(const auto axis : range(in_shape.ndim())) + { + if(reduce_lens[axis] == in_lens[axis]) + { + new_rsp_dims.push_back(in_lens[axis]); + } + else if(new_reduce_axes.empty()) + { + assert(reduce_lens[axis] == 1); + new_rsp_dims.push_back(-1); + new_reduce_axes.push_back(axis); + } + } + auto rsp_ins = m.insert_instruction( + ins, migraphx::make_op("reshape", {{"dims", new_rsp_dims}}), ins->inputs().front()); + auto collapsed_reduce = m.insert_instruction( + ins, migraphx::make_op(reduce_op_name, {{"axes", new_reduce_axes}}), rsp_ins); + auto rsp_back = m.insert_instruction( + ins, migraphx::make_op("reshape", {{"dims", reduce_lens}}), collapsed_reduce); + m.replace_instruction(ins, rsp_back); + } +}; + +struct find_leaky_relu +{ + auto matcher() const { return match::name("leaky_relu"); } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto x_ins = ins->inputs().front(); + + float alpha_f = ins->get_operator().to_value()["alpha"].to(); + auto alpha = m.add_literal(literal{{x_ins->get_shape().type(), {1}}, {alpha_f}}); + auto zero = m.add_literal(literal{{x_ins->get_shape().type(), {1}}, {0.0}}); + + auto greater = insert_common_op(m, ins, make_op("greater"), {x_ins, zero}); + auto mul_alpha = insert_common_op(m, ins, make_op("mul"), {x_ins, alpha}); + + m.replace_instruction(ins, make_op("where"), {greater, x_ins, mul_alpha}); + } +}; + +// mlir has issues sometime when the condition to `where` is not a bool. So this will convert the +// condition to a bool. +struct find_where +{ + auto matcher() const { return match::name("where"); } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto cond_ins = ins->inputs().front(); + + if(cond_ins->get_shape().type() == shape::bool_type) + return; + + auto bool_cond_ins = m.insert_instruction( + ins, make_op("convert", {{"target_type", shape::bool_type}}), cond_ins); + + m.replace_instruction( + ins, make_op("where"), {bool_cond_ins, ins->inputs()[1], ins->inputs()[2]}); + } +}; + +} // namespace + +void prepare_mlir::apply(module& m) const +{ + match::find_matches(m, find_reduce{}, find_leaky_relu{}); + match::find_matches(m, find_where{}); + run_passes(m, {dead_code_elimination{}}); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index 387b323e4f2..89534ddbbb9 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -247,6 +247,44 @@ module { EXPECT(verify_mlir(m)); } +TEST_CASE(conv_add_leaky_relu) +{ + std::string mlir_output = R"__migraphx__( +module { + func.func @mlir_convolution_add_greater_mul_convert_where(%arg0: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg2: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes ${attrs} { + %0 = migraphx.literal(dense<0.000000e+00> : tensor<1xf32>) : <1xf32, 1> + %1 = migraphx.literal(dense<0.00999999977> : tensor<1xf32>) : <1xf32, 1> + %2 = migraphx.convolution %arg2, %arg1 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1> + %3 = migraphx.add %2, %arg0 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1> + %4 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 2, 2, 2]} : <1xf32, 1> -> <1x2x2x2xf32, 0x0x0x0> + %5 = migraphx.greater %3, %4 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 0x0x0x0> -> <1x2x2x2xf32, 8x4x2x1> + %6 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1, 2, 2, 2]} : <1xf32, 1> -> <1x2x2x2xf32, 0x0x0x0> + %7 = migraphx.mul %3, %6 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 0x0x0x0> -> <1x2x2x2xf32, 8x4x2x1> + %8 = migraphx.convert %5 {target_type = 0 : i64} : <1x2x2x2xf32, 8x4x2x1> to <1x2x2x2xsi8, 8x4x2x1> + %9 = migraphx.where %8, %3, %7 : <1x2x2x2xsi8, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1> + return %9 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> + } +} +)__migraphx__"; + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}}); + auto w = m.add_parameter("w", {migraphx::shape::float_type, {2, 8, 3, 3}}); + auto b = m.add_parameter("b", {migraphx::shape::float_type, {1, 2, 2, 2}}); + auto conv = m.add_instruction(migraphx::make_op("convolution"), x, w); + auto add = m.add_instruction(migraphx::make_op("add"), conv, b); + auto relu = m.add_instruction(migraphx::make_op("leaky_relu"), add); + m.add_return({relu}); + auto s = migraphx::gpu::dump_mlir(m); + // Skip test if MLIR is not enabled + if(s.empty()) + return; + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); + + EXPECT(verify_mlir(m)); +} + // The following test checks that a dimension -1, within reshape operator is handled properly.. TEST_CASE(conv_reshape_dim_minus_one) { diff --git a/test/verify/test_conv_add_leaky_relu.cpp b/test/verify/test_conv_add_leaky_relu.cpp new file mode 100644 index 00000000000..1b9fc4e3837 --- /dev/null +++ b/test/verify/test_conv_add_leaky_relu.cpp @@ -0,0 +1,55 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include + +template +struct test_conv_add_leaky_relu : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); + auto bias_literal = + migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}}; + auto bias = mm->add_literal(bias_literal); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + auto bcast_bias = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), + bias); + auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_bias); + mm->add_instruction(migraphx::make_op("leaky_relu"), bias_add); + return p; + } + std::string section() const { return "conv"; } +}; + +template struct test_conv_add_leaky_relu; +template struct test_conv_add_leaky_relu;