diff --git a/src/models/mnist_module.py b/src/models/mnist_module.py index 6f613865c..294d5341c 100644 --- a/src/models/mnist_module.py +++ b/src/models/mnist_module.py @@ -51,6 +51,7 @@ def __init__( :param net: The model to train. :param optimizer: The optimizer to use for training. :param scheduler: The learning rate scheduler to use for training. + :param compile: Whether to compile the model before training. """ super().__init__()