Skip to content

Tensorboard not working with Trainer Pattern #20809

Open
@GeraudK

Description

@GeraudK

I'm using the Keras Trainer pattern as illustrated here. The issue when using this pattern is that when you use Tensorboard only the top level weights are being recorded.

The reason for this is that Tensorboard is recording the weights for the all the layers in self.model.layers here. But this equal to [<Sequential name=sequential, built=True>] and the weights for that Sequential object is []

I tried several things:

  1. Passing a CallBackList to the Tensorflow Trainer when calling fit passing model_a instead of trainer_a, but this fails because model_a has no optimizer
  2. I tried to overwrite the layers method in the Trainer object to have recursive=True but the weights were still not showing in TensorBoard suggesting that something else is going on

I'm open to any suggestions here.

full example

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras
from keras.callbacks import TensorBoard

# Load MNIST dataset and standardize the data
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

class MyTrainer(keras.Model):
    def __init__(self, model):
        super().__init__()
        self.model = model
        # Create loss and metrics here.
        self.loss_fn = keras.losses.SparseCategoricalCrossentropy()
        self.accuracy_metric = keras.metrics.SparseCategoricalAccuracy()

    @property
    def metrics(self):
        # List metrics here.
        return [self.accuracy_metric]

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self.model(x, training=True)  # Forward pass
            # Compute loss value
            loss = self.loss_fn(y, y_pred)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics
        for metric in self.metrics:
            metric.update_state(y, y_pred)

        # Return a dict mapping metric names to current value.
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        x, y = data

        # Inference step
        y_pred = self.model(x, training=False)

        # Update metrics
        for metric in self.metrics:
            metric.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def call(self, x):
        # Equivalent to `call()` of the wrapped keras.Model
        x = self.model(x)
        return x

model_a = keras.models.Sequential(
    [
        keras.layers.Flatten(input_shape=(28, 28)),
        keras.layers.Dense(256, activation="relu"),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10, activation="softmax"),
    ]
)

callbacks = [TensorBoard(histogram_freq=1)]
trainer_1 = MyTrainer(model_a)
trainer_1.compile(optimizer=keras.optimizers.SGD())
trainer_1.fit(
    x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test), callbacks=callbacks,
)

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions