CAIIVS / chuchichaestli

Where you find all the state-of-the-art cooking utensils (salt, pepper, gradient descent... the usual).
GNU General Public License v3.0
4 stars 0 forks source link

FD (e.g. FID) slice wise for 3D #60

Open danielbarco opened 3 days ago

danielbarco commented 3 days ago

https://github.com/layer6ai-labs/dgm-eval/blob/master/dgm_eval/metrics/fd.py

structure

import torch
from torchmetrics import Metric

class MyAccuracy(Metric):
    def __init__(self):
        # remember to call super
        super().__init__()
        # call `self.add_state`for every internal state that is needed for the metrics computations
        # dist_reduce_fx indicates the function that should be used to reduce
        # state from multiple processes
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
        # extract predicted class index for computing accuracy
        preds = preds.argmax(dim=-1)
        assert preds.shape == target.shape
        # update metric states
        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self) -> torch.Tensor:
        # compute final result
        return self.correct.float() / self.total

my_metric = MyAccuracy()
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

print(my_metric(preds, target))
danielbarco commented 3 days ago

implementation here: https://github.com/CAIIVS/chuchichaestli/tree/23-metrics-module-for-chuchichaestli