Lightning-AI / torchmetrics

Torchmetrics - Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.07k stars 395 forks source link

Allow arbitrary types for metric states #987

Open LarsHill opened 2 years ago

LarsHill commented 2 years ago

🚀 Feature

Metric states seem to be limited to torch.Tensor or List[torch.Tensor].

In my usecase i want to store a dictionary as state. My dataset comprises of samples who can be assigned to different documents. In order to calculate macro metrics (calculate metrics per document and average) I want to store my metric states (e.g. true positives, false positives, etc.) as a dictionary. Here is some pseudocode:

class MyMetric(Metric):

    def __init__(self, dist_sync_on_step: bool = False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state("statistics", default=defaultdict(lambda: defaultdict(float)), dist_reduce_fx=None)

    def update(
            self,
            predictions: torch.Tensor,
            targets: torch.Tensor,
            document_ids: List
    ):
        predictions = predictions.bool()
        targets = targets.bool()

        tps = predictions * targets
        tns = predictions.logical_not() * targets.logical_not()
        fps = predictions * targets.logical_not()
        fns = predictions.logical_not() * targets

        for id_, tp, tn, fp, fn in zip(document_ids, tps, tns, fps, fns):
            self.statistics[id_]['tp'] += tp.float().item()
            self.statistics[id_]['tn'] += tn.float().item()
            self.statistics[id_]['fp'] += fp.float().item()
            self.statistics[id_]['fn'] += fn.float().item()

    def compute(self):
            ...

Unfortunately the above code is not allowed. Each metric state has to be a torch.Tensor or a List[torch.Tensor]. That means normal float values or numpy arrays cannot be used as metrics either. Is there a particular reason for that?

github-actions[bot] commented 2 years ago

Hi! thanks for your contribution!, great first issue!

LarsHill commented 2 years ago

Hey,

Are there any updates on this matter? After checking metric.py I don't see a particular reason for this restriction but maybe I overlooked something. Also, the reset method could simply reset each state variable to the user defined default instead of checking the type to distinguish between List and torch.Tensor. Regarding reduction logic one can keep the current approach and if the state type is something custom (e.g. a Dict, etc.) the synced state is simply a List[Dict] and the user is in charge of providing a custome reduce fn to combine the different states across processes.

amorehead commented 1 year ago

@LarsHill, I agree that this is an important question to ask. I am also encountering a use case where I need to track my custom Metric's state with a dictionary. Any thoughts on this query, @SkafteNicki or @Borda?

LarsHill commented 1 year ago

Once again a short ping. Are there any plans to tackle this issue in the near future?

SkafteNicki commented 1 year ago

Hi @LarsHill, It has not really been a priority for us, but I am willing to listen. At the moment I think we could consider adding the option to also support dicts of tensors as an third option. It would not completely solve your initial example, because it requires dict of dict of tensors, but then you could do something like:

for i in range(num_document_ids):
    self.add_state(f"statistics_{i}", default=defaultdict(), dist_reduce_fx=None)

and then in update:

for id_, tp, tn, fp, fn in zip(document_ids, tps, tns, fps, fns):
    getattr(self, f"statistics_{id_})['tp'] += tp.float().item()
    getattr(self, f"statistics_{id_})['tn'] += tn.float().item()
    getattr(self, f"statistics_{id_})['fp'] += fp.float().item()
    getattr(self, f"statistics_{id_})['fn'] += fn.float().item()

this would of cause mean that the num_documents_ids needed to be known in advance.

LarsHill commented 10 months ago

Hi @SkafteNicki,

Thanks for your suggestion. That would work as a temporary solution, but I think it's not so elegant to iterate the entire dataset at initialization time, just to retrieve all the document ids.

Is there a particular reason, why the state types are restricted in the first place? What about the suggestion to extend it to arbitrary types and if it is something custom, the synced/reduced state is simply a List of that custom type and the user is in charge to properly reduce it, i.e. by providing a specific reduce function.

I am referring to these lines in metric.py:

if not isinstance(default, (Tensor, list)) or (isinstance(default, list) and default):
    raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)")

if dist_reduce_fx == "sum":
    dist_reduce_fx = dim_zero_sum
elif dist_reduce_fx == "mean":
    dist_reduce_fx = dim_zero_mean
elif dist_reduce_fx == "max":
    dist_reduce_fx = dim_zero_max
elif dist_reduce_fx == "min":
    dist_reduce_fx = dim_zero_min
elif dist_reduce_fx == "cat":
    dist_reduce_fx = dim_zero_cat
elif dist_reduce_fx is not None and not callable(dist_reduce_fx):
    raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', 'min', 'max', None]")

It seems now users already have the chance to pass a custom dist_reduce_fx Callable. So why not change the above code to something like that?

if not isinstance(default, (Tensor, list)) or (isinstance(default, list) and default and not callable(dist_reduce_fx)):
    raise ValueError("If the state variable is not a tensor or any empty list (where you can append tensors) you have to provide a custom `dist_reduce_fx` that is callable.")

if dist_reduce_fx == "sum":
    dist_reduce_fx = dim_zero_sum
elif dist_reduce_fx == "mean":
    dist_reduce_fx = dim_zero_mean
elif dist_reduce_fx == "max":
    dist_reduce_fx = dim_zero_max
elif dist_reduce_fx == "min":
    dist_reduce_fx = dim_zero_min
elif dist_reduce_fx == "cat":
    dist_reduce_fx = dim_zero_cat
elif dist_reduce_fx is not None and not callable(dist_reduce_fx):
    raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', 'min', 'max', None]")

For current users that work with torch.Tensor oder List[torch.Tensor] nothing changes, and users who have to use custom metric state types have the flexibility to do so.

Is there some issues with such a change that I am not aware of?

Btw since I could not wait for changes in torchmetrics, I had to go ahead and implement my own metric abstraction. However, I would prefer to switch to torchmetrics, but this seems to me like a missing feature, that is limiting.