21
21
22
22
23
23
class SplitLayer (Layer ):
24
+ """The base (abstract) class for a split layer.
25
+ """
24
26
def __init__ (self , classname , input_layer , sizes , name ):
27
+ assert hasattr (PythonWrapper , classname ), 'Incorrect split layer specified: ' + classname
28
+
25
29
if type (input_layer ) is getattr (PythonWrapper , classname ):
26
30
super ().__init__ (input_layer )
27
31
return
28
32
29
33
layers , outputs = check_input_layers (input_layer , 1 )
30
34
31
- s = numpy .array (sizes , dtype = numpy .int32 , copy = False )
32
-
33
- if s .size > 3 :
34
- raise ValueError ('The `sizes` must contain not more than 3 elements.' )
35
-
36
- if numpy .any (s < 0 ):
37
- raise ValueError ('The `sizes` must contain only positive values.' )
38
-
39
- internal = getattr (PythonWrapper , classname )(str (name ), layers [0 ], int (outputs [0 ]), s )
35
+ internal = getattr (PythonWrapper , classname )(str (name ), layers [0 ], int (outputs [0 ]), self .__sizes_to_array (sizes ))
40
36
super ().__init__ (internal )
41
37
42
38
@property
@@ -48,8 +44,21 @@ def output_sizes(self):
48
44
@output_sizes .setter
49
45
def output_sizes (self , value ):
50
46
"""
51
- """
52
- self ._internal .set_output_counts (value )
47
+ """
48
+ self ._internal .set_output_counts (self .__sizes_to_array (value ))
49
+
50
+ @staticmethod
51
+ def __sizes_to_array (sizes ) -> numpy .ndarray :
52
+ sizes = numpy .array (sizes , dtype = numpy .int32 )
53
+ if sizes .ndim != 1 or sizes .size > 3 :
54
+ raise ValueError ('The `sizes` must be a one-dimentional sequence containing not more than 3 elements.' )
55
+
56
+ if numpy .any (sizes < 0 ):
57
+ raise ValueError ('The `sizes` must contain only positive values.' )
58
+
59
+ return sizes
60
+
61
+ # ----------------------------------------------------------------------------------------------------------------------
53
62
54
63
55
64
class SplitChannels (SplitLayer ):
0 commit comments