diff --git a/scripts/tf_cnn_benchmarks/models/experimental/deepspeech.py b/scripts/tf_cnn_benchmarks/models/experimental/deepspeech.py index 7b18f8da..df11a466 100644 --- a/scripts/tf_cnn_benchmarks/models/experimental/deepspeech.py +++ b/scripts/tf_cnn_benchmarks/models/experimental/deepspeech.py @@ -121,12 +121,24 @@ def decode_logits(self, logits): class DeepSpeech2Model(model_lib.Model): """Define DeepSpeech2 model.""" - # Supported rnn cells. - SUPPORTED_RNNS = { - 'lstm': tf.nn.rnn_cell.BasicLSTMCell, - 'rnn': tf.nn.rnn_cell.RNNCell, - 'gru': tf.nn.rnn_cell.GRUCell, - } + # Check TensorFlow Keras version. + keras_version = tf.keras.__version__.split('.') + major_version = int(keras_version[0]) + + if major_version >= 3: + # Supported rnn cells for Keras 3.x + SUPPORTED_RNNS = { + 'lstm': tf.keras.layers.LSTM, + 'rnn': tf.keras.layers.SimpleRNN, + 'gru': tf.keras.layers.GRU, + } + else: + # Supported rnn cells for Keras versions below 3.x + SUPPORTED_RNNS = { + 'lstm': tf.nn.rnn_cell.BasicLSTMCell, + 'rnn': tf.nn.rnn_cell.RNNCell, + 'gru': tf.nn.rnn_cell.GRUCell, + } # Parameters for batch normalization. BATCH_NORM_EPSILON = 1e-5