Open
Description
I'm using the Keras Trainer pattern as illustrated here. The issue when using this pattern is that when you use Tensorboard only the top level weights are being recorded.
The reason for this is that Tensorboard
is recording the weights for the all the layers in self.model.layers
here. But this equal to [<Sequential name=sequential, built=True>]
and the weights for that Sequential object is []
I tried several things:
- Passing a CallBackList to the Tensorflow Trainer when calling fit passing model_a instead of trainer_a, but this fails because model_a has no optimizer
- I tried to overwrite the
layers
method in the Trainer object to haverecursive=True
but the weights were still not showing in TensorBoard suggesting that something else is going on
I'm open to any suggestions here.
full example
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
from keras.callbacks import TensorBoard
# Load MNIST dataset and standardize the data
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
class MyTrainer(keras.Model):
def __init__(self, model):
super().__init__()
self.model = model
# Create loss and metrics here.
self.loss_fn = keras.losses.SparseCategoricalCrossentropy()
self.accuracy_metric = keras.metrics.SparseCategoricalAccuracy()
@property
def metrics(self):
# List metrics here.
return [self.accuracy_metric]
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
y_pred = self.model(x, training=True) # Forward pass
# Compute loss value
loss = self.loss_fn(y, y_pred)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics
for metric in self.metrics:
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
x, y = data
# Inference step
y_pred = self.model(x, training=False)
# Update metrics
for metric in self.metrics:
metric.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
def call(self, x):
# Equivalent to `call()` of the wrapped keras.Model
x = self.model(x)
return x
model_a = keras.models.Sequential(
[
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(256, activation="relu"),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation="softmax"),
]
)
callbacks = [TensorBoard(histogram_freq=1)]
trainer_1 = MyTrainer(model_a)
trainer_1.compile(optimizer=keras.optimizers.SGD())
trainer_1.fit(
x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test), callbacks=callbacks,
)