Skip to content

keras.layers.Flatten() not updating shape information after transformation #21502

@TheWalkingSea

Description

@TheWalkingSea
one_hot_encode_input = keras.layers.Input(shape=(2,), name='code', dtype=tf.int64)
one_hot_encode = keras.layers.IntegerLookup(
    vocabulary=[0, 1, 2],
    output_mode='one_hot',
    num_oov_indices=0
)(one_hot_encode_input)
one_hot_encode = keras.layers.Flatten()(one_hot_encode)


output = keras.layers.Dense(1, activation='sigmoid', name='score')(one_hot_encode)

model = keras.Model(
    inputs=one_hot_encode_input,
    outputs=output
)


model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(),
    metrics=[keras.metrics.Accuracy(), keras.metrics.AUC()]
)


model.fit(
    np.array([[1, 2]]),
    np.array([[1]]),
    epochs=4,
    verbose=2
)

Results in an error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[680], line 8
      1 model.compile(
      2     optimizer=keras.optimizers.Adam(),
      3     loss=keras.losses.BinaryCrossentropy(),
      4     metrics=[keras.metrics.Accuracy(), keras.metrics.AUC()]
      5 )
----> 8 model.fit(
      9     np.array([[1, 2]]),
     10     np.array([[1]]),
     11     epochs=4,
     12     verbose=2
     13 )

File /mnt/c/Users/Austi/Downloads/code/py/solarisAIO/venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File /mnt/c/Users/Austi/Downloads/code/py/solarisAIO/venv/lib/python3.12/site-packages/keras/src/layers/input_spec.py:227, in assert_input_compatibility(input_spec, inputs, layer_name)
    222     for axis, value in spec.axes.items():
    223         if value is not None and shape[axis] not in {
    224             value,
    225             None,
    226         }:
--> 227             raise ValueError(
    228                 f'Input {input_index} of layer "{layer_name}" is '
    229                 f"incompatible with the layer: expected axis {axis} "
    230                 f"of input shape to have value {value}, "
    231                 "but received input with "
    232                 f"shape {shape}"
    233             )
    234 # Check shape.
    235 if spec.shape is not None:

ValueError: Exception encountered when calling Functional.call().

Input 0 of layer "score" is incompatible with the layer: expected axis -1 of input shape to have value 3, but received input with shape (None, 6)

Arguments received by Functional.call():
  • inputs=tf.Tensor(shape=(None, 2), dtype=int64)
  • training=Truemask=Nonekwargs=<class 'inspect._empty'>

The Flatten Layer is not converting the output shape of the IntegerLookup from (2,3) to (6,) -> Results in an error during graph execution. An attempt to .Reshape the output of Flatten results in an error of similar results

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions