Skip to content

Commit

Permalink
Dont fuse broadcast after conv/gemm in mlir (#3863)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Mar 4, 2025
1 parent e66eadb commit 04b82df
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,19 @@ struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;

static auto make_conv_dot_reshaper_names()
{
auto names = reshaper_names();
names.erase("broadcast");
names.erase("multibroadcast");
return names;
}

auto matcher() const
{
auto dot_or_conv = match::skip(match::name(reshaper_names()))(
static const auto conv_dot_reshaper_names = make_conv_dot_reshaper_names();
auto dot_or_conv = match::skip(match::name(conv_dot_reshaper_names))(
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
Expand Down
38 changes: 38 additions & 0 deletions test/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,44 @@ TEST_CASE(dot_transpose_reshape_add)
EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(conv_broadcast_mul)
{
migraphx::shape os{migraphx::shape::float_type, {4, 56, 122, 122}};
migraphx::shape is{migraphx::shape::float_type, {4, 14, 1, 1}};
migraphx::shape ws{migraphx::shape::float_type, {56, 14, 1, 1}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", is);
auto y = mm->add_parameter("y", os);
auto w = mm->add_parameter("w", ws);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto convb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", os.lens()}}), conv);
auto mul = add_pointwise(p1, "main:pointwise0", {convb, y}, single_pointwise("mul"));
mm->add_return({mul});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", is);
auto y = mm->add_parameter("y", os);
auto w = mm->add_parameter("w", ws);
auto conv = add_mlir(
p2, "mlir_convolution0", {x, w}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) {
auto c =
pm->add_instruction(migraphx::make_op("convolution"), inputs[0], inputs[1]);
return std::make_tuple(c->get_operator(), c);
});
auto convb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", os.lens()}}), conv);
auto mul = add_pointwise(p2, "main:pointwise0", {convb, y}, single_pointwise("mul"));
mm->add_return({mul});
}
EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(multi_use_dot_trans_add_pooling_sub)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 5}};
Expand Down

0 comments on commit 04b82df

Please sign in to comment.