Skip to content

Commit 50058c9

Browse files
committed
minor changes
Signed-off-by: voropz <[email protected]>
1 parent 49dc07c commit 50058c9

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

NeoML/Python/neoml/Dnn/Split.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,18 @@
2121

2222

2323
class SplitLayer(Layer):
24+
"""The base (abstract) class for a split layer.
25+
"""
2426
def __init__(self, classname, input_layer, sizes, name):
27+
assert hasattr(PythonWrapper, classname), 'Incorrect split layer specified: ' + classname
28+
2529
if type(input_layer) is getattr(PythonWrapper, classname):
2630
super().__init__(input_layer)
2731
return
2832

2933
layers, outputs = check_input_layers(input_layer, 1)
3034

31-
s = numpy.array(sizes, dtype=numpy.int32, copy=False)
32-
33-
if s.size > 3:
34-
raise ValueError('The `sizes` must contain not more than 3 elements.')
35-
36-
if numpy.any(s < 0):
37-
raise ValueError('The `sizes` must contain only positive values.')
38-
39-
internal = getattr(PythonWrapper, classname)(str(name), layers[0], int(outputs[0]), s)
35+
internal = getattr(PythonWrapper, classname)(str(name), layers[0], int(outputs[0]), self.__sizes_to_array(sizes))
4036
super().__init__(internal)
4137

4238
@property
@@ -48,8 +44,21 @@ def output_sizes(self):
4844
@output_sizes.setter
4945
def output_sizes(self, value):
5046
"""
51-
"""
52-
self._internal.set_output_counts(value)
47+
"""
48+
self._internal.set_output_counts(self.__sizes_to_array(value))
49+
50+
@staticmethod
51+
def __sizes_to_array(sizes) -> numpy.ndarray:
52+
sizes = numpy.array(sizes, dtype=numpy.int32)
53+
if sizes.ndim != 1 or sizes.size > 3:
54+
raise ValueError('The `sizes` must be a one-dimentional sequence containing not more than 3 elements.')
55+
56+
if numpy.any(sizes < 0):
57+
raise ValueError('The `sizes` must contain only positive values.')
58+
59+
return sizes
60+
61+
# ----------------------------------------------------------------------------------------------------------------------
5362

5463

5564
class SplitChannels(SplitLayer):

0 commit comments

Comments
 (0)