Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSIM 3D #58

Open
danielbarco opened this issue Nov 28, 2024 · 0 comments
Open

SSIM 3D #58

danielbarco opened this issue Nov 28, 2024 · 0 comments

Comments

@danielbarco
Copy link
Contributor

danielbarco commented Nov 28, 2024

https://docs.monai.io/en/stable/losses.html
https://github.com/Lightning-AI/torchmetrics/tree/master/src/torchmetrics/image

Structure:

import torch
from torchmetrics import Metric


class MyAccuracy(Metric):
    def __init__(self):
        # remember to call super
        super().__init__()
        # call `self.add_state`for every internal state that is needed for the metrics computations
        # dist_reduce_fx indicates the function that should be used to reduce
        # state from multiple processes
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
        # extract predicted class index for computing accuracy
        preds = preds.argmax(dim=-1)
        assert preds.shape == target.shape
        # update metric states
        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self) -> torch.Tensor:
        # compute final result
        return self.correct.float() / self.total


my_metric = MyAccuracy()
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

print(my_metric(preds, target))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant