import torch
from torchmetrics import Metric
class MyAccuracy(Metric):
def __init__(self):
# remember to call super
super().__init__()
# call `self.add_state`for every internal state that is needed for the metrics computations
# dist_reduce_fx indicates the function that should be used to reduce
# state from multiple processes
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
# extract predicted class index for computing accuracy
preds = preds.argmax(dim=-1)
assert preds.shape == target.shape
# update metric states
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self) -> torch.Tensor:
# compute final result
return self.correct.float() / self.total
my_metric = MyAccuracy()
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
print(my_metric(preds, target))
https://github.com/layer6ai-labs/dgm-eval/blob/master/dgm_eval/metrics/fd.py
structure