From 2e455cd95c1592db2f5562ab668f6d3eb46e966f Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 9 Jan 2025 15:19:11 -0600 Subject: [PATCH 1/3] Add exhaustive tune to reduce --- src/targets/gpu/jit/reduce.cpp | 57 ++++++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index bdf7313f5f1..012fc9d89e6 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -79,6 +79,22 @@ static std::vector get_reduce_lens(const std::vector& return reduce_lens; } +static shape get_input_shape(const std::vector& inputs) +{ + auto it = std::max_element(inputs.begin(), + inputs.end(), + by(std::less<>{}, [](const shape& s) { return s.elements(); })); + return *it; +} + +static shape get_reduce_shape(const std::vector& inputs) +{ + auto it = std::min_element(inputs.begin(), + inputs.end(), + by(std::less<>{}, [](const shape& s) { return s.elements(); })); + return *it; +} + template static shape get_reduced_shape(const shape& s, const std::vector& axes) { @@ -310,14 +326,6 @@ struct fused_reduce_compiler : compiler { std::vector names() const { return {"fused_reduce", "split_fused_reduce"}; } - static shape get_input_shape(const std::vector& inputs) - { - auto it = std::max_element(inputs.begin(), - inputs.end(), - by(std::less<>{}, [](const shape& s) { return s.elements(); })); - return *it; - } - operation compile_op(context& ctx, const std::vector& inputs, const value& v) const { auto assign = v.get("assign", "assign_none"); @@ -352,7 +360,7 @@ struct fused_reduce_compiler : compiler auto relements = reduction_shape.elements() / vec.size; if(algo == "block") { - auto block_size = compute_block_size(ctx, relements, 256); + auto block_size = v.get("block_size", compute_block_size(ctx, relements, 256)); if(relements >= block_size * 256) algo = "block_large"; options.set_launch_params( @@ -392,16 +400,45 @@ struct fused_reduce_compiler : compiler return compile_hip_code_object(ctx, src, options); } - compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const + compiler_replace compile(context& ctx, instruction_ref ins, const operation& op, const value& solution) const { assert(not ins->module_inputs().empty()); auto v = op.to_value(); + for(const auto& x:v) + v.insert(x); auto* rm = ins->module_inputs().front(); v["preamble"] = generate_reduce(*rm, "fused_reduce_op"); v["lambda"] = "MIGRAPHX_LIFT(fused_reduce_op)"; v["kernel"] = generate_name_from_ops(*rm) + "_kernel"; return compile_op(ctx, to_shapes(ins->inputs()), v); } + + optional get_tuning_config(const context& ctx, + instruction_ref ins, + const operation& op, + bool exhaustive) const + { + if(not exhaustive) + return nullopt; + if(op.name() != "fused_reduce") + return nullopt; + tuning_config tc; + auto shapes = to_shapes(ins->inputs()); + tc.problem = to_value(shapes); + auto input_shape = get_input_shape(shapes); + auto reduce_shape = get_reduce_shape(shapes); + auto relements = reduce_shape.elements(); + for(auto block_size:{64, 128, 256, 512, 1024}) + { + if(relements < block_size) + continue; + tc.solutions.push_back({{"algo", "block"}, {"block_size", block_size}}); + } + tc.solutions.push_back({{"algo", "lane"}}); + tc.solutions.push_back({{"algo", "wave"}}); + return tc; + } + }; } // namespace gpu } // namespace MIGRAPHX_INLINE_NS From d24f634cb7dd07d5b674e19418d2bb4b946a7f4c Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 9 Jan 2025 15:58:54 -0600 Subject: [PATCH 2/3] Fix benchmarking --- src/targets/gpu/compile_ops.cpp | 6 ++++-- src/targets/gpu/jit/reduce.cpp | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index cc5a7fc24d7..0a6b245bd15 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -28,6 +28,8 @@ #include #include #include +#include +#include #include #include #include @@ -225,8 +227,8 @@ struct compile_plan auto bench_ins = bench_mm->add_instruction( cr->ins->get_operator(), bench_ins_inputs, cr->ins->module_inputs()); cr->replace.replace(*bench_mm, bench_ins); - // do dead code elimination by directly removing instruction - bench_mm->remove_instruction(bench_ins); + // do dead code elimination + run_passes(*bench_mm, {dead_code_elimination{}}); auto t = time_program(*ctx, bench_prog, 20); if(trace_level > 1) std::cout << t << "ms" << std::endl; diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 012fc9d89e6..5512e0d812a 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -404,7 +404,7 @@ struct fused_reduce_compiler : compiler { assert(not ins->module_inputs().empty()); auto v = op.to_value(); - for(const auto& x:v) + for(const auto& x:solution) v.insert(x); auto* rm = ins->module_inputs().front(); v["preamble"] = generate_reduce(*rm, "fused_reduce_op"); @@ -413,7 +413,7 @@ struct fused_reduce_compiler : compiler return compile_op(ctx, to_shapes(ins->inputs()), v); } - optional get_tuning_config(const context& ctx, + optional get_tuning_config(const context&, instruction_ref ins, const operation& op, bool exhaustive) const @@ -435,7 +435,8 @@ struct fused_reduce_compiler : compiler tc.solutions.push_back({{"algo", "block"}, {"block_size", block_size}}); } tc.solutions.push_back({{"algo", "lane"}}); - tc.solutions.push_back({{"algo", "wave"}}); + if (relements < 16384) + tc.solutions.push_back({{"algo", "wave"}}); return tc; } From 925baf1d9a59337a9e4ed65f074f7e1e8c74a859 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 10 Jan 2025 16:21:55 -0600 Subject: [PATCH 3/3] Fix shape computation --- src/targets/gpu/jit/reduce.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 5512e0d812a..3e7431a3689 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -425,8 +425,9 @@ struct fused_reduce_compiler : compiler tuning_config tc; auto shapes = to_shapes(ins->inputs()); tc.problem = to_value(shapes); + auto axes = op.to_value().at("axes").to_vector(); auto input_shape = get_input_shape(shapes); - auto reduce_shape = get_reduce_shape(shapes); + auto reduce_shape = get_reduced_shape(input_shape, axes); auto relements = reduce_shape.elements(); for(auto block_size:{64, 128, 256, 512, 1024}) {