Skip to content

Commit 49bef1c

Browse files
committed
Merge pull request opencv#21283 from rogday:flatten_fix
2 parents 175bcb1 + fec2c7e commit 49bef1c

5 files changed

+58
-36
lines changed

modules/dnn/src/layers/flatten_layer.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ class FlattenLayerImpl CV_FINAL : public FlattenLayer
100100
{
101101
outputShapeVec.push_back(inputs[0][i]);
102102
}
103-
CV_Assert(outputShapeVec.size() <= 4);
104103

105104
outputs.resize(inputs.size(), outputShapeVec);
106105

modules/dnn/src/onnx/onnx_importer.cpp

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,20 +1781,67 @@ void ONNXImporter::parseSqueeze(LayerParams& layerParams, const opencv_onnx::Nod
17811781
addLayer(layerParams, node_proto);
17821782
}
17831783

1784-
void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
1784+
void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
17851785
{
1786+
opencv_onnx::NodeProto node_proto = node_proto_;
17861787
CV_CheckEQ(node_proto.input_size(), 1, "");
1788+
int axis_ = layerParams.get<int>("axis", 1);
17871789
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
17881790
{
17891791
Mat input = getBlob(node_proto, 0);
1790-
int axis = normalize_axis(layerParams.get<int>("axis", 1), input.dims);
1792+
int axis = normalize_axis(axis_, input.dims);
17911793

1792-
std::vector<int> out_size(&input.size[0], &input.size[0] + axis);
1793-
out_size.push_back(input.total(axis));
1794-
Mat output = input.reshape(1, out_size);
1794+
int out_size[2] = {1, 1};
1795+
for (int i = 0; i < axis; ++i)
1796+
{
1797+
out_size[0] *= input.size[i];
1798+
}
1799+
for (int i = axis; i < input.dims; ++i)
1800+
{
1801+
out_size[1] *= input.size[i];
1802+
}
1803+
1804+
Mat output = input.reshape(1, 2, out_size);
17951805
addConstant(layerParams.name, output);
17961806
return;
17971807
}
1808+
IterShape_t shapeIt = outShapes.find(node_proto.input(0));
1809+
CV_Assert(shapeIt != outShapes.end());
1810+
MatShape inpShape = shapeIt->second;
1811+
int axis = normalize_axis(axis_, inpShape.size());
1812+
1813+
if (axis == 0 || axis == inpShape.size())
1814+
{
1815+
LayerParams reshapeLp;
1816+
reshapeLp.name = layerParams.name + "/reshape";
1817+
reshapeLp.type = "Reshape";
1818+
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
1819+
1820+
inpShape.insert(axis == 0 ? inpShape.begin() : inpShape.end(), 1);
1821+
reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
1822+
1823+
opencv_onnx::NodeProto proto;
1824+
proto.add_input(node_proto.input(0));
1825+
proto.add_output(reshapeLp.name);
1826+
addLayer(reshapeLp, proto);
1827+
node_proto.set_input(0, reshapeLp.name);
1828+
axis += 1;
1829+
}
1830+
1831+
LayerParams first_pass;
1832+
first_pass.name = layerParams.name + "/flatten";
1833+
CV_Assert(layer_id.find(first_pass.name) == layer_id.end());
1834+
first_pass.type = "Flatten";
1835+
first_pass.set("axis", 0);
1836+
first_pass.set("end_axis", axis - 1);
1837+
1838+
opencv_onnx::NodeProto proto;
1839+
proto.add_input(node_proto.input(0));
1840+
proto.add_output(first_pass.name);
1841+
addLayer(first_pass, proto);
1842+
1843+
layerParams.set("axis", 1);
1844+
node_proto.set_input(0, first_pass.name);
17981845
addLayer(layerParams, node_proto);
17991846
}
18001847

modules/dnn/test/test_onnx_conformance_layer_filter__halide_denylist.inl.hpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,6 @@
1717
"test_elu",
1818
"test_elu_default",
1919
"test_exp",
20-
"test_flatten_axis0",
21-
"test_flatten_axis2",
22-
"test_flatten_axis3",
23-
"test_flatten_negative_axis1",
24-
"test_flatten_negative_axis2",
25-
"test_flatten_negative_axis4",
2620
"test_leakyrelu",
2721
"test_leakyrelu_default",
2822
"test_logsoftmax_axis_1",

modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -561,35 +561,23 @@ CASE(test_eyelike_with_dtype)
561561
CASE(test_eyelike_without_dtype)
562562
// no filter
563563
CASE(test_flatten_axis0)
564-
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
565-
SKIP;
566-
#endif
564+
// no filter
567565
CASE(test_flatten_axis1)
568566
// no filter
569567
CASE(test_flatten_axis2)
570-
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
571-
SKIP;
572-
#endif
568+
// no filter
573569
CASE(test_flatten_axis3)
574-
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
575-
SKIP;
576-
#endif
570+
// no filter
577571
CASE(test_flatten_default_axis)
578572
// no filter
579573
CASE(test_flatten_negative_axis1)
580-
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
581-
SKIP;
582-
#endif
574+
// no filter
583575
CASE(test_flatten_negative_axis2)
584-
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
585-
SKIP;
586-
#endif
576+
// no filter
587577
CASE(test_flatten_negative_axis3)
588578
// no filter
589579
CASE(test_flatten_negative_axis4)
590-
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
591-
SKIP;
592-
#endif
580+
// no filter
593581
CASE(test_floor)
594582
// no filter
595583
CASE(test_floor_example)

modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,6 @@
77
"test_castlike_FLOAT_to_STRING_expanded",
88
"test_castlike_STRING_to_FLOAT_expanded",
99
"test_concat_1d_axis_negative_1",
10-
"test_flatten_axis0",
11-
"test_flatten_axis2",
12-
"test_flatten_axis3",
13-
"test_flatten_negative_axis1",
14-
"test_flatten_negative_axis2",
15-
"test_flatten_negative_axis4",
1610
"test_logsoftmax_default_axis",
1711
"test_maxpool_2d_dilations",
1812
"test_maxpool_2d_same_lower",

0 commit comments

Comments
 (0)