Lightning-AI / torchmetrics

Torchmetrics - Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.07k stars 395 forks source link

Cloning `MultitaskWrapper` with postfix parameters results in incorrect dictionary #2656

Open Buzzeitor30 opened 1 month ago

Buzzeitor30 commented 1 month ago

🐛 Bug

If you clone a MultitaskWrapper and make usage of the postfix parameter, input dictionary keys seem to expect to have the postfix included at the end. However, documentation indicates that the output dictionary should be modified. I.e., the one resulting after having computed the values

To Reproduce

Steps to reproduce the behavior...

Expand ```py from torchmetrics import MultitaskWrapper, F1Score import torch wrapper = MultitaskWrapper({"F1": F1Score(num_classes=2, average='macro', task="multiclass")}) wrapper2 = wrapper.clone(postfix="train") preds = {"F1": torch.ones((5, 2))} tgt = {"F1": torch.tensor([0, 1, 0, 1, 0], dtype=torch.long)} #wrapper(preds, tgt) # This one works wrapper2(preds, tgt) ```

Expected behavior

Resulting dictionary should include {"F1train" : ....}. Nevertheless, it throws a KeyError

Environment

Additional context

github-actions[bot] commented 1 month ago

Hi! thanks for your contribution!, great first issue!

MihaiBabiac commented 1 month ago

This is also the case for the prefix parameter. This means that unfortunately most of the changes from #2330 aren't really usable :frowning_face: . I think that PR assumed that changes to the dictionary keys will affect the output, as in the case of MetricCollection, not the input.

MihaiBabiac commented 1 month ago

@SkafteNicki what do you think about adding self.prefix and self.postfix member variables to the MultitaskWrapper and modifying MultitaskWrapper.keys() and MultitaskWrapper.items() to prepend/append them to the returned keys? I think that would fix the issue

Borda commented 2 weeks ago

@MihaiBabiac thank you for your feedback, I have check the sample code and looked to the implementation and so far it seems to me all is as expected, ref to the docs:

Args:
    prefix: a string to append in front of the metric keys
    postfix: a string to append after the keys of the output dict.

@Buzzeitor30 Could you pls elaborate on why would you be adding pre/postfix to the "metric's dictionary" which is not present in the key?

MihaiBabiac commented 2 weeks ago

Hi @Borda, thanks for the reply.

I think a common scenario is to clone a metric collection or, in this case, the MultiTaskWrapper in order to have separate metric objects for training, test and validation, each of them using different names when logged to tensorboard. This works for the collection, but not for the wrapper.

The similarity to the MetricCollection is also what makes the prefix and postfix arguments confusing, because they have different implications. For the metric collection, they just affect the keys in the output dictionary, meaning that the clones can be applied to the same input data. But for the multitask wrapper they also make the wrapper extract different entries in the input dictionary, so trying to apply the clones to the same data is unlikely to work.

If it's unclear or would like to have an MRE, just let me know.

SkafteNicki commented 1 day ago

Hi @MihaiBabiac, thanks for raising this issue. You are completely right that it is confusing that MetricCollection and MultitaskWrapper behaves in a different way. I have created PR #2722 that should fix this, making the prefix and postfix arguments work similar to how they work in MetricCollection