From 8cfd942b27b24fbbbe769cc1ec0fa5dc4b0fe90a Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Tue, 25 Nov 2025 13:14:28 -0600 Subject: [PATCH 1/6] update quantization to support dynamic shapes --- src/quantization.cpp | 33 +++++++++++++++++++++++++++++---- src/split_single_dyn_dim.cpp | 5 +++++ src/truncate_float.cpp | 5 +++++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/quantization.cpp b/src/quantization.cpp index 9ab06cc6128..b323a5cf94a 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -30,6 +30,8 @@ #include #include #include +#include +#include #include #include #include @@ -68,7 +70,11 @@ static tracer quant_tracer() void quantize_fp16(program& prog, const std::vector& ins_names) { run_passes(prog, - {normalize_ops{}, + {split_single_dyn_dim{}, + dead_code_elimination{}, + simplify_dyn_ops{}, + dead_code_elimination{}, + normalize_ops{}, optimize_module{{"quantizelinear", "dequantizelinear"}}, truncate_float_pass{ins_names, shape::half_type}, optimize_module{{"quantizelinear", "dequantizelinear"}}}, @@ -78,7 +84,11 @@ void quantize_fp16(program& prog, const std::vector& ins_names) void quantize_bf16(program& prog, const std::vector& ins_names) { run_passes(prog, - {normalize_ops{}, + {split_single_dyn_dim{}, + dead_code_elimination{}, + simplify_dyn_ops{}, + dead_code_elimination{}, + normalize_ops{}, optimize_module{{"quantizelinear", "dequantizelinear"}}, truncate_float_pass{ins_names, shape::bf16_type}, optimize_module{{"quantizelinear", "dequantizelinear"}}}, @@ -93,7 +103,16 @@ static void quantize_8bits(program& prog, { // Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to // avoid loss of precision. - run_passes(prog, {rewrite_rnn{}, normalize_ops{}, optimize_module{}}, quant_tracer()); + run_passes(prog, + {split_single_dyn_dim{}, + dead_code_elimination{}, + simplify_dyn_ops{}, + dead_code_elimination{}, + rewrite_rnn{}, + dead_code_elimination{}, + normalize_ops{}, + optimize_module{}}, + quant_tracer()); std::shared_ptr>> quant_8bit_params = std::make_shared>>(); @@ -188,7 +207,13 @@ void quantize_int8(program& prog, void quantize_int4_weights(program& prog) { - run_passes(prog, {normalize_ops{}, optimize_module{}, quantize_int4_pass{}}, quant_tracer()); + run_passes(prog, {split_single_dyn_dim{}, + dead_code_elimination{}, + simplify_dyn_ops{}, + dead_code_elimination{}, + normalize_ops{}, + optimize_module{}, + quantize_int4_pass{}}, quant_tracer()); } void quantize_fp8(program& prog, const target& t, const std::vector& calibration) diff --git a/src/split_single_dyn_dim.cpp b/src/split_single_dyn_dim.cpp index 66974cd8e59..e2996bc96b3 100644 --- a/src/split_single_dyn_dim.cpp +++ b/src/split_single_dyn_dim.cpp @@ -95,6 +95,11 @@ has_one_unique_dyn_dim(const std::unordered_map& param_shape */ static bool any_sm_next(const_module_ref mm, const std::vector& ddcs) { + // skip main module that contains select_module (meaning this pass already ran) + if(any_of(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "select_module";} )) + { + return true; + } for(const auto& ddc : ddcs) { auto p_outputs = mm->get_parameter(ddc.dyn_param_str)->outputs(); diff --git a/src/truncate_float.cpp b/src/truncate_float.cpp index 15f807684d3..690815e7a70 100644 --- a/src/truncate_float.cpp +++ b/src/truncate_float.cpp @@ -38,6 +38,11 @@ inline namespace MIGRAPHX_INLINE_NS { static void quantize_module(module& m, const std::vector& ins_names, shape::type_t float_type) { + // skip main module that contains select_module + if(any_of(m.begin(), m.end(), [](auto ins) { return ins.name() == "select_module";} )) + { + return; + } for(auto ins : iterator_for(m)) { // instructions are not in the set to be quantized From f7f69d41f27ff3834800d3c1fc7a46db77f96a1f Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Tue, 25 Nov 2025 13:16:35 -0600 Subject: [PATCH 2/6] formatting --- src/quantization.cpp | 30 +++++++++-------- src/split_single_dyn_dim.cpp | 2 +- src/truncate_float.cpp | 2 +- test/algorithm.cpp | 64 ------------------------------------ 4 files changed, 18 insertions(+), 80 deletions(-) diff --git a/src/quantization.cpp b/src/quantization.cpp index b323a5cf94a..01a47bc700d 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -105,13 +105,13 @@ static void quantize_8bits(program& prog, // avoid loss of precision. run_passes(prog, {split_single_dyn_dim{}, - dead_code_elimination{}, - simplify_dyn_ops{}, - dead_code_elimination{}, - rewrite_rnn{}, - dead_code_elimination{}, - normalize_ops{}, - optimize_module{}}, + dead_code_elimination{}, + simplify_dyn_ops{}, + dead_code_elimination{}, + rewrite_rnn{}, + dead_code_elimination{}, + normalize_ops{}, + optimize_module{}}, quant_tracer()); std::shared_ptr>> quant_8bit_params = @@ -207,13 +207,15 @@ void quantize_int8(program& prog, void quantize_int4_weights(program& prog) { - run_passes(prog, {split_single_dyn_dim{}, - dead_code_elimination{}, - simplify_dyn_ops{}, - dead_code_elimination{}, - normalize_ops{}, - optimize_module{}, - quantize_int4_pass{}}, quant_tracer()); + run_passes(prog, + {split_single_dyn_dim{}, + dead_code_elimination{}, + simplify_dyn_ops{}, + dead_code_elimination{}, + normalize_ops{}, + optimize_module{}, + quantize_int4_pass{}}, + quant_tracer()); } void quantize_fp8(program& prog, const target& t, const std::vector& calibration) diff --git a/src/split_single_dyn_dim.cpp b/src/split_single_dyn_dim.cpp index e2996bc96b3..92248a9a5f3 100644 --- a/src/split_single_dyn_dim.cpp +++ b/src/split_single_dyn_dim.cpp @@ -96,7 +96,7 @@ has_one_unique_dyn_dim(const std::unordered_map& param_shape static bool any_sm_next(const_module_ref mm, const std::vector& ddcs) { // skip main module that contains select_module (meaning this pass already ran) - if(any_of(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "select_module";} )) + if(any_of(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "select_module"; })) { return true; } diff --git a/src/truncate_float.cpp b/src/truncate_float.cpp index 690815e7a70..006bf3fcda5 100644 --- a/src/truncate_float.cpp +++ b/src/truncate_float.cpp @@ -39,7 +39,7 @@ static void quantize_module(module& m, const std::vector& ins_names, shape::type_t float_type) { // skip main module that contains select_module - if(any_of(m.begin(), m.end(), [](auto ins) { return ins.name() == "select_module";} )) + if(any_of(m.begin(), m.end(), [](auto ins) { return ins.name() == "select_module"; })) { return; } diff --git a/test/algorithm.cpp b/test/algorithm.cpp index 358ae0063e5..9733b519689 100644 --- a/test/algorithm.cpp +++ b/test/algorithm.cpp @@ -83,68 +83,4 @@ MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(adjacent_remove_if_non_equivalence, int) EXPECT(v == Container{1, 1, 1, 4, 2, 4, 2, 5, 6}); } -MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_basic, int) -{ - Container v = {5, 3, 7, 1, 9, 2}; - auto is_even = [](int x) { return x % 2 == 0; }; - auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); - EXPECT(it != v.end()); - EXPECT(*it == 2); -} - -MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_no_valid, int) -{ - Container v = {5, 3, 7, 1, 9}; - auto is_even = [](int x) { return x % 2 == 0; }; - auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); - EXPECT(it == v.end()); -} - -MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_all_valid, int) -{ - Container v = {6, 2, 8, 4, 10}; - auto is_even = [](int x) { return x % 2 == 0; }; - auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); - EXPECT(it != v.end()); - EXPECT(*it == 2); -} - -MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_custom_compare, int) -{ - Container v = {5, 3, 7, 1, 9, 2, 8}; - auto is_even = [](int x) { return x % 2 == 0; }; - // Find the largest even number - auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::greater<>{}); - EXPECT(it != v.end()); - EXPECT(*it == 8); -} - -MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_empty, int) -{ - Container v; - auto is_even = [](int x) { return x % 2 == 0; }; - auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); - EXPECT(it == v.end()); -} - -MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_first_element, int) -{ - Container v = {2, 5, 3, 7, 1, 9}; - auto is_even = [](int x) { return x % 2 == 0; }; - auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); - EXPECT(it != v.end()); - EXPECT(*it == 2); - EXPECT(it == v.begin()); -} - -MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_complex_predicate, int) -{ - Container v = {15, 3, 20, 1, 9, 25, 8, 12}; - // Find the smallest number greater than 10 - auto greater_than_10 = [](int x) { return x > 10; }; - auto it = migraphx::min_element_if(v.begin(), v.end(), greater_than_10, std::less<>{}); - EXPECT(it != v.end()); - EXPECT(*it == 12); -} - int main(int argc, const char* argv[]) { test::run(argc, argv); } From f30135c7c5267a11370e02df6282aa2d85b86247 Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Tue, 25 Nov 2025 13:32:29 -0600 Subject: [PATCH 3/6] update quantize_8bit passes to skip main module --- src/quantize_8bits.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/quantize_8bits.cpp b/src/quantize_8bits.cpp index 1f06a3451b8..caa3ce09255 100644 --- a/src/quantize_8bits.cpp +++ b/src/quantize_8bits.cpp @@ -51,6 +51,11 @@ static std::vector& get_quantizable_type() void quantize_8bits_pass::apply(module& m) const // NOLINT { + // skip main module that contains select_module + if(any_of(m.begin(), m.end(), [](auto ins) { return ins.name() == "select_module"; })) + { + return; + } const auto& quantizable_types = get_quantizable_type(); for(auto ins : iterator_for(m)) { @@ -97,6 +102,11 @@ void quantize_8bits_pass::apply(module& m) const // NOLINT void capture_arguments_pass::apply(module& m) const // NOLINT { + // skip main module that contains select_module + if(any_of(m.begin(), m.end(), [](auto ins) { return ins.name() == "select_module"; })) + { + return; + } assert(param_index != nullptr); const auto& quantizable_types = get_quantizable_type(); From 2076d2935fc153bf1269ced396071f69222f341d Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Tue, 25 Nov 2025 13:35:44 -0600 Subject: [PATCH 4/6] revert irrelevant file --- test/algorithm.cpp | 66 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/test/algorithm.cpp b/test/algorithm.cpp index 9733b519689..d31ffc7e430 100644 --- a/test/algorithm.cpp +++ b/test/algorithm.cpp @@ -49,7 +49,7 @@ erase_iterator(Container& c, Iterator pos, Iterator last) -> decltype(c.erase_af template static auto erase_iterator(Container& c, Iterator pos, Iterator last) -> decltype(c.erase(pos, - last)) + last)) { return c.erase(pos, last); } @@ -83,4 +83,68 @@ MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(adjacent_remove_if_non_equivalence, int) EXPECT(v == Container{1, 1, 1, 4, 2, 4, 2, 5, 6}); } +MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_basic, int) +{ + Container v = {5, 3, 7, 1, 9, 2}; + auto is_even = [](int x) { return x % 2 == 0; }; + auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); + EXPECT(it != v.end()); + EXPECT(*it == 2); +} + +MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_no_valid, int) +{ + Container v = {5, 3, 7, 1, 9}; + auto is_even = [](int x) { return x % 2 == 0; }; + auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); + EXPECT(it == v.end()); +} + +MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_all_valid, int) +{ + Container v = {6, 2, 8, 4, 10}; + auto is_even = [](int x) { return x % 2 == 0; }; + auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); + EXPECT(it != v.end()); + EXPECT(*it == 2); +} + +MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_custom_compare, int) +{ + Container v = {5, 3, 7, 1, 9, 2, 8}; + auto is_even = [](int x) { return x % 2 == 0; }; + // Find the largest even number + auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::greater<>{}); + EXPECT(it != v.end()); + EXPECT(*it == 8); +} + +MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_empty, int) +{ + Container v; + auto is_even = [](int x) { return x % 2 == 0; }; + auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); + EXPECT(it == v.end()); +} + +MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_first_element, int) +{ + Container v = {2, 5, 3, 7, 1, 9}; + auto is_even = [](int x) { return x % 2 == 0; }; + auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); + EXPECT(it != v.end()); + EXPECT(*it == 2); + EXPECT(it == v.begin()); +} + +MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_complex_predicate, int) +{ + Container v = {15, 3, 20, 1, 9, 25, 8, 12}; + // Find the smallest number greater than 10 + auto greater_than_10 = [](int x) { return x > 10; }; + auto it = migraphx::min_element_if(v.begin(), v.end(), greater_than_10, std::less<>{}); + EXPECT(it != v.end()); + EXPECT(*it == 12); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 18a7db7d376c0ecea039323545be0a5035096e8f Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Tue, 25 Nov 2025 13:36:40 -0600 Subject: [PATCH 5/6] revert irrelevant file --- test/algorithm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/algorithm.cpp b/test/algorithm.cpp index d31ffc7e430..358ae0063e5 100644 --- a/test/algorithm.cpp +++ b/test/algorithm.cpp @@ -49,7 +49,7 @@ erase_iterator(Container& c, Iterator pos, Iterator last) -> decltype(c.erase_af template static auto erase_iterator(Container& c, Iterator pos, Iterator last) -> decltype(c.erase(pos, - last)) + last)) { return c.erase(pos, last); } From dc0bdf8902760cfa76420a45d20c2733c00231bf Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Tue, 25 Nov 2025 14:06:19 -0600 Subject: [PATCH 6/6] fix license --- src/truncate_float.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/truncate_float.cpp b/src/truncate_float.cpp index 006bf3fcda5..b56463bc2ff 100644 --- a/src/truncate_float.cpp +++ b/src/truncate_float.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal