Skip to content

Commit 8259fb7

Browse files
committed
WIP
1 parent c734ebe commit 8259fb7

File tree

4 files changed

+48
-48
lines changed

4 files changed

+48
-48
lines changed

src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,35 @@ std::vector<std::vector<InputShape>> transposedShape_4D(bool with_static = true,
1616
std::vector<std::vector<ov::test::InputShape>> shapes;
1717
if (with_static) {
1818
auto static_shapes =
19-
SNIPPETS_TESTS_STATIC_SHAPES({{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}},
20-
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}},
21-
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}},
22-
{{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}},
23-
{{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}});
19+
SNIPPETS_TESTS_STATIC_SHAPES({{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}});
2420
shapes.insert(shapes.end(), static_shapes.begin(), static_shapes.end());
2521
}
26-
if (with_dynamic) {
27-
std::vector<std::vector<ov::test::InputShape>> dynamic_shapes = {
28-
{
29-
{PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}},
30-
{PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}},
31-
{PartialShape{-1, -1, -1, 128}, {{1, 4, 64, 128}, {2, 2, 16, 128}, {1, 4, 72, 128}}},
32-
{PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}},
33-
},
34-
{
35-
{PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 16, 2, 100}, {1, 128, 3, 64}}},
36-
{PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 128, 2, 100}, {1, 128, 1, 64}}},
37-
{PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {2, 2, 16, 128}, {2, 1, 128, 128}}},
38-
{PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 128, 2, 100}, {1, 128, 3, 64}}},
39-
},
40-
{
41-
{PartialShape{-1, -1, 12, 64},
42-
{{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}},
43-
{PartialShape{-1, -1, 12, 64},
44-
{{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}},
45-
{PartialShape{-1, 12, -1, -1},
46-
{{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}},
47-
{PartialShape{-1, -1, 12, 64},
48-
{{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}},
49-
}};
50-
shapes.insert(shapes.end(), dynamic_shapes.begin(), dynamic_shapes.end());
51-
}
22+
// if (with_dynamic) {
23+
// std::vector<std::vector<ov::test::InputShape>> dynamic_shapes = {
24+
// {
25+
// {PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}},
26+
// {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}},
27+
// {PartialShape{-1, -1, -1, 128}, {{1, 4, 64, 128}, {2, 2, 16, 128}, {1, 4, 72, 128}}},
28+
// {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}},
29+
// },
30+
// {
31+
// {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 16, 2, 100}, {1, 128, 3, 64}}},
32+
// {PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 128, 2, 100}, {1, 128, 1, 64}}},
33+
// {PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {2, 2, 16, 128}, {2, 1, 128, 128}}},
34+
// {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 128, 2, 100}, {1, 128, 3, 64}}},
35+
// },
36+
// {
37+
// {PartialShape{-1, -1, 12, 64},
38+
// {{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}},
39+
// {PartialShape{-1, -1, 12, 64},
40+
// {{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}},
41+
// {PartialShape{-1, 12, -1, -1},
42+
// {{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}},
43+
// {PartialShape{-1, -1, 12, 64},
44+
// {{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}},
45+
// }};
46+
// shapes.insert(shapes.end(), dynamic_shapes.begin(), dynamic_shapes.end());
47+
// }
5248
return shapes;
5349
}
5450

@@ -160,11 +156,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D,
160156
MHA,
161157
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
162158
::testing::ValuesIn(precision_bf16_if_supported(4)),
163-
::testing::Values(ov::element::f32),
159+
::testing::Values(ov::element::bf16),
164160
::testing::Values(false),
165161
::testing::Values(MHA::default_thread_count),
166-
::testing::Values(8), // decomposed Transpose + MHA + 5 Converts + 1 Transpose on output
167-
::testing::Values(6), // MHA + 5 Converts on inputs and output
162+
::testing::Values(3), // decomposed Transpose + MHA + 1 Transpose on output
163+
::testing::Values(2), // decomposed Transpose + MHA
168164
::testing::Values(ov::test::utils::DEVICE_CPU),
169165
::testing::Values(CPUTestUtils::empty_plugin_config)),
170166
MHA::getTestCaseName);
@@ -173,11 +169,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
173169
MHA,
174170
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
175171
::testing::ValuesIn(precision_f32(4)),
176-
::testing::Values(ov::element::bf16),
172+
::testing::Values(ov::element::f32),
177173
::testing::ValuesIn({false}),
178174
::testing::Values(MHA::default_thread_count),
179-
::testing::Values(8), // decomposed Transpose + MHA + 5 Converts + 1 Transpose on output
180-
::testing::Values(6), // MHA + 5 Reorders on inputs and output
175+
::testing::Values(3), // decomposed Transpose + MHA + 1 Transpose on output
176+
::testing::Values(2), // decomposed Transpose + MHA
181177
::testing::Values(ov::test::utils::DEVICE_CPU),
182178
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
183179
MHA::getTestCaseName);
@@ -224,7 +220,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_Without_Multiply,
224220
MHA,
225221
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
226222
::testing::ValuesIn(precision_f32(4)),
227-
::testing::Values(ov::element::f16),
223+
::testing::Values(ov::element::f32, ov::element::f16),
228224
::testing::ValuesIn({false}),
229225
::testing::Values(MHA::default_thread_count),
230226
::testing::Values(3),
@@ -236,7 +232,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_With_Multiply_Static,
236232
MHA,
237233
::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)),
238234
::testing::ValuesIn(precision_f32(4)),
239-
::testing::Values(ov::element::f16),
235+
::testing::Values(ov::element::f32, ov::element::f16),
240236
::testing::ValuesIn({true}),
241237
::testing::Values(MHA::default_thread_count),
242238
::testing::Values(3),
@@ -248,7 +244,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_With_Multiply_Dynamic,
248244
MHA,
249245
::testing::Combine(::testing::ValuesIn(transposedShape_4D(false, true)),
250246
::testing::ValuesIn(precision_f32(4)),
251-
::testing::Values(ov::element::f16),
247+
::testing::Values(ov::element::f32, ov::element::f16),
252248
::testing::ValuesIn({true}),
253249
::testing::Values(MHA::default_thread_count),
254250
::testing::Values(4),

src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_with_dyn_mul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ INSTANTIATE_TEST_SUITE_P(
6161
MHAWithDynamicMul,
6262
::testing::Combine(::testing::ValuesIn(transposedShape_4D_WithMul),
6363
::testing::ValuesIn(precision_f32(5)),
64-
::testing::Values(ov::element::bf16),
64+
::testing::Values(ov::element::f32, ov::element::bf16),
6565
::testing::Values(MHA::default_thread_count),
6666
::testing::Values(9), // Transpose1 + MHA + 1 Transpose on output + 6 Converts around
6767
::testing::Values(7), // MHA + 6 Converts around

src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_wo_transpose.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ INSTANTIATE_TEST_SUITE_P(
100100
MHAWOTranspose,
101101
::testing::Combine(::testing::ValuesIn(originalShapes),
102102
::testing::ValuesIn(precision_f32(3)),
103-
::testing::Values(ov::element::bf16),
103+
::testing::Values(ov::element::f32, ov::element::bf16),
104104
::testing::Values(true), // Need to support False for graph builder in tests
105105
::testing::Values(MHA::default_thread_count),
106106
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
@@ -115,7 +115,7 @@ INSTANTIATE_TEST_SUITE_P(
115115
MHAWOTranspose,
116116
::testing::Combine(::testing::ValuesIn(originalShapes),
117117
::testing::ValuesIn(precision_f32(3)),
118-
::testing::Values(ov::element::f16),
118+
::testing::Values(ov::element::f32, ov::element::f16),
119119
::testing::Values(true), // Need to support False for graph builder in tests
120120
::testing::Values(MHA::default_thread_count),
121121
::testing::Values(1),

src/tests/functional/plugin/shared/src/snippets/mha.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@ void MHABase::generate_inputs(const std::vector<ov::Shape>& targetInputStaticSha
2828
const auto& model_input = model_inputs[i];
2929
ov::Tensor tensor;
3030
ov::test::utils::InputGenerateData in_data;
31+
const bool bf16_precision =
32+
configuration.at(ov::hint::inference_precision.name()).as<ov::element::Type>() == ov::element::bf16 ||
33+
model_input.get_element_type() == ov::element::bf16;
3134
// To avoid big relative errors in the vicinity of zero, only positive values are generated for bf16 precision
32-
in_data.start_from = model_input.get_element_type() == ov::element::bf16 ? 0 : -1;
35+
in_data.start_from = bf16_precision ? 0 : -1;
3336
in_data.range = 2;
3437
in_data.resolution = 256;
3538
tensor =
@@ -55,16 +58,17 @@ void MHABase::SetUp() {
5558
setInferenceType(prc);
5659
}
5760

58-
void MHABase::init_thresholds() {
61+
void MHABase::init_thresholds() {
5962
// Note: Libxsmm calculates Exp in a slightly different way, so the abs values might differ a bit. Ticket: 130699
6063
#ifdef SNIPPETS_LIBXSMM_TPP
6164
abs_threshold = 1e-6;
6265
#endif
63-
if (inType == ov::element::bf16)
66+
auto infer_precision = configuration.at(ov::hint::inference_precision.name()).as<ov::element::Type>();
67+
if (infer_precision == ov::element::bf16)
6468
rel_threshold = 0.05f;
65-
if (inType == ov::element::f16)
69+
if (infer_precision == ov::element::f16)
6670
abs_threshold = 2e-2;
67-
}
71+
}
6872

6973
std::string MHA::getTestCaseName(const testing::TestParamInfo<ov::test::snippets::MHAParams>& obj) {
7074
const auto& [input_shapes,

0 commit comments

Comments
 (0)