From 04b82dfa725268e88b83ecd29dcf5deda2a01c54 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Tue, 4 Mar 2025 14:36:27 -0600 Subject: [PATCH] Dont fuse broadcast after conv/gemm in mlir (#3863) --- src/targets/gpu/fuse_mlir.cpp | 12 ++++++++++- test/gpu/fuse_mlir.cpp | 38 +++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 5d142470023..c260ab40815 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -641,9 +641,19 @@ struct find_mlir_fused_ops { mlir_mode conv_mode = mlir_mode::none; mlir_mode dot_mode = mlir_mode::none; + + static auto make_conv_dot_reshaper_names() + { + auto names = reshaper_names(); + names.erase("broadcast"); + names.erase("multibroadcast"); + return names; + } + auto matcher() const { - auto dot_or_conv = match::skip(match::name(reshaper_names()))( + static const auto conv_dot_reshaper_names = make_conv_dot_reshaper_names(); + auto dot_or_conv = match::skip(match::name(conv_dot_reshaper_names))( match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op")); return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x"))); } diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 47a2dd31031..59a00b78598 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -199,6 +199,44 @@ TEST_CASE(dot_transpose_reshape_add) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(conv_broadcast_mul) +{ + migraphx::shape os{migraphx::shape::float_type, {4, 56, 122, 122}}; + migraphx::shape is{migraphx::shape::float_type, {4, 14, 1, 1}}; + migraphx::shape ws{migraphx::shape::float_type, {56, 14, 1, 1}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", os); + auto w = mm->add_parameter("w", ws); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w); + auto convb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", os.lens()}}), conv); + auto mul = add_pointwise(p1, "main:pointwise0", {convb, y}, single_pointwise("mul")); + mm->add_return({mul}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", os); + auto w = mm->add_parameter("w", ws); + auto conv = add_mlir( + p2, "mlir_convolution0", {x, w}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) { + auto c = + pm->add_instruction(migraphx::make_op("convolution"), inputs[0], inputs[1]); + return std::make_tuple(c->get_operator(), c); + }); + auto convb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", os.lens()}}), conv); + auto mul = add_pointwise(p2, "main:pointwise0", {convb, y}, single_pointwise("mul")); + mm->add_return({mul}); + } + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(multi_use_dot_trans_add_pooling_sub) { migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 5}};