@@ -131,6 +131,7 @@ class ONNXImporter
131
131
typedef void (ONNXImporter::*ONNXImporterNodeParser)(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
132
132
typedef std::map<std::string, ONNXImporterNodeParser> DispatchMap;
133
133
134
+ void parseMaxUnpool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
134
135
void parseMaxPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
135
136
void parseAveragePool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
136
137
void parseReduce (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
@@ -625,6 +626,41 @@ void setCeilMode(LayerParams& layerParams)
625
626
}
626
627
}
627
628
629
+ void ONNXImporter::parseMaxUnpool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
630
+ {
631
+ layerParams.type = " MaxUnpool" ;
632
+
633
+ DictValue kernel_shape = layerParams.get (" kernel_size" );
634
+ CV_Assert (kernel_shape.size () == 2 );
635
+ layerParams.set (" pool_k_w" , kernel_shape.get <int >(0 ));
636
+ layerParams.set (" pool_k_h" , kernel_shape.get <int >(1 ));
637
+
638
+ int pool_pad_w = 0 , pool_pad_h = 0 ;
639
+ if (layerParams.has (" pad" ))
640
+ {
641
+ DictValue pads = layerParams.get (" pad" );
642
+ CV_CheckEQ (pads.size (), 2 , " " );
643
+ pool_pad_w = pads.get <int >(0 );
644
+ pool_pad_h = pads.get <int >(1 );
645
+ }
646
+ layerParams.set (" pool_pad_w" , pool_pad_w);
647
+ layerParams.set (" pool_pad_h" , pool_pad_h);
648
+
649
+
650
+ int pool_stride_w = 1 , pool_stride_h = 1 ;
651
+ if (layerParams.has (" stride" ))
652
+ {
653
+ DictValue strides = layerParams.get (" stride" );
654
+ CV_CheckEQ (strides.size (), 2 , " " );
655
+ pool_stride_w = strides.get <int >(0 );
656
+ pool_stride_h = strides.get <int >(1 );
657
+ }
658
+ layerParams.set (" pool_stride_w" , pool_stride_w);
659
+ layerParams.set (" pool_stride_h" , pool_stride_h);
660
+
661
+ addLayer (layerParams, node_proto);
662
+ }
663
+
628
664
void ONNXImporter::parseMaxPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
629
665
{
630
666
layerParams.type = " Pooling" ;
@@ -659,11 +695,11 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
659
695
pool = " AVE" ;
660
696
layerParams.set (" pool" , pool);
661
697
layerParams.set (" global_pooling" , !layerParams.has (" axes" ));
698
+ bool keepdims = layerParams.get <int >(" keepdims" , 1 ) == 1 ;
662
699
if (layerParams.has (" axes" ) && (layer_type == " ReduceMean" || layer_type == " ReduceSum" || layer_type == " ReduceMax" ))
663
700
{
664
701
MatShape inpShape = outShapes[node_proto.input (0 )];
665
702
DictValue axes = layerParams.get (" axes" );
666
- bool keepdims = layerParams.get <int >(" keepdims" );
667
703
MatShape targetShape;
668
704
std::vector<bool > shouldDelete (inpShape.size (), false );
669
705
for (int i = 0 ; i < axes.size (); i++) {
@@ -771,7 +807,10 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
771
807
}
772
808
else if (!layerParams.has (" axes" ) && (layer_type == " ReduceMean" || layer_type == " ReduceSum" || layer_type == " ReduceMax" ))
773
809
{
774
- CV_CheckEQ (layerParams.get <int >(" keepdims" ), 0 , " layer only supports keepdims = false" );
810
+ IterShape_t shapeIt = outShapes.find (node_proto.input (0 ));
811
+ CV_Assert (shapeIt != outShapes.end ());
812
+ const size_t dims = keepdims ? shapeIt->second .size () : 1 ;
813
+
775
814
LayerParams reshapeLp;
776
815
reshapeLp.name = layerParams.name + " /reshape" ;
777
816
reshapeLp.type = " Reshape" ;
@@ -793,8 +832,8 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
793
832
addLayer (poolLp, node_proto);
794
833
795
834
layerParams.type = " Reshape" ;
796
- int targetShape[] = { 1 } ;
797
- layerParams.set (" dim" , DictValue::arrayInt (& targetShape[ 0 ], 1 ));
835
+ std::vector< int > targetShape (dims, 1 ) ;
836
+ layerParams.set (" dim" , DictValue::arrayInt (targetShape. data (), targetShape. size () ));
798
837
799
838
node_proto.set_input (0 , node_proto.output (0 ));
800
839
node_proto.set_output (0 , layerParams.name );
@@ -2341,6 +2380,7 @@ const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
2341
2380
{
2342
2381
DispatchMap dispatch;
2343
2382
2383
+ dispatch[" MaxUnpool" ] = &ONNXImporter::parseMaxUnpool;
2344
2384
dispatch[" MaxPool" ] = &ONNXImporter::parseMaxPool;
2345
2385
dispatch[" AveragePool" ] = &ONNXImporter::parseAveragePool;
2346
2386
dispatch[" GlobalAveragePool" ] = dispatch[" GlobalMaxPool" ] = dispatch[" ReduceMean" ] = dispatch[" ReduceSum" ] =
0 commit comments