Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.15k stars 410 forks source link

`DiceScore(average="weighted")` doesn't work when in same compute group within `MetricCollection` #2847

Open nkaenzig opened 4 days ago

nkaenzig commented 4 days ago

🐛 Bug

Using a torchmetrics.MetricCollection with DiceScore(average="weighted") and other DiceScore metrics in the same compute group will raise the following error ValueError: No samples to concatenate, because the support state remains an empty list due to the shared state within the compute group.

To Reproduce

The error can be reproduced with the following code. If you disable the compute groups by setting the COMPUTE_GROUPS variable to False it runs without errors.

Code sample ```python import torch import pytorch_lightning as pl from torch.utils.data import Dataset, DataLoader import torchmetrics from torchmetrics import segmentation class DummySegmentationDataset(Dataset): def __init__(self, n_samples=16, n_classes=3): self.n_samples = n_samples self.n_classes = n_classes def __len__(self): return self.n_samples def __getitem__(self, idx): target = torch.full((self.n_classes, 128, 128), 0, dtype=torch.int8) preds = torch.full((self.n_classes, 128, 128), 0, dtype=torch.int8) return preds, target class MulticlassSegmentationMetrics(torchmetrics.MetricCollection): """Default metrics for multi-class semantic segmentation tasks.""" def __init__( self, num_classes: int, include_background: bool = False, prefix: str | None = None, postfix: str | None = None, compute_groups: bool = False, ) -> None: super().__init__( metrics={ "DiceScore (micro)": segmentation.DiceScore( num_classes=num_classes, include_background=include_background, average="micro", ), "DiceScore (weighted)": segmentation.DiceScore( num_classes=num_classes, include_background=include_background, average="weighted", ), }, prefix=prefix, postfix=postfix, compute_groups=compute_groups, ) class DummySegmentationModel(pl.LightningModule): def __init__(self, n_classes, metrics): super().__init__() self.n_classes = n_classes self.metrics = metrics def training_step(self, batch, batch_idx): preds, target = batch self.metrics.update(preds, target) return {"loss": torch.tensor(0.0, requires_grad=True)} # Dummy loss def on_train_epoch_end(self): metrics = self.metrics.compute() self.log_dict(metrics) print(f"Epoch {self.current_epoch} metrics:", metrics) def configure_optimizers(self): return torch.optim.SGD([torch.tensor([0.0], requires_grad=True)], lr=0.01) def main(): COMPUTE_GROUPS = True # When setting this to False, it works as expected N_SAMPLES, N_CLASSES = 16, 3 BATCH_SIZE = 4 dataset = DummySegmentationDataset(n_samples=N_SAMPLES, n_classes=N_CLASSES) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False) metrics = MulticlassSegmentationMetrics(num_classes=N_CLASSES, compute_groups=COMPUTE_GROUPS) model = DummySegmentationModel(n_classes=N_CLASSES, metrics=metrics) trainer = pl.Trainer( max_epochs=10, accelerator="cpu", enable_checkpointing=False, logger=False ) trainer.fit(model, dataloader) if __name__ == "__main__": main() ```

Expected behavior

It should be possible to compute standard flavours of DiceScore within the same compute group of a MetricCollection.

Environment

Additional context

A simple fix could be to simply remove this if statement in the update() method: https://github.com/Lightning-AI/torchmetrics/blob/d528131c8e7f130c65ba62a14e12f15906d488c1/src/torchmetrics/segmentation/dice.py#L134