diff --git a/src/quantization.cpp b/src/quantization.cpp index 9ab06cc6128..01a47bc700d 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -30,6 +30,8 @@ #include #include #include +#include +#include #include #include #include @@ -68,7 +70,11 @@ static tracer quant_tracer() void quantize_fp16(program& prog, const std::vector& ins_names) { run_passes(prog, - {normalize_ops{}, + {split_single_dyn_dim{}, + dead_code_elimination{}, + simplify_dyn_ops{}, + dead_code_elimination{}, + normalize_ops{}, optimize_module{{"quantizelinear", "dequantizelinear"}}, truncate_float_pass{ins_names, shape::half_type}, optimize_module{{"quantizelinear", "dequantizelinear"}}}, @@ -78,7 +84,11 @@ void quantize_fp16(program& prog, const std::vector& ins_names) void quantize_bf16(program& prog, const std::vector& ins_names) { run_passes(prog, - {normalize_ops{}, + {split_single_dyn_dim{}, + dead_code_elimination{}, + simplify_dyn_ops{}, + dead_code_elimination{}, + normalize_ops{}, optimize_module{{"quantizelinear", "dequantizelinear"}}, truncate_float_pass{ins_names, shape::bf16_type}, optimize_module{{"quantizelinear", "dequantizelinear"}}}, @@ -93,7 +103,16 @@ static void quantize_8bits(program& prog, { // Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to // avoid loss of precision. - run_passes(prog, {rewrite_rnn{}, normalize_ops{}, optimize_module{}}, quant_tracer()); + run_passes(prog, + {split_single_dyn_dim{}, + dead_code_elimination{}, + simplify_dyn_ops{}, + dead_code_elimination{}, + rewrite_rnn{}, + dead_code_elimination{}, + normalize_ops{}, + optimize_module{}}, + quant_tracer()); std::shared_ptr>> quant_8bit_params = std::make_shared>>(); @@ -188,7 +207,15 @@ void quantize_int8(program& prog, void quantize_int4_weights(program& prog) { - run_passes(prog, {normalize_ops{}, optimize_module{}, quantize_int4_pass{}}, quant_tracer()); + run_passes(prog, + {split_single_dyn_dim{}, + dead_code_elimination{}, + simplify_dyn_ops{}, + dead_code_elimination{}, + normalize_ops{}, + optimize_module{}, + quantize_int4_pass{}}, + quant_tracer()); } void quantize_fp8(program& prog, const target& t, const std::vector& calibration) diff --git a/src/quantize_8bits.cpp b/src/quantize_8bits.cpp index 1f06a3451b8..caa3ce09255 100644 --- a/src/quantize_8bits.cpp +++ b/src/quantize_8bits.cpp @@ -51,6 +51,11 @@ static std::vector& get_quantizable_type() void quantize_8bits_pass::apply(module& m) const // NOLINT { + // skip main module that contains select_module + if(any_of(m.begin(), m.end(), [](auto ins) { return ins.name() == "select_module"; })) + { + return; + } const auto& quantizable_types = get_quantizable_type(); for(auto ins : iterator_for(m)) { @@ -97,6 +102,11 @@ void quantize_8bits_pass::apply(module& m) const // NOLINT void capture_arguments_pass::apply(module& m) const // NOLINT { + // skip main module that contains select_module + if(any_of(m.begin(), m.end(), [](auto ins) { return ins.name() == "select_module"; })) + { + return; + } assert(param_index != nullptr); const auto& quantizable_types = get_quantizable_type(); diff --git a/src/split_single_dyn_dim.cpp b/src/split_single_dyn_dim.cpp index 66974cd8e59..92248a9a5f3 100644 --- a/src/split_single_dyn_dim.cpp +++ b/src/split_single_dyn_dim.cpp @@ -95,6 +95,11 @@ has_one_unique_dyn_dim(const std::unordered_map& param_shape */ static bool any_sm_next(const_module_ref mm, const std::vector& ddcs) { + // skip main module that contains select_module (meaning this pass already ran) + if(any_of(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "select_module"; })) + { + return true; + } for(const auto& ddc : ddcs) { auto p_outputs = mm->get_parameter(ddc.dyn_param_str)->outputs(); diff --git a/src/truncate_float.cpp b/src/truncate_float.cpp index 15f807684d3..b56463bc2ff 100644 --- a/src/truncate_float.cpp +++ b/src/truncate_float.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,6 +38,11 @@ inline namespace MIGRAPHX_INLINE_NS { static void quantize_module(module& m, const std::vector& ins_names, shape::type_t float_type) { + // skip main module that contains select_module + if(any_of(m.begin(), m.end(), [](auto ins) { return ins.name() == "select_module"; })) + { + return; + } for(auto ins : iterator_for(m)) { // instructions are not in the set to be quantized