@@ -16,39 +16,35 @@ std::vector<std::vector<InputShape>> transposedShape_4D(bool with_static = true,
1616 std::vector<std::vector<ov::test::InputShape>> shapes;
1717 if (with_static) {
1818 auto static_shapes =
19- SNIPPETS_TESTS_STATIC_SHAPES ({{1 , 128 , 12 , 64 }, {1 , 128 , 12 , 64 }, {1 , 12 , 128 , 128 }, {1 , 128 , 12 , 64 }},
20- {{1 , 128 , 16 , 64 }, {1 , 128 , 16 , 64 }, {1 , 16 , 1 , 1 }, {1 , 128 , 16 , 64 }},
21- {{1 , 128 , 16 , 64 }, {1 , 128 , 16 , 64 }, {1 , 1 , 1 , 128 }, {1 , 128 , 16 , 64 }},
22- {{2 , 68 , 6 , 92 }, {2 , 68 , 6 , 92 }, {1 , 1 , 68 , 68 }, {2 , 68 , 6 , 92 }},
23- {{1 , 58 , 16 , 34 }, {1 , 58 , 16 , 34 }, {1 , 1 , 1 , 58 }, {1 , 58 , 16 , 34 }});
19+ SNIPPETS_TESTS_STATIC_SHAPES ({{1 , 128 , 12 , 64 }, {1 , 128 , 12 , 64 }, {1 , 12 , 128 , 128 }, {1 , 128 , 12 , 64 }});
2420 shapes.insert (shapes.end (), static_shapes.begin (), static_shapes.end ());
2521 }
26- if (with_dynamic) {
27- std::vector<std::vector<ov::test::InputShape>> dynamic_shapes = {
28- {
29- {PartialShape{-1 , -1 , -1 , 100 }, {{1 , 64 , 4 , 100 }, {2 , 16 , 2 , 100 }, {1 , 72 , 4 , 100 }}},
30- {PartialShape{-1 , 128 , -1 , 100 }, {{1 , 128 , 4 , 100 }, {2 , 128 , 2 , 100 }, {1 , 128 , 4 , 100 }}},
31- {PartialShape{-1 , -1 , -1 , 128 }, {{1 , 4 , 64 , 128 }, {2 , 2 , 16 , 128 }, {1 , 4 , 72 , 128 }}},
32- {PartialShape{-1 , 128 , -1 , 100 }, {{1 , 128 , 4 , 100 }, {2 , 128 , 2 , 100 }, {1 , 128 , 4 , 100 }}},
33- },
34- {
35- {PartialShape{-1 , -1 , -1 , -1 }, {{1 , 128 , 3 , 64 }, {2 , 16 , 2 , 100 }, {1 , 128 , 3 , 64 }}},
36- {PartialShape{-1 , -1 , -1 , -1 }, {{1 , 128 , 1 , 64 }, {2 , 128 , 2 , 100 }, {1 , 128 , 1 , 64 }}},
37- {PartialShape{-1 , -1 , -1 , -1 }, {{2 , 1 , 128 , 128 }, {2 , 2 , 16 , 128 }, {2 , 1 , 128 , 128 }}},
38- {PartialShape{-1 , -1 , -1 , -1 }, {{1 , 128 , 3 , 64 }, {2 , 128 , 2 , 100 }, {1 , 128 , 3 , 64 }}},
39- },
40- {
41- {PartialShape{-1 , -1 , 12 , 64 },
42- {{1 , 70 , 12 , 64 }, {1 , 20 , 12 , 64 }, {1 , 20 , 12 , 64 }, {1 , 20 , 12 , 64 }, {1 , 70 , 12 , 64 }}},
43- {PartialShape{-1 , -1 , 12 , 64 },
44- {{1 , 35 , 12 , 64 }, {2 , 10 , 12 , 64 }, {2 , 1 , 12 , 64 }, {2 , 10 , 12 , 64 }, {1 , 35 , 12 , 64 }}},
45- {PartialShape{-1 , 12 , -1 , -1 },
46- {{2 , 12 , 70 , 35 }, {1 , 12 , 20 , 10 }, {1 , 12 , 20 , 10 }, {1 , 12 , 20 , 1 }, {2 , 12 , 70 , 35 }}},
47- {PartialShape{-1 , -1 , 12 , 64 },
48- {{1 , 35 , 12 , 64 }, {1 , 10 , 12 , 64 }, {1 , 10 , 12 , 64 }, {1 , 10 , 12 , 64 }, {1 , 35 , 12 , 64 }}},
49- }};
50- shapes.insert (shapes.end (), dynamic_shapes.begin (), dynamic_shapes.end ());
51- }
22+ // if (with_dynamic) {
23+ // std::vector<std::vector<ov::test::InputShape>> dynamic_shapes = {
24+ // {
25+ // {PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}},
26+ // {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}},
27+ // {PartialShape{-1, -1, -1, 128}, {{1, 4, 64, 128}, {2, 2, 16, 128}, {1, 4, 72, 128}}},
28+ // {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}},
29+ // },
30+ // {
31+ // {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 16, 2, 100}, {1, 128, 3, 64}}},
32+ // {PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 128, 2, 100}, {1, 128, 1, 64}}},
33+ // {PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {2, 2, 16, 128}, {2, 1, 128, 128}}},
34+ // {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 128, 2, 100}, {1, 128, 3, 64}}},
35+ // },
36+ // {
37+ // {PartialShape{-1, -1, 12, 64},
38+ // {{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}},
39+ // {PartialShape{-1, -1, 12, 64},
40+ // {{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}},
41+ // {PartialShape{-1, 12, -1, -1},
42+ // {{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}},
43+ // {PartialShape{-1, -1, 12, 64},
44+ // {{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}},
45+ // }};
46+ // shapes.insert(shapes.end(), dynamic_shapes.begin(), dynamic_shapes.end());
47+ // }
5248 return shapes;
5349}
5450
@@ -160,11 +156,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D,
160156 MHA,
161157 ::testing::Combine (::testing::ValuesIn(transposedShape_4D()),
162158 ::testing::ValuesIn(precision_bf16_if_supported(4 )),
163- ::testing::Values(ov::element::f32 ),
159+ ::testing::Values(ov::element::bf16 ),
164160 ::testing::Values(false ),
165161 ::testing::Values(MHA::default_thread_count),
166- ::testing::Values(8 ), // decomposed Transpose + MHA + 5 Converts + 1 Transpose on output
167- ::testing::Values(6 ), // MHA + 5 Converts on inputs and output
162+ ::testing::Values(3 ), // decomposed Transpose + MHA + 1 Transpose on output
163+ ::testing::Values(2 ), // decomposed Transpose + MHA
168164 ::testing::Values(ov::test::utils::DEVICE_CPU),
169165 ::testing::Values(CPUTestUtils::empty_plugin_config)),
170166 MHA::getTestCaseName);
@@ -173,11 +169,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
173169 MHA,
174170 ::testing::Combine (::testing::ValuesIn(transposedShape_4D()),
175171 ::testing::ValuesIn(precision_f32(4 )),
176- ::testing::Values(ov::element::bf16 ),
172+ ::testing::Values(ov::element::f32 ),
177173 ::testing::ValuesIn({false }),
178174 ::testing::Values(MHA::default_thread_count),
179- ::testing::Values(8 ), // decomposed Transpose + MHA + 5 Converts + 1 Transpose on output
180- ::testing::Values(6 ), // MHA + 5 Reorders on inputs and output
175+ ::testing::Values(3 ), // decomposed Transpose + MHA + 1 Transpose on output
176+ ::testing::Values(2 ), // decomposed Transpose + MHA
181177 ::testing::Values(ov::test::utils::DEVICE_CPU),
182178 ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
183179 MHA::getTestCaseName);
@@ -224,7 +220,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_Without_Multiply,
224220 MHA,
225221 ::testing::Combine (::testing::ValuesIn(transposedShape_4D()),
226222 ::testing::ValuesIn(precision_f32(4 )),
227- ::testing::Values(ov::element::f16 ),
223+ ::testing::Values(ov::element::f32 , ov::element:: f16 ),
228224 ::testing::ValuesIn({false }),
229225 ::testing::Values(MHA::default_thread_count),
230226 ::testing::Values(3 ),
@@ -236,7 +232,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_With_Multiply_Static,
236232 MHA,
237233 ::testing::Combine (::testing::ValuesIn(transposedShape_4D(true , false )),
238234 ::testing::ValuesIn(precision_f32(4 )),
239- ::testing::Values(ov::element::f16 ),
235+ ::testing::Values(ov::element::f32 , ov::element:: f16 ),
240236 ::testing::ValuesIn({true }),
241237 ::testing::Values(MHA::default_thread_count),
242238 ::testing::Values(3 ),
@@ -248,7 +244,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_With_Multiply_Dynamic,
248244 MHA,
249245 ::testing::Combine (::testing::ValuesIn(transposedShape_4D(false , true )),
250246 ::testing::ValuesIn(precision_f32(4 )),
251- ::testing::Values(ov::element::f16 ),
247+ ::testing::Values(ov::element::f32 , ov::element:: f16 ),
252248 ::testing::ValuesIn({true }),
253249 ::testing::Values(MHA::default_thread_count),
254250 ::testing::Values(4 ),
0 commit comments