@@ -160,11 +160,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D,
160160 MHA,
161161 ::testing::Combine (::testing::ValuesIn(transposedShape_4D()),
162162 ::testing::ValuesIn(precision_bf16_if_supported(4 )),
163- ::testing::Values(ov::element::f32 ),
163+ ::testing::Values(ov::element::bf16 ),
164164 ::testing::Values(false ),
165165 ::testing::Values(MHA::default_thread_count),
166- ::testing::Values(8 ), // decomposed Transpose + MHA + 5 Converts + 1 Transpose on output
167- ::testing::Values(6 ), // MHA + 5 Converts on inputs and output
166+ ::testing::Values(3 ), // decomposed Transpose + MHA + 1 Transpose on output
167+ ::testing::Values(2 ), // decomposed Transpose + MHA
168168 ::testing::Values(ov::test::utils::DEVICE_CPU),
169169 ::testing::Values(CPUTestUtils::empty_plugin_config)),
170170 MHA::getTestCaseName);
@@ -182,6 +182,19 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
182182 ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
183183 MHA::getTestCaseName);
184184
185+ INSTANTIATE_TEST_SUITE_P (smoke_Snippets_MHAEnforceBF16_f32_in_prc,
186+ MHA,
187+ ::testing::Combine (::testing::ValuesIn(transposedShape_4D()),
188+ ::testing::ValuesIn(precision_f32(4 )),
189+ ::testing::Values(ov::element::f32 ),
190+ ::testing::ValuesIn({false }),
191+ ::testing::Values(MHA::default_thread_count),
192+ ::testing::Values(3 ), // decomposed Transpose + MHA + 1 Transpose on output
193+ ::testing::Values(2 ), // decomposed Transpose + MHA
194+ ::testing::Values(ov::test::utils::DEVICE_CPU),
195+ ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
196+ MHA::getTestCaseName);
197+
185198INSTANTIATE_TEST_SUITE_P (smoke_Snippets_MHA_FP16_4D_Without_Multiply,
186199 MHA,
187200 ::testing::Combine (::testing::ValuesIn(transposedShape_4D()),
0 commit comments