Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FD (e.g. FID) slice wise for 3D #60

Open
danielbarco opened this issue Nov 28, 2024 · 1 comment
Open

FD (e.g. FID) slice wise for 3D #60

danielbarco opened this issue Nov 28, 2024 · 1 comment
Assignees

Comments

@danielbarco
Copy link

https://github.com/layer6ai-labs/dgm-eval/blob/master/dgm_eval/metrics/fd.py

structure

import torch
from torchmetrics import Metric


class MyAccuracy(Metric):
    def __init__(self):
        # remember to call super
        super().__init__()
        # call `self.add_state`for every internal state that is needed for the metrics computations
        # dist_reduce_fx indicates the function that should be used to reduce
        # state from multiple processes
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
        # extract predicted class index for computing accuracy
        preds = preds.argmax(dim=-1)
        assert preds.shape == target.shape
        # update metric states
        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self) -> torch.Tensor:
        # compute final result
        return self.correct.float() / self.total


my_metric = MyAccuracy()
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

print(my_metric(preds, target))
@danielbarco
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant