diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index f66e449dfc7..1456c1ceac2 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -46,7 +46,7 @@ static literal get_scalar(instruction_ref ins) if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name())) return get_scalar(ins->inputs().front()); const auto& s = ins->get_shape(); - if(s.elements() != 1 and not(s.scalar())) + if(s.dynamic() or (s.elements() != 1 and not(s.scalar()))) return {}; if(not ins->can_eval()) return {}; @@ -330,16 +330,20 @@ struct pointwise_reshape : rewrite_reshapes_base static std::string name() { return "pointwise"; } }; -struct pointwise_broadcast_pointwise +struct pointwise_broadcast_pointwise : match::supports_dynamic_shapes { auto matcher() const { + auto pointwise = match::name("pointwise")(match::used_once()).bind("x"); auto broadcast_pointwise = - match::name("multibroadcast")( - match::used_once(), - match::args(match::name("pointwise")(match::used_once()).bind("x"))) + match::name("multibroadcast")(match::used_once(), match::args(pointwise)) .bind("broadcast"); - return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise)); + auto dyn_broadcast_pointwise = match::name("multibroadcast")(match::used_once(), + match::nargs(2), + match::arg(1)(pointwise)) + .bind("broadcast"); + return match::name("pointwise")(match::any_of[match::inputs()]( + match::any_of(broadcast_pointwise, dyn_broadcast_pointwise))); } void apply(module& m, const match::matcher_result& r) const @@ -359,11 +363,39 @@ struct pointwise_broadcast_pointwise } }; +// Use pointwise instruction input as reference for dynamic multibroadcast rather than +// the pointwise instruction itself +struct dyn_pointwise_broadcast : match::supports_dynamic_shapes +{ + auto matcher() const + { + auto broadcast_pointwise = + match::name("multibroadcast")( + match::nargs(2), + match::arg(0)(match::name("pointwise").bind("x"))) + .bind("broadcast"); + return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise)); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto broadcast_ins = r.instructions["broadcast"]; + auto x_ins = r.instructions["x"]; + + auto broadcast_inps = broadcast_ins->inputs(); + broadcast_inps[0] = x_ins->inputs().front(); + + m.replace_instruction(broadcast_ins, broadcast_ins->get_operator(), broadcast_inps); + } +}; + } // namespace static void rewrite_broadcasts(module_pass_manager& mpm) { - match::find_matches(mpm.get_module(), pointwise_broadcast_pointwise{}); + mpm.get_module().debug_print(); + match::find_matches( + mpm.get_module(), dyn_pointwise_broadcast{}, pointwise_broadcast_pointwise{}); mpm.run_pass(dead_code_elimination{}); } diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 8bd771b8a2e..2f0c95da4df 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -51,6 +51,10 @@ inline namespace MIGRAPHX_INLINE_NS { namespace match { +struct supports_dynamic_shapes +{ +}; + struct matcher_context { matcher_context(module& m) : mod(&m) {} @@ -407,10 +411,28 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_MATCHERS) +MIGRAPHX_PRED_MATCHER(not_dynamic_shape, instruction_ref ins) +{ + return not ins->get_shape().dynamic(); +} + +template +auto get_matcher(const Finder& f) +{ + if constexpr(std::is_base_of{}) + { + return f.matcher(); + } + else + { + return not_dynamic_shape(f.matcher()); + } +} + template auto make_match_runner_with_trace(source_location location, Finder& f) { - auto m = f.matcher(); + auto m = get_matcher(f); const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{}); @@ -485,7 +507,7 @@ auto make_match_runner_with_trace(source_location location, Finder& f) template auto make_match_runner(Finder& f) { - auto m = f.matcher(); + auto m = get_matcher(f); return [=, &f](auto& mod, instruction_ref ins) -> bool { match::matcher_result r = match::match_instruction(get_module(mod), ins, m); if(r.result == get_module(mod).end()) diff --git a/src/include/migraphx/op/pointwise.hpp b/src/include/migraphx/op/pointwise.hpp index 2f5116dc654..2372e71c7f7 100644 --- a/src/include/migraphx/op/pointwise.hpp +++ b/src/include/migraphx/op/pointwise.hpp @@ -49,11 +49,14 @@ struct pointwise MIGRAPHX_THROW("pointwise should have at least one input"); auto* pm = mods.front(); auto pnames = pm->get_parameter_names(); - check_shapes{inputs, *this}.has(pnames.size()).same_dims(); + check_shapes{inputs, *this, true}.has(pnames.size()).same_dims(); + + std::vector scalar_const_out_lens = + inputs.front().dynamic() ? std::vector{} : inputs.front().lens(); auto result = pm->compute_shapes( inputs, - {.name = name(), .strict_type = true, .scalar_const_out_lens = inputs.front().lens()}); + {.name = name(), .strict_type = true, .scalar_const_out_lens = scalar_const_out_lens}); if(result.size() == 1) return result.front(); return shape{result}; diff --git a/src/simplify_dyn_ops.cpp b/src/simplify_dyn_ops.cpp index fb6b41d3b7c..91eaad6c60f 100644 --- a/src/simplify_dyn_ops.cpp +++ b/src/simplify_dyn_ops.cpp @@ -39,7 +39,7 @@ inline namespace MIGRAPHX_INLINE_NS { * into multibroadcast op with a static output shape attribute. * */ -struct find_broadcast_with_dims_static +struct find_broadcast_with_dims_static : match::supports_dynamic_shapes { auto matcher() const { @@ -80,7 +80,7 @@ struct find_broadcast_with_dims_static * At time of writing, Resize allows either 1 or 2 inputs * but the 1-input case is never created by Onnx parsing. */ -struct find_resize_static +struct find_resize_static : match::supports_dynamic_shapes { auto matcher() const @@ -168,7 +168,7 @@ struct find_resize_static * To: * broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims */ -struct find_static_2in_broadcasts +struct find_static_2in_broadcasts : match::supports_dynamic_shapes { auto matcher() const { @@ -201,7 +201,7 @@ struct find_static_2in_broadcasts * To: * slice(data); slice.starts, slice.ends. slice.axes set */ -struct find_const_2in_slice +struct find_const_2in_slice : match::supports_dynamic_shapes { auto matcher() const { @@ -255,7 +255,7 @@ struct find_const_2in_slice * To: * slice(data); slice.starts, slice.ends. slice.axes set */ -struct find_const_3in_slice +struct find_const_3in_slice : match::supports_dynamic_shapes { auto matcher() const { @@ -266,10 +266,10 @@ struct find_const_3in_slice void apply(module& m, const match::matcher_result& mr) const { - auto ins = mr.result; - auto inputs = ins->inputs(); - auto slice_op = any_cast(ins->get_operator()); - auto set_attrs = slice_op.get_set_attributes(); + auto ins = mr.result; + auto inputs = ins->inputs(); + auto slice_op = any_cast(ins->get_operator()); + auto set_attrs = slice_op.get_set_attributes(); std::vector starts_vec; std::vector ends_vec; std::vector axes_vec; @@ -314,7 +314,7 @@ struct find_const_3in_slice * To: * slice(data); slice.starts, slice.ends. slice.axes set */ -struct find_const_4in_slice +struct find_const_4in_slice : match::supports_dynamic_shapes { auto matcher() const { @@ -351,7 +351,7 @@ struct find_const_4in_slice * Simplify dimensions_of to a literal when the input arugment has a static shape * or the dynamic dimensions from `start` to `end` are fixed. */ -struct find_static_dimensions_of +struct find_static_dimensions_of : match::supports_dynamic_shapes { auto matcher() const { return match::name("dimensions_of")(); } @@ -396,7 +396,7 @@ struct find_static_dimensions_of * To: * reshape(data); reshape.dims = constant_output_dims */ -struct find_const_alloc_reshapes +struct find_const_alloc_reshapes : match::supports_dynamic_shapes { auto matcher() const { @@ -430,7 +430,7 @@ struct find_const_alloc_reshapes * To: * literal */ -struct find_const_alloc_fill +struct find_const_alloc_fill : match::supports_dynamic_shapes { auto matcher() const { @@ -454,7 +454,7 @@ struct find_const_alloc_fill * To: * multibroadcast(static_shape_arg); output_lens = static_broadcast_for_doted_shape */ -struct find_static_broadcast_for_dot +struct find_static_broadcast_for_dot : match::supports_dynamic_shapes { auto matcher() const { @@ -496,7 +496,7 @@ struct find_static_broadcast_for_dot * (on_value - off_value) * mask + off_value when we have `fill` working * on the GPU. */ -struct find_static_onehot +struct find_static_onehot : match::supports_dynamic_shapes { auto matcher() const { @@ -530,7 +530,7 @@ struct find_static_onehot depth_ins->eval().visit([&](auto d) { depth_val = d[0]; }); values_ins = onehot_inputs[2]; } - shape values_shape = values_ins->get_shape(); + shape values_shape = values_ins->get_shape(); std::vector static_output_lens = indices_shape.lens(); auto normalized_axis = (onehot_op.axis < 0) ? onehot_op.axis + indices_shape.ndim() + 1 : onehot_op.axis; @@ -574,7 +574,7 @@ struct find_static_onehot * This version ignores dynamic_dimension opt values. * Intended to be run after the other simplify_dyn_ops passes. */ -struct simplify_select_module_output_shape +struct simplify_select_module_output_shape : match::supports_dynamic_shapes { auto matcher() const { return match::name("select_module"); } diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 5844c934259..4a8b3a986d8 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -89,6 +89,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SET_GEMM_PROVIDER) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_FULL_DYNAMIC) std::vector target::get_passes(migraphx::context& gctx, const compile_options& options) const { @@ -179,7 +180,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti // clang-format off return { - split_single_dyn_dim{}, + enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), split_single_dyn_dim{}), dead_code_elimination{}, simplify_dyn_ops{}, dead_code_elimination{}, @@ -204,7 +205,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti insert_pad{{"convolution"}}, dead_code_elimination{}, inline_module{}, - rewrite_pooling{.rewrite_lrn = enabled(MIGRAPHX_REWRITE_LRN{})}, + enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), rewrite_pooling{.rewrite_lrn = enabled(MIGRAPHX_REWRITE_LRN{})}), dead_code_elimination{}, rewrite_gelu{options.fast_math}, optimize_module{}, @@ -233,7 +234,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}), #endif dead_code_elimination{}, - enable_pass(mlir_enabled(), fuse_mlir{&ctx}), + enable_pass(mlir_enabled() and disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_mlir{&ctx}), dead_code_elimination{}, fuse_concat{}, dead_code_elimination{},