-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
featureIs an improvement or enhancementIs an improvement or enhancementneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainers
Description
Description & Motivation
Problem
Given the script below, when running running it with lightning CLI, then hparams.yaml
becomes
optimizer:
class_path: __main__.SGD
init_args:
params:
- 1
- 2
- 3
lr: 123.0
myds:
x: hello
_instantiator: lightning.pytorch.cli.instantiate_module
When running the hardcoded training script instead, hparams.yaml
becomes
myds: !!python/object:__main__.MyDataclass
x: hello
optimizer: !!python/object:__main__.SGD {}
In other words, even though the hyperparameters are the same, hparams.yaml
look different. Maybe an alternative question is what's the best practice to define more complex hyperparameters.
Script below
"""
# How to trigger hardcoded training
Comment out `main()` at the very end. Run `python script.py
# How to trigger CLI training
python script.py --config cfg.yaml
model:
optimizer:
class_path: SGD
init_args:
lr: 123
params: [1, 2, 3]
"""
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
from lightning.pytorch import cli
import torch
import lightning as L
from torch.utils.data import random_split, DataLoader
from torchvision.transforms import v2
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms
# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
# custom objects
from dataclasses import dataclass
import abc
from typing import Iterable, Callable
from lightning.pytorch.core.mixins import HyperparametersMixin
class Optimizer(abc.ABC):
def __init__(self, params: Iterable = [1, 2, 3]):
pass
class SGD(Optimizer):
def __init__(self, params: Iterable, lr: float):
super().__init__()
@dataclass
class MyDataclass:
x: str = "hello"
# define the LightningModule
class LitAutoEncoder(L.LightningModule):
def __init__(
self,
# optimizer: Callable[[], Optimizer],
optimizer: Optimizer,
myds: MyDataclass,
):
print(type(optimizer))
# print(type(optimizer()))
super().__init__()
self.save_hyperparameters()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
class MNISTDataModule(L.LightningDataModule):
def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage: str):
transform = v2.ToTensor() # mainly to convert PIL to tensor
self.mnist_test = MNIST(
self.data_dir,
train=False,
download=True,
transform=transform,
)
self.mnist_predict = self.mnist_test
mnist_full = MNIST(
self.data_dir,
train=True,
download=True,
transform=transform,
)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=self.batch_size)
def train_hardcode():
autoencoder = LitAutoEncoder(
optimizer=SGD(params=[1, 2, 3], lr=123),
myds=MyDataclass(),
)
# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=10, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
def main():
cli.LightningCLI(
LitAutoEncoder,
MNISTDataModule,
trainer_defaults=dict(
max_epochs=1,
limit_train_batches=10,
limit_val_batches=10,
),
)
if __name__ == "__main__":
main()
# train_hardcode()
Pitch
No response
Alternatives
No response
Additional context
No response
cc @Borda
Metadata
Metadata
Assignees
Labels
featureIs an improvement or enhancementIs an improvement or enhancementneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainers