Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse multiple outputs for pointwise and reductions #3752

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions src/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,33 @@ static module::with_inputs append_pointwise_module(instruction_ref ins, instruct
input_map[input] = map_ins[param];
}
}
pm.replace_return(pm.insert_instructions(last, xm, &map_ins));
auto returns = pm.insert_instructions(last, xm, &map_ins);
if (ins->outputs().size() > 1)
{
auto ireturns = pm.get_returns();
returns.insert(returns.end(), ireturns.begin(), ireturns.end());
}
pm.replace_return(returns);
return {std::move(pm), inputs};
}

static bool find_pointwise_modules(module_pass_manager& mpm)
static auto find_input_pointwise(instruction_ref ins, bool multi_out)
{
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->name() == "pointwise" and i->outputs().size() == 1;
});
if(it == ins->inputs().end() and multi_out)
{
it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->name() == "pointwise" and std::none_of(i->outputs().begin(), i->outputs().end(), [&](auto output) {
return output != ins and reaches(output, ins);
});
});
}
return it;
}

static bool find_pointwise_modules(module_pass_manager& mpm, bool multi_out)
{
bool changed = false;
auto last = std::prev(mpm.get_module().end());
Expand All @@ -176,18 +198,25 @@ static bool find_pointwise_modules(module_pass_manager& mpm)
continue;
if(ins->outputs().empty() and ins != last)
continue;
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->name() == "pointwise" and i->outputs().size() == 1;
});
auto it = find_input_pointwise(ins, multi_out);
if(it == ins->inputs().end())
continue;
auto input = *it;

const bool has_multi_out = input->outputs().size() > 1;
auto fused = append_pointwise_module(input, ins);
auto name = fused.mod.name();
mpm.rename_module(name, name + ":" + ins->module_inputs().front()->name() + "-deleted");
auto* new_pm = mpm.create_module(name, std::move(fused.mod));
mpm.get_module().replace_instruction(ins, input->get_operator(), fused.inputs, {new_pm});
auto fins = mpm.get_module().insert_instruction(ins, input->get_operator(), fused.inputs, {new_pm});
if(has_multi_out)
{
auto noutputs = std::max<std::size_t>(1, ins->get_shape().sub_shapes().size());
auto finput = mpm.get_module().insert_instruction(ins, make_op("get_tuple_elem", {{"index", noutputs}}), fins);
mpm.get_module().replace_instruction(input, finput);
if(noutputs == 1)
fins = mpm.get_module().insert_instruction(ins, make_op("get_tuple_elem", {{"index", 0}}), fins);
}
mpm.get_module().replace_instruction(ins, fins);

changed = true;
}
Expand Down Expand Up @@ -252,7 +281,7 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
mpm.run_pass(rewrite_reshapes<pointwise_reshape>{});
if(enable_rewrite_broadcasts)
rewrite_broadcasts(mpm);
if(not find_pointwise_modules(mpm))
if(not find_pointwise_modules(mpm, enable_multi_output))
break;
mpm.run_pass(dead_code_elimination{});
}
Expand Down
21 changes: 15 additions & 6 deletions src/fuse_pointwise_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,21 @@ static std::size_t get_split_size(std::size_t default_split)

void fuse_pointwise_reduce::apply(module_pass_manager& mpm) const
{
mpm.run_pass(fuse_pointwise{.enable_rewrite_reshapes = false});
mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = false});
mpm.run_pass(fuse_pointwise{.enable_rewrite_reshapes = true});
mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = true});
mpm.run_pass(split_reduce{.split_size = get_split_size(split_size)});
mpm.run_pass(fuse_pointwise{.enable_rewrite_broadcasts = true});
if(enable_multi_output)
{
mpm.run_pass(fuse_pointwise{.enable_multi_output = true});
mpm.run_pass(fuse_reduce{.enable_multi_output = true});

}
else
{
mpm.run_pass(fuse_pointwise{.enable_rewrite_reshapes = false});
mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = false});
mpm.run_pass(fuse_pointwise{.enable_rewrite_reshapes = true});
mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = true});
mpm.run_pass(split_reduce{.split_size = get_split_size(split_size)});
mpm.run_pass(fuse_pointwise{.enable_rewrite_broadcasts = true});
}
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
125 changes: 99 additions & 26 deletions src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,18 @@
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
const auto* sm = mods.front();
if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported");
if(not sm->bypass())
MIGRAPHX_THROW("fused_reduce: bypass flag is not set");
auto names = sm->get_parameter_names();
check_shapes{inputs, *this}.has(names.size()).same_ndims();
std::sort(names.begin(), names.end());
auto shapes = sm->get_parameter_shapes();
// Check dimension matches for each input
if(not equal(names, inputs, [&](const auto& name, const auto& input) {
return shapes.at(name).lens() == input.lens();
}))
MIGRAPHX_THROW("Input dimension does not match the submodule.");

return shape::from_permutation(sm->get_output_shapes().front().type(),
sm->get_output_shapes().front().lens(),
find_permutation(inputs));

auto result = sm->compute_shapes(
inputs,
{.name = name(), .strict_type = true, .strict_lens = true});
if(result.size() == 1)
return result.front();
return shape{result};
}

std::string name() const { return "fused_reduce"; }
Expand Down Expand Up @@ -215,22 +210,99 @@
});
}

static auto match_broadcastable_input(const std::string& op, const std::string& name)
{
auto match_op = match::name(op)(used_once_except_broadcast()).bind(name);
auto match_op_input = any_input(match_op, match::used_once());
auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once());
return match::any_of(match_op_input, match_broadcast_axes(broadcast_match_op_input));
}

static void finalize_reduce_module(module_ref m)
{
eliminate_common_subexpression{}.apply(*m);
dead_code_elimination{}.apply(*m);
}

static void move_output_instructions_after(module& m, instruction_ref src, instruction_ref dst)
{
auto d = std::distance(src, dst);
std::vector<std::pair<std::size_t, instruction_ref>> instructions;
fix([&](auto self, instruction_ref ins) {
for(auto output:ins->outputs())
{
if(any_of(instructions, [&](const auto& p) { return p.second == output; }))
continue;
auto i = std::distance(src, output);
if(i >= d)
continue;
instructions.emplace_back(i, output);
self(output);
}
})(src);
std::sort(instructions.begin(), instructions.end(), by(std::less<>{}, [](auto&& p) { return p.first; }));
auto loc = std::next(dst);
for(auto [i, ins]:instructions)
m.move_instruction(ins, loc);
}

namespace {
struct find_pointwise_reduce
struct find_reduce_base
{
bool multi_output = false;
template<class... Ms>
auto any_fusable_input_usage(Ms... ms) const
{
auto m = match::any(ms...);
return match::make_basic_fun_matcher(
[=](match::matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return ctx.matched(m, i) and i->outputs().size() == 1;
});
if(it == ins->inputs().end() and multi_output)
{
it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return ctx.matched(m, i) and std::none_of(i->outputs().begin(), i->outputs().end(), [&](auto output) {
return output != ins and reaches(output, ins);
});
});
}
if(it == ins->inputs().end())
return nullopt;
return *it;
});
}
auto match_broadcastable_input(const std::string& op, const std::string& name) const
{
auto match_op = match::name(op)(used_once_except_broadcast()).bind(name);
auto match_op_input = any_input(match_op, match::used_once());
auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once());
return match::any_of(match_op_input, match_broadcast_axes(broadcast_match_op_input), any_fusable_input_usage(match::name(op).bind(name)));
}
void replace_return(instruction_ref input, module_ref sm, std::vector<instruction_ref> returns) const
{
if (multi_output and input->outputs().size() > 1)
{
auto r = sm->get_returns();
returns.insert(returns.end(), r.begin(), r.end());
}
sm->replace_return(returns);
}

void replace_instruction(instruction_ref input, module& m, instruction_ref ins, operation op, const std::vector<instruction_ref>& args, const std::vector<module_ref>& module_inputs) const

Check warning on line 285 in src/fuse_reduce.cpp

View workflow job for this annotation

GitHub Actions / tidy

the parameter 'op' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param,-warnings-as-errors]
{
if(multi_output and input->outputs().size() > 1)
{
move_output_instructions_after(m, input, ins);
auto fins = m.insert_instruction(ins, op, args, module_inputs);
auto noutputs = std::max<std::size_t>(1, ins->get_shape().sub_shapes().size());
auto finput = m.insert_instruction(ins, make_op("get_tuple_elem", {{"index", noutputs}}), fins);
m.replace_instruction(input, finput);
if(noutputs == 1)
fins = m.insert_instruction(ins, make_op("get_tuple_elem", {{"index", 0}}), fins);
m.replace_instruction(ins, fins);
}
else
{
m.replace_instruction(ins, op, args, module_inputs);
}
}

};
struct find_pointwise_reduce : find_reduce_base
{
auto matcher() const
{
Expand All @@ -251,6 +323,7 @@
// Insert pointwise
auto rins = rm->fuse({input}, &map_ins).front();
map_ins[input] = rins;
rm->add_return({rins});

if(contains(r.instructions, "broadcast"))
{
Expand All @@ -262,15 +335,15 @@
}

// Insert fused_reduce
rm->add_return(insert_module_in_submodule(rm, reduce, &map_ins));
replace_return(input, rm, insert_module_in_submodule(rm, reduce, &map_ins));
finalize_reduce_module(rm);

auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
replace_instruction(input, mpm.get_module(), reduce, reduce->get_operator(), new_inputs, {rm});
}
};

struct find_reduce_pointwise
struct find_reduce_pointwise : find_reduce_base
{

auto matcher() const
Expand Down Expand Up @@ -312,7 +385,7 @@
}
};

struct find_reduce_reduce
struct find_reduce_reduce : find_reduce_base
{
auto matcher() const
{
Expand Down Expand Up @@ -429,9 +502,9 @@
for(int i = 0; i < 4; i++)
{
if(enable_rewrite_reshapes)
mpm.run_pass(rewrite_reshapes<reduce_reshape>{});
mpm.run_pass(rewrite_reshapes<reduce_reshape>{.enable_multi_output = enable_multi_output});
match::find_matches(
mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
mpm, find_reduce_pointwise{}, find_pointwise_reduce{{.multi_output = enable_multi_output}}, find_reduce_reduce{});
mpm.run_pass(dead_code_elimination{});
}
}
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/fuse_pointwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct MIGRAPHX_EXPORT fuse_pointwise

bool enable_rewrite_reshapes = true;
bool enable_rewrite_broadcasts = false;
bool enable_multi_output = false;
};

} // namespace MIGRAPHX_INLINE_NS
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/fuse_pointwise_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct module_pass_manager;
struct MIGRAPHX_EXPORT fuse_pointwise_reduce
{
std::size_t split_size = 32768;
bool enable_multi_output = false;
std::string name() const { return "fuse_pointwise_reduce"; }
void apply(module_pass_manager& mpm) const;
};
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/fuse_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct MIGRAPHX_EXPORT fuse_reduce
void apply(module_pass_manager& mpm) const;

bool enable_rewrite_reshapes = true;
bool enable_multi_output = false;
};

} // namespace MIGRAPHX_INLINE_NS
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ struct MIGRAPHX_EXPORT module
std::size_t size() const;
instruction_ref begin() const;
instruction_ref end() const;
instruction_ref insert_end() const;

struct compute_shapes_options
{
Expand Down
26 changes: 21 additions & 5 deletions src/include/migraphx/rewrite_reshapes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,22 @@ struct rewrite_reshapes_base
template <class T>
struct rewrite_reshapes
{
bool enable_multi_output = false;
std::string name() const { return "rewrite_reshapes"; }
struct find_op_reshape_op
{
std::string op1;
std::string op2;
bool multi_out = false;

auto used_once_or_multi_out() const
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(multi_out)
return true;
return ins->outputs().size() == 1;
});
}

auto matcher() const
{
Expand All @@ -81,7 +92,7 @@ struct rewrite_reshapes
"contiguous",
"multibroadcast",
"broadcast")(match::used_once());
auto pointwise = match::name(op1)(match::used_once());
auto pointwise = match::name(op1)(used_once_or_multi_out());
auto reshapes_pointwise =
reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x"))));
return match::name(op2)(
Expand Down Expand Up @@ -198,6 +209,11 @@ struct rewrite_reshapes
auto rins =
reshape_input(ins, &shape_transform_descriptor::generate_dst_from_common)(pw);
mpm.get_module().replace_instruction(ins, rins);
if(x_ins->outputs().size() > 1)
{
auto r_x_ins = reshape_input(x_ins, &shape_transform_descriptor::generate_src_from_common)(new_x_ins);
mpm.get_module().replace_instruction(x_ins, r_x_ins);
}
}

static bool same_dims(instruction_ref ins)
Expand All @@ -224,14 +240,14 @@ struct rewrite_reshapes
{
if(T::name() == "pointwise")
{
match::find_matches(mpm, find_op_reshape_op{"pointwise", T::name()});
match::find_matches(mpm, find_op_reshape_op{"pointwise", T::name(), enable_multi_output});
}
else
{
match::find_matches(mpm,
find_op_reshape_op{"pointwise", T::name()},
find_op_reshape_op{T::name(), "pointwise"},
find_op_reshape_op{T::name(), T::name()});
find_op_reshape_op{"pointwise", T::name(), enable_multi_output},
find_op_reshape_op{T::name(), "pointwise", enable_multi_output},
find_op_reshape_op{T::name(), T::name(), enable_multi_output});
}
mpm.run_pass(simplify_reshapes{1});
mpm.run_pass(eliminate_common_subexpression{});
Expand Down
2 changes: 2 additions & 0 deletions src/include/migraphx/shape_transform_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor
generate_common_from_dst(const std::vector<std::size_t>& input_dims = {}) const;
std::vector<operation>
generate_dst_from_common(const std::vector<std::size_t>& input_dims = {}) const;
std::vector<operation>
generate_src_from_common(const std::vector<std::size_t>& input_dims = {}) const;
std::vector<std::vector<std::size_t>> common_axes_map_from_src() const;
std::vector<std::vector<std::size_t>> common_axes_map_from_dst() const;

Expand Down
Loading
Loading