Skip to content

Commit 1a8a21c

Browse files
v-Golubevchenhu-wang
authored andcommitted
Added a test case with bf16 enforcement and f32 input precision
1 parent 60c61ad commit 1a8a21c

File tree

2 files changed

+25
-8
lines changed
  • src
    • plugins/intel_cpu/tests/functional/shared_tests_instances/snippets
    • tests/functional/plugin/shared/src/snippets

2 files changed

+25
-8
lines changed

src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D,
160160
MHA,
161161
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
162162
::testing::ValuesIn(precision_bf16_if_supported(4)),
163-
::testing::Values(ov::element::f32),
163+
::testing::Values(ov::element::bf16),
164164
::testing::Values(false),
165165
::testing::Values(MHA::default_thread_count),
166-
::testing::Values(8), // decomposed Transpose + MHA + 5 Converts + 1 Transpose on output
167-
::testing::Values(6), // MHA + 5 Converts on inputs and output
166+
::testing::Values(3), // decomposed Transpose + MHA + 1 Transpose on output
167+
::testing::Values(2), // decomposed Transpose + MHA
168168
::testing::Values(ov::test::utils::DEVICE_CPU),
169169
::testing::Values(CPUTestUtils::empty_plugin_config)),
170170
MHA::getTestCaseName);
@@ -182,6 +182,19 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
182182
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
183183
MHA::getTestCaseName);
184184

185+
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16_f32_in_prc,
186+
MHA,
187+
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
188+
::testing::ValuesIn(precision_f32(4)),
189+
::testing::Values(ov::element::f32),
190+
::testing::ValuesIn({false}),
191+
::testing::Values(MHA::default_thread_count),
192+
::testing::Values(3), // decomposed Transpose + MHA + 1 Transpose on output
193+
::testing::Values(2), // decomposed Transpose + MHA
194+
::testing::Values(ov::test::utils::DEVICE_CPU),
195+
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
196+
MHA::getTestCaseName);
197+
185198
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_Without_Multiply,
186199
MHA,
187200
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),

src/tests/functional/plugin/shared/src/snippets/mha.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@ void MHABase::generate_inputs(const std::vector<ov::Shape>& targetInputStaticSha
2828
const auto& model_input = model_inputs[i];
2929
ov::Tensor tensor;
3030
ov::test::utils::InputGenerateData in_data;
31+
const bool bf16_precision =
32+
configuration.at(ov::hint::inference_precision.name()).as<ov::element::Type>() == ov::element::bf16 ||
33+
model_input.get_element_type() == ov::element::bf16;
3134
// To avoid big relative errors in the vicinity of zero, only positive values are generated for bf16 precision
32-
in_data.start_from = model_input.get_element_type() == ov::element::bf16 ? 0 : -1;
35+
in_data.start_from = bf16_precision ? 0 : -1;
3336
in_data.range = 2;
3437
in_data.resolution = 256;
3538
tensor =
@@ -55,16 +58,17 @@ void MHABase::SetUp() {
5558
setInferenceType(prc);
5659
}
5760

58-
void MHABase::init_thresholds() {
61+
void MHABase::init_thresholds() {
5962
// Note: Libxsmm calculates Exp in a slightly different way, so the abs values might differ a bit. Ticket: 130699
6063
#ifdef SNIPPETS_LIBXSMM_TPP
6164
abs_threshold = 1e-6;
6265
#endif
63-
if (inType == ov::element::bf16)
66+
auto infer_precision = configuration.at(ov::hint::inference_precision.name()).as<ov::element::Type>();
67+
if (infer_precision == ov::element::bf16)
6468
rel_threshold = 0.05f;
65-
if (inType == ov::element::f16)
69+
if (infer_precision == ov::element::f16)
6670
abs_threshold = 2e-2;
67-
}
71+
}
6872

6973
std::string MHA::getTestCaseName(const testing::TestParamInfo<ov::test::snippets::MHAParams>& obj) {
7074
const auto& [input_shapes,

0 commit comments

Comments
 (0)