From c78363f173010de9d32d0935353bd56072372434 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 24 Sep 2025 00:37:54 +0000 Subject: [PATCH 1/7] Refactor find_matches --- src/include/migraphx/matcher.hpp | 124 ++++++++++++++++++++++++++++++- 1 file changed, 122 insertions(+), 2 deletions(-) diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index f5cd5682b39..c1e4d92b5bd 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -406,6 +406,110 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_MATCHERS) +template +auto make_match_runner_with_trace(source_location location, Finder& f) +{ + auto m = f.matcher(); + const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); + const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); + const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{}); + const auto& finder_name = get_type_name(f); + const bool trace_enabled = trace > 0 and (trace_filter.empty() or + contains(std::string{location.file_name()}, trace_filter) or + contains(std::string{location.function_name()}, trace_filter) or + contains(finder_name, trace_filter)); + return [=, &f](auto& mod, instruction_ref ins) -> bool { + using microseconds = std::chrono::duration; + if(trace > 1 and trace_enabled) + std::cout << "Running matcher: " << finder_name << std::endl; + + + match::matcher_result r; + double match_time = 0.0; + if(trace_enabled) + { + match_time = time([&] { + r = match::match_instruction(get_module(mod), ins, m); + }); + } + else + { + r = match::match_instruction(get_module(mod), ins, m); + } + + if(trace > 1 and trace_enabled) + { + std::cout << "Matcher time for " << finder_name << ": " << match_time << "us" + << std::endl; + } + + // did not match any instruction + if(r.result == get_module(mod).end()) + return false; + + if(trace > 0 or trace_enabled) + { + std::cout << "Matched by: " << finder_name << std::endl; + get_module(mod).debug_print(ins); + } + // If its already invalid dont validate it again + bool invalidated = validate and get_module(mod).validate() != get_module(mod).end(); + if(trace_enabled) + { + if(trace > 1) + std::cout << "Applying matcher: " << finder_name << std::endl; + auto apply_time = time([&] { f.apply(mod, r); }); + std::cout << "Apply time for " << finder_name << ": " << apply_time << "us" + << std::endl; + } + else + { + f.apply(mod, r); + } + + if(validate and not invalidated) + { + auto invalid = get_module(mod).validate(); + if(invalid != get_module(mod).end()) + { + std::cout << "Invalid program from match: " << finder_name << std::endl; + std::cout << "Invalid instructions: " << std::endl; + get_module(mod).debug_print(invalid->inputs()); + get_module(mod).debug_print(invalid); + } + } + return true; + }; +} + +template +auto make_match_runner(Finder& f) +{ + auto m = f.matcher(); + return [=, &f](auto& mod, instruction_ref ins) -> bool { + match::matcher_result r = match::match_instruction(get_module(mod), ins, m); + if(r.result == get_module(mod).end()) + return false; + f.apply(mod, r); + return true; + }; +} + +template +void find_matches_for(Mod& mod, instruction_ref ins, RunnerPack rp) +{ + rp([&](auto&&... rs) { + bool matched = false; + each_args( + [&](auto&& r) { + if(matched) + return; + matched = r(mod, ins); + }, + rs...); + }); +} + /// Find matches for an instruction in the module for per section of matchers template void find_matches_for(source_location location, Mod& mod, instruction_ref ins, Ms&&... ms) @@ -484,9 +588,25 @@ struct find_matches { find_matches(Mod& mod, Ms&&... ms, source_location location = source_location::current()) { - for(auto ins : iterator_for(get_module(mod))) + const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); + const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); + const bool need_trace = trace > 0 or validate; + + if(need_trace) { - find_matches_for(location, mod, ins, ms...); + auto runners = pack(make_match_runner_with_trace(location, ms)...); + for(auto ins : iterator_for(get_module(mod))) + { + find_matches_for(mod, ins, runners); + } + } + else + { + auto runners = pack(make_match_runner(ms)...); + for(auto ins : iterator_for(get_module(mod))) + { + find_matches_for(mod, ins, runners); + } } } }; From 40eea7b4780bca264ea70e70f8d6fb004bc9619e Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 24 Sep 2025 00:38:16 +0000 Subject: [PATCH 2/7] Format --- src/include/migraphx/matcher.hpp | 35 ++++++++++++++++---------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index c1e4d92b5bd..a7298cb6054 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -406,31 +406,30 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_MATCHERS) -template +template auto make_match_runner_with_trace(source_location location, Finder& f) { - auto m = f.matcher(); - const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); - const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); - const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{}); + auto m = f.matcher(); + const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); + const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); + const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{}); const auto& finder_name = get_type_name(f); - const bool trace_enabled = trace > 0 and (trace_filter.empty() or - contains(std::string{location.file_name()}, trace_filter) or - contains(std::string{location.function_name()}, trace_filter) or - contains(finder_name, trace_filter)); + const bool trace_enabled = + trace > 0 and + (trace_filter.empty() or contains(std::string{location.file_name()}, trace_filter) or + contains(std::string{location.function_name()}, trace_filter) or + contains(finder_name, trace_filter)); return [=, &f](auto& mod, instruction_ref ins) -> bool { using microseconds = std::chrono::duration; if(trace > 1 and trace_enabled) std::cout << "Running matcher: " << finder_name << std::endl; - match::matcher_result r; double match_time = 0.0; if(trace_enabled) { - match_time = time([&] { - r = match::match_instruction(get_module(mod), ins, m); - }); + match_time = + time([&] { r = match::match_instruction(get_module(mod), ins, m); }); } else { @@ -460,7 +459,7 @@ auto make_match_runner_with_trace(source_location location, Finder& f) std::cout << "Applying matcher: " << finder_name << std::endl; auto apply_time = time([&] { f.apply(mod, r); }); std::cout << "Apply time for " << finder_name << ": " << apply_time << "us" - << std::endl; + << std::endl; } else { @@ -482,7 +481,7 @@ auto make_match_runner_with_trace(source_location location, Finder& f) }; } -template +template auto make_match_runner(Finder& f) { auto m = f.matcher(); @@ -498,7 +497,7 @@ auto make_match_runner(Finder& f) template void find_matches_for(Mod& mod, instruction_ref ins, RunnerPack rp) { - rp([&](auto&&... rs) { + rp([&](auto&&... rs) { bool matched = false; each_args( [&](auto&& r) { @@ -588,8 +587,8 @@ struct find_matches { find_matches(Mod& mod, Ms&&... ms, source_location location = source_location::current()) { - const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); - const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); + const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); + const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); const bool need_trace = trace > 0 or validate; if(need_trace) From 6000a7991b921b9f9099cb50834d635de43a8594 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Wed, 24 Sep 2025 17:23:11 -0500 Subject: [PATCH 3/7] Update src/include/migraphx/matcher.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/include/migraphx/matcher.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index a7298cb6054..3961bfdda12 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -446,7 +446,7 @@ auto make_match_runner_with_trace(source_location location, Finder& f) if(r.result == get_module(mod).end()) return false; - if(trace > 0 or trace_enabled) + if(trace_enabled) { std::cout << "Matched by: " << finder_name << std::endl; get_module(mod).debug_print(ins); From 3626c2da205a828dc61a8af662fffc9bb1c1f36a Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 7 Oct 2025 16:01:21 -0700 Subject: [PATCH 4/7] disable matching for dynamic shapes --- src/include/migraphx/matcher.hpp | 12 +++++++++++ src/simplify_dyn_ops.cpp | 34 ++++++++++++++++---------------- src/targets/gpu/target.cpp | 9 +++++---- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 3961bfdda12..8e49d5c8cb2 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -50,6 +50,10 @@ inline namespace MIGRAPHX_INLINE_NS { namespace match { +struct supports_dynamic_shapes +{ +}; + struct matcher_context { matcher_context(module& m) : mod(&m) {} @@ -423,6 +427,10 @@ auto make_match_runner_with_trace(source_location location, Finder& f) using microseconds = std::chrono::duration; if(trace > 1 and trace_enabled) std::cout << "Running matcher: " << finder_name << std::endl; + + constexpr bool dynamic_supported = std::is_base_of::value; + if(not dynamic_supported and ins->get_shape().dynamic()) + return false; match::matcher_result r; double match_time = 0.0; @@ -486,6 +494,10 @@ auto make_match_runner(Finder& f) { auto m = f.matcher(); return [=, &f](auto& mod, instruction_ref ins) -> bool { + constexpr bool dynamic_supported = std::is_base_of::value; + if(not dynamic_supported and ins->get_shape().dynamic()) + return false; + match::matcher_result r = match::match_instruction(get_module(mod), ins, m); if(r.result == get_module(mod).end()) return false; diff --git a/src/simplify_dyn_ops.cpp b/src/simplify_dyn_ops.cpp index fb6b41d3b7c..91eaad6c60f 100644 --- a/src/simplify_dyn_ops.cpp +++ b/src/simplify_dyn_ops.cpp @@ -39,7 +39,7 @@ inline namespace MIGRAPHX_INLINE_NS { * into multibroadcast op with a static output shape attribute. * */ -struct find_broadcast_with_dims_static +struct find_broadcast_with_dims_static : match::supports_dynamic_shapes { auto matcher() const { @@ -80,7 +80,7 @@ struct find_broadcast_with_dims_static * At time of writing, Resize allows either 1 or 2 inputs * but the 1-input case is never created by Onnx parsing. */ -struct find_resize_static +struct find_resize_static : match::supports_dynamic_shapes { auto matcher() const @@ -168,7 +168,7 @@ struct find_resize_static * To: * broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims */ -struct find_static_2in_broadcasts +struct find_static_2in_broadcasts : match::supports_dynamic_shapes { auto matcher() const { @@ -201,7 +201,7 @@ struct find_static_2in_broadcasts * To: * slice(data); slice.starts, slice.ends. slice.axes set */ -struct find_const_2in_slice +struct find_const_2in_slice : match::supports_dynamic_shapes { auto matcher() const { @@ -255,7 +255,7 @@ struct find_const_2in_slice * To: * slice(data); slice.starts, slice.ends. slice.axes set */ -struct find_const_3in_slice +struct find_const_3in_slice : match::supports_dynamic_shapes { auto matcher() const { @@ -266,10 +266,10 @@ struct find_const_3in_slice void apply(module& m, const match::matcher_result& mr) const { - auto ins = mr.result; - auto inputs = ins->inputs(); - auto slice_op = any_cast(ins->get_operator()); - auto set_attrs = slice_op.get_set_attributes(); + auto ins = mr.result; + auto inputs = ins->inputs(); + auto slice_op = any_cast(ins->get_operator()); + auto set_attrs = slice_op.get_set_attributes(); std::vector starts_vec; std::vector ends_vec; std::vector axes_vec; @@ -314,7 +314,7 @@ struct find_const_3in_slice * To: * slice(data); slice.starts, slice.ends. slice.axes set */ -struct find_const_4in_slice +struct find_const_4in_slice : match::supports_dynamic_shapes { auto matcher() const { @@ -351,7 +351,7 @@ struct find_const_4in_slice * Simplify dimensions_of to a literal when the input arugment has a static shape * or the dynamic dimensions from `start` to `end` are fixed. */ -struct find_static_dimensions_of +struct find_static_dimensions_of : match::supports_dynamic_shapes { auto matcher() const { return match::name("dimensions_of")(); } @@ -396,7 +396,7 @@ struct find_static_dimensions_of * To: * reshape(data); reshape.dims = constant_output_dims */ -struct find_const_alloc_reshapes +struct find_const_alloc_reshapes : match::supports_dynamic_shapes { auto matcher() const { @@ -430,7 +430,7 @@ struct find_const_alloc_reshapes * To: * literal */ -struct find_const_alloc_fill +struct find_const_alloc_fill : match::supports_dynamic_shapes { auto matcher() const { @@ -454,7 +454,7 @@ struct find_const_alloc_fill * To: * multibroadcast(static_shape_arg); output_lens = static_broadcast_for_doted_shape */ -struct find_static_broadcast_for_dot +struct find_static_broadcast_for_dot : match::supports_dynamic_shapes { auto matcher() const { @@ -496,7 +496,7 @@ struct find_static_broadcast_for_dot * (on_value - off_value) * mask + off_value when we have `fill` working * on the GPU. */ -struct find_static_onehot +struct find_static_onehot : match::supports_dynamic_shapes { auto matcher() const { @@ -530,7 +530,7 @@ struct find_static_onehot depth_ins->eval().visit([&](auto d) { depth_val = d[0]; }); values_ins = onehot_inputs[2]; } - shape values_shape = values_ins->get_shape(); + shape values_shape = values_ins->get_shape(); std::vector static_output_lens = indices_shape.lens(); auto normalized_axis = (onehot_op.axis < 0) ? onehot_op.axis + indices_shape.ndim() + 1 : onehot_op.axis; @@ -574,7 +574,7 @@ struct find_static_onehot * This version ignores dynamic_dimension opt values. * Intended to be run after the other simplify_dyn_ops passes. */ -struct simplify_select_module_output_shape +struct simplify_select_module_output_shape : match::supports_dynamic_shapes { auto matcher() const { return match::name("select_module"); } diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index ffc3a803565..499f7db3036 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -88,6 +88,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_REWRITE_DOT) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SET_GEMM_PROVIDER) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_FULL_DYNAMIC) std::vector target::get_passes(migraphx::context& gctx, const compile_options& options) const { @@ -178,7 +179,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti // clang-format off return { - split_single_dyn_dim{}, + enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), split_single_dyn_dim{}), dead_code_elimination{}, simplify_dyn_ops{}, dead_code_elimination{}, @@ -203,7 +204,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti insert_pad{{"convolution"}}, dead_code_elimination{}, inline_module{}, - rewrite_pooling{}, + enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), rewrite_pooling{}), dead_code_elimination{}, rewrite_gelu{options.fast_math}, optimize_module{}, @@ -224,13 +225,13 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti enable_pass(mlir_attention_enabled(&ctx), fuse_attention{}), dead_code_elimination{}, optimize_module{}, - fuse_pointwise_reduce{}, + enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_pointwise_reduce{}), dead_code_elimination{}, #ifndef _WIN32 enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}), #endif dead_code_elimination{}, - enable_pass(mlir_enabled(), fuse_mlir{&ctx}), + enable_pass(mlir_enabled() and disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_mlir{&ctx}), dead_code_elimination{}, fuse_concat{}, dead_code_elimination{}, From 903ab834a88feede3ba985ba6d3ad5c0ce6fe15b Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 9 Oct 2025 16:06:52 -0700 Subject: [PATCH 5/7] Pass by const ref --- src/include/migraphx/matcher.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 3961bfdda12..2877f45f882 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -495,7 +495,7 @@ auto make_match_runner(Finder& f) } template -void find_matches_for(Mod& mod, instruction_ref ins, RunnerPack rp) +void find_matches_for(Mod& mod, instruction_ref ins, const RunnerPack& rp) { rp([&](auto&&... rs) { bool matched = false; From eb471ba97097c93f9506439eb478abafc0c80e8e Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 13 Oct 2025 15:39:52 -0700 Subject: [PATCH 6/7] initial enablement --- src/fuse_pointwise.cpp | 46 +++++++++++++++++++++++---- src/include/migraphx/op/pointwise.hpp | 7 ++-- src/targets/gpu/target.cpp | 2 +- 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index f66e449dfc7..1456c1ceac2 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -46,7 +46,7 @@ static literal get_scalar(instruction_ref ins) if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name())) return get_scalar(ins->inputs().front()); const auto& s = ins->get_shape(); - if(s.elements() != 1 and not(s.scalar())) + if(s.dynamic() or (s.elements() != 1 and not(s.scalar()))) return {}; if(not ins->can_eval()) return {}; @@ -330,16 +330,20 @@ struct pointwise_reshape : rewrite_reshapes_base static std::string name() { return "pointwise"; } }; -struct pointwise_broadcast_pointwise +struct pointwise_broadcast_pointwise : match::supports_dynamic_shapes { auto matcher() const { + auto pointwise = match::name("pointwise")(match::used_once()).bind("x"); auto broadcast_pointwise = - match::name("multibroadcast")( - match::used_once(), - match::args(match::name("pointwise")(match::used_once()).bind("x"))) + match::name("multibroadcast")(match::used_once(), match::args(pointwise)) .bind("broadcast"); - return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise)); + auto dyn_broadcast_pointwise = match::name("multibroadcast")(match::used_once(), + match::nargs(2), + match::arg(1)(pointwise)) + .bind("broadcast"); + return match::name("pointwise")(match::any_of[match::inputs()]( + match::any_of(broadcast_pointwise, dyn_broadcast_pointwise))); } void apply(module& m, const match::matcher_result& r) const @@ -359,11 +363,39 @@ struct pointwise_broadcast_pointwise } }; +// Use pointwise instruction input as reference for dynamic multibroadcast rather than +// the pointwise instruction itself +struct dyn_pointwise_broadcast : match::supports_dynamic_shapes +{ + auto matcher() const + { + auto broadcast_pointwise = + match::name("multibroadcast")( + match::nargs(2), + match::arg(0)(match::name("pointwise").bind("x"))) + .bind("broadcast"); + return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise)); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto broadcast_ins = r.instructions["broadcast"]; + auto x_ins = r.instructions["x"]; + + auto broadcast_inps = broadcast_ins->inputs(); + broadcast_inps[0] = x_ins->inputs().front(); + + m.replace_instruction(broadcast_ins, broadcast_ins->get_operator(), broadcast_inps); + } +}; + } // namespace static void rewrite_broadcasts(module_pass_manager& mpm) { - match::find_matches(mpm.get_module(), pointwise_broadcast_pointwise{}); + mpm.get_module().debug_print(); + match::find_matches( + mpm.get_module(), dyn_pointwise_broadcast{}, pointwise_broadcast_pointwise{}); mpm.run_pass(dead_code_elimination{}); } diff --git a/src/include/migraphx/op/pointwise.hpp b/src/include/migraphx/op/pointwise.hpp index 2f5116dc654..2372e71c7f7 100644 --- a/src/include/migraphx/op/pointwise.hpp +++ b/src/include/migraphx/op/pointwise.hpp @@ -49,11 +49,14 @@ struct pointwise MIGRAPHX_THROW("pointwise should have at least one input"); auto* pm = mods.front(); auto pnames = pm->get_parameter_names(); - check_shapes{inputs, *this}.has(pnames.size()).same_dims(); + check_shapes{inputs, *this, true}.has(pnames.size()).same_dims(); + + std::vector scalar_const_out_lens = + inputs.front().dynamic() ? std::vector{} : inputs.front().lens(); auto result = pm->compute_shapes( inputs, - {.name = name(), .strict_type = true, .scalar_const_out_lens = inputs.front().lens()}); + {.name = name(), .strict_type = true, .scalar_const_out_lens = scalar_const_out_lens}); if(result.size() == 1) return result.front(); return shape{result}; diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index fbc540b20d9..4a8b3a986d8 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -228,7 +228,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti enable_pass(mlir_attention_enabled(&ctx), fuse_attention{}), dead_code_elimination{}, optimize_module{}, - enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_pointwise_reduce{}), + fuse_pointwise_reduce{}, dead_code_elimination{}, #ifndef _WIN32 enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}), From a1befd5fff461edb8856407eb7ca08e7d23e4446 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 14 Oct 2025 13:35:12 -0700 Subject: [PATCH 7/7] add get_matcher wrapper --- src/include/migraphx/matcher.hpp | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 4c4cec2c0d6..2f0c95da4df 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -411,10 +411,28 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_MATCHERS) +MIGRAPHX_PRED_MATCHER(not_dynamic_shape, instruction_ref ins) +{ + return not ins->get_shape().dynamic(); +} + +template +auto get_matcher(const Finder& f) +{ + if constexpr(std::is_base_of{}) + { + return f.matcher(); + } + else + { + return not_dynamic_shape(f.matcher()); + } +} + template auto make_match_runner_with_trace(source_location location, Finder& f) { - auto m = f.matcher(); + auto m = get_matcher(f); const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{}); @@ -428,10 +446,6 @@ auto make_match_runner_with_trace(source_location location, Finder& f) using microseconds = std::chrono::duration; if(trace > 1 and trace_enabled) std::cout << "Running matcher: " << finder_name << std::endl; - - constexpr bool dynamic_supported = std::is_base_of::value; - if(not dynamic_supported and ins->get_shape().dynamic()) - return false; match::matcher_result r; double match_time = 0.0; @@ -493,12 +507,8 @@ auto make_match_runner_with_trace(source_location location, Finder& f) template auto make_match_runner(Finder& f) { - auto m = f.matcher(); + auto m = get_matcher(f); return [=, &f](auto& mod, instruction_ref ins) -> bool { - constexpr bool dynamic_supported = std::is_base_of::value; - if(not dynamic_supported and ins->get_shape().dynamic()) - return false; - match::matcher_result r = match::match_instruction(get_module(mod), ins, m); if(r.result == get_module(mod).end()) return false;