Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.14k stars 409 forks source link

GeneralizedDiceScore fails with 2D tensors and `input_format="index"` #2816

Closed fguiotte closed 1 week ago

fguiotte commented 3 weeks ago

πŸ› Bug

The generalized dice score does not work with 2D tensors of shape (N, L) and input_format="index".

To Reproduce

import torch
from torchmetrics.segmentation import GeneralizedDiceScore

batch_size = 16
num_classes = 8
L = 100

y = torch.randint(num_classes, (batch_size, L))
pred = torch.randint(num_classes, (batch_size, L))

dice = GeneralizedDiceScore(num_classes=num_classes, input_format="index")

dice(y, pred)
ValueError: Expected both `preds` and `target` to have at least 3 dimensions, but got 2.

Expected behavior

According to the documentation, the dice score should work on tensors of shape (N, ...) with input_format="index":

https://github.com/Lightning-AI/torchmetrics/blob/0a64b3f8ae5f7ea1186f4e732f232f73a084e4cd/src/torchmetrics/segmentation/generalized_dice.py#L54-L61

Environment

Additional context

The exception is coming from dice functional:

https://github.com/Lightning-AI/torchmetrics/blob/0a64b3f8ae5f7ea1186f4e732f232f73a084e4cd/src/torchmetrics/functional/segmentation/generalized_dice.py#L55-L56

github-actions[bot] commented 3 weeks ago

Hi! thanks for your contribution!, great first issue!