Skip to content

Commit

Permalink
Reviewed LogisticRegression (#950)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka B <[email protected]>
  • Loading branch information
3 people authored May 21, 2023
1 parent 3d1c652 commit 4ef2750
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 90 deletions.
191 changes: 133 additions & 58 deletions src/pl_bolts/models/regression/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
"""An implemen§ation of Logistic Regression in PyTorch-Lightning."""

from argparse import ArgumentParser
from typing import Any, Dict, List, Tuple, Type

import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from torch import Tensor, nn
from torch.nn import functional as F
from torch.nn.functional import softmax
from torch.optim import Adam
from torch.optim.optimizer import Optimizer
from torchmetrics.functional import accuracy
from torchmetrics import functional

from pl_bolts.utils.stability import under_review


@under_review()
class LogisticRegression(LightningModule):
"""Logistic regression model."""
"""Logistic Regression Model."""

criterion: nn.CrossEntropyLoss
linear: nn.Linear

def __init__(
self,
Expand All @@ -28,87 +30,160 @@ def __init__(
l2_strength: float = 0.0,
**kwargs: Any,
) -> None:
"""
"""Logistic Regression.
Args:
input_dim: number of dimensions of the input (at least 1)
num_classes: number of class labels (binary: 2, multi-class: >2)
bias: specifies if a constant or intercept should be fitted (equivalent to fit_intercept in sklearn)
learning_rate: learning_rate for the optimizer
optimizer: the optimizer to use (default: ``Adam``)
l1_strength: L1 regularization strength (default: ``0.0``)
l2_strength: L2 regularization strength (default: ``0.0``)
input_dim: Number of dimensions of the input (at least `1`).
num_classes: Number of class labels (binary: `2`, multi-class: > `2`).
bias: Specifies if a constant or intercept should be fitted (equivalent to `fit_intercept` in `sklearn`).
learning_rate: Learning rate for the optimizer.
optimizer: Model optimizer to use.
l1_strength: L1 regularization strength.
l2_strength: L2 regularization strength.
Attributes:
linear: Linear layer.
criterion: Cross-Entropy loss function.
optimizer: Model optimizer to use.
"""
super().__init__()
self.save_hyperparameters()
self.optimizer = optimizer

self.linear = nn.Linear(in_features=self.hparams.input_dim, out_features=self.hparams.num_classes, bias=bias)
self.criterion = nn.CrossEntropyLoss()
self.linear = nn.Linear(
in_features=self.hparams.input_dim, out_features=self.hparams.num_classes, bias=self.hparams.bias
)

def forward(self, x: Tensor) -> Tensor:
x = self.linear(x)
y_hat = softmax(x)
return y_hat
"""Forward pass of the model.
Args:
x: Input tensor.
Returns:
Output tensor.
"""
return self.linear(x)

def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
x, y = batch
"""Training step for the model.
# flatten any input
x = x.view(x.size(0), -1)
Args:
batch: Batch of data.
batch_idx: Batch index.
y_hat = self.linear(x)
Returns:
Loss tensor.
"""
return self._shared_step(batch, "train")

# PyTorch cross_entropy function combines log_softmax and nll_loss in single function
loss = F.cross_entropy(y_hat, y, reduction="sum")
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
"""Validation step for the model.
# L1 regularizer
if self.hparams.l1_strength > 0:
l1_reg = self.linear.weight.abs().sum()
loss += self.hparams.l1_strength * l1_reg
Args:
batch: Batch of data.
batch_idx: Batch index.
# L2 regularizer
if self.hparams.l2_strength > 0:
l2_reg = self.linear.weight.pow(2).sum()
loss += self.hparams.l2_strength * l2_reg
Returns:
Loss tensor.
"""
return self._shared_step(batch, "val")

loss /= x.size(0)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
"""Test step for the model.
tensorboard_logs = {"train_ce_loss": loss}
progress_bar_metrics = tensorboard_logs
return {"loss": loss, "log": tensorboard_logs, "progress_bar": progress_bar_metrics}
Args:
batch: Batch of data.
batch_idx: Batch index.
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self.linear(x)
acc = accuracy(F.softmax(y_hat, -1), y)
return {"val_loss": F.cross_entropy(y_hat, y), "acc": acc}
Returns:
Loss tensor.
"""
return self._shared_step(batch, "test")

def validation_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:
acc = torch.stack([x["acc"] for x in outputs]).mean()
val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
tensorboard_logs = {"val_ce_loss": val_loss, "val_acc": acc}
progress_bar_metrics = tensorboard_logs
return {"val_loss": val_loss, "log": tensorboard_logs, "progress_bar": progress_bar_metrics}
"""Validation epoch end for the model.
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self.linear(x)
acc = accuracy(F.softmax(y_hat, -1), y)
return {"test_loss": F.cross_entropy(y_hat, y), "acc": acc}
Args:
outputs: List of outputs from the validation step.
Returns:
Loss tensor.
"""
return self._shared_epoch_end(outputs, "val")

def test_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:
acc = torch.stack([x["acc"] for x in outputs]).mean()
test_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
tensorboard_logs = {"test_ce_loss": test_loss, "test_acc": acc}
progress_bar_metrics = tensorboard_logs
return {"test_loss": test_loss, "log": tensorboard_logs, "progress_bar": progress_bar_metrics}
"""Test epoch end for the model.
Args:
outputs: List of outputs from the test step.
Returns:
Loss tensor.
"""
return self._shared_epoch_end(outputs, "test")

def configure_optimizers(self) -> Optimizer:
"""Configure the optimizer for the model.
Returns:
Optimizer.
"""
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate)

def _prepare_batch(self, batch: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
x, y = batch
x = x.view(x.size(0), -1)
return self.linear(x), torch.tensor(y, dtype=torch.long)

def _shared_step(self, batch: Tuple[Tensor, Tensor], stage: str) -> Dict[str, Tensor]:
x, y = self._prepare_batch(batch)
loss = self.criterion(x, y)

if stage == "train":
loss = self._regularization(loss)
loss /= x.size(0)
metrics = {"loss": loss}
self.log_dict(metrics, on_step=True)
return metrics

acc = self._calculate_accuracy(x, y)
return self._log_metrics(acc, loss, stage, on_step=True)

def _shared_epoch_end(self, outputs: List[Dict[str, Tensor]], stage: str) -> Dict[str, Tensor]:
acc = torch.stack([x[f"{stage}_acc"] for x in outputs]).mean()
loss = torch.stack([x[f"{stage}_loss"] for x in outputs]).mean()
return self._log_metrics(acc, loss, stage, on_epoch=True)

def _log_metrics(self, acc: Tensor, loss: Tensor, stage: str, **kwargs: bool) -> Dict[str, Tensor]:
metrics = {f"{stage}_loss": loss, f"{stage}_acc": acc}
self.log_dict(metrics, **kwargs)
return metrics

def _calculate_accuracy(self, x: Tensor, y: Tensor) -> Tensor:
_, y_hat = torch.max(x, dim=-1)
return functional.accuracy(y_hat, y, average="weighted", num_classes=self.hparams.num_classes)

def _regularization(self, loss: Tensor) -> Tensor:
if self.hparams.l1_strength > 0:
l1_reg = self.linear.weight.abs().sum()
loss += self.hparams.l1_strength * l1_reg

if self.hparams.l2_strength > 0:
l2_reg = self.linear.weight.pow(2).sum()
loss += self.hparams.l2_strength * l2_reg
return loss

@staticmethod
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
"""Adds model specific arguments to the parser.
Args:
parent_parser: Parent parser to which the arguments will be added.
Returns:
ArgumentParser with the added arguments.
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--learning_rate", type=float, default=0.0001)
parser.add_argument("--input_dim", type=int, default=None)
Expand Down
Empty file.
24 changes: 24 additions & 0 deletions tests/models/regression/test_logistic_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import functools
import operator

import pytorch_lightning as pl

from pl_bolts import datamodules
from pl_bolts.models import regression


def test_logistic_regression_model(datadir):
pl.seed_everything(0)

dm = datamodules.MNISTDataModule(datadir)

model = regression.LogisticRegression(
input_dim=functools.reduce(operator.mul, dm.dims, 1), num_classes=10, learning_rate=0.001
)

trainer = pl.Trainer(max_epochs=3, logger=False, enable_checkpointing=False)
trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)
assert trainer.state.finished
assert trainer.callback_metrics["test_acc"] > 0.9
assert trainer.callback_metrics["test_loss"] < 0.3
33 changes: 1 addition & 32 deletions tests/models/test_classic_ml.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import functools
import operator

import numpy as np
from pytorch_lightning import Trainer, seed_everything
from torch.utils.data import DataLoader

from pl_bolts.datamodules import MNISTDataModule
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset
from pl_bolts.models.regression import LinearRegression, LogisticRegression
from pl_bolts.models.regression import LinearRegression


def test_linear_regression_model(tmpdir):
Expand Down Expand Up @@ -37,30 +33,3 @@ def test_linear_regression_model(tmpdir):
coeffs = model.linear.weight.detach().numpy().flatten()
np.testing.assert_allclose(coeffs, [1, 2], rtol=1e-3)
trainer.test(model, loader)


def test_logistic_regression_model(tmpdir, datadir):
seed_everything(0)

# create dataset
dm = MNISTDataModule(num_workers=0, data_dir=datadir)

model = LogisticRegression(
input_dim=functools.reduce(operator.mul, dm.dims, 1), num_classes=10, learning_rate=0.001
)
model.prepare_data = dm.prepare_data
model.setup = dm.setup
model.train_dataloader = dm.train_dataloader
model.val_dataloader = dm.val_dataloader
model.test_dataloader = dm.test_dataloader

trainer = Trainer(
max_epochs=3,
default_root_dir=tmpdir,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model)
trainer.test(model)
# todo: update model and add healthy check
# assert trainer.progress_bar_dict['test_acc'] >= 0.9

0 comments on commit 4ef2750

Please sign in to comment.