diff --git a/pytorch2keras/reshape_layers.py b/pytorch2keras/reshape_layers.py index 83a8efb..bdb20b0 100644 --- a/pytorch2keras/reshape_layers.py +++ b/pytorch2keras/reshape_layers.py @@ -28,7 +28,14 @@ def convert_flatten(params, w_name, scope_name, inputs, layers, weights, names): else: tf_name = w_name + str(random.random()) - reshape = keras.layers.Reshape([-1], name=tf_name) + flat_size = 1 + for d in layers[inputs[0]].shape: + try: + flat_size = flat_size * int(d) + except TypeError: + pass + + reshape = keras.layers.Reshape([flat_size], name=tf_name) layers[scope_name] = reshape(layers[inputs[0]]) @@ -171,4 +178,4 @@ def target_layer(x): return tf.shape(x) lambda_layer = keras.layers.Lambda(target_layer) - layers[scope_name] = lambda_layer(layers[inputs[0]]) \ No newline at end of file + layers[scope_name] = lambda_layer(layers[inputs[0]])