Using a torchmetrics.MetricCollection with DiceScore(average="weighted") and other DiceScore metrics in the same compute group will raise the following error ValueError: No samples to concatenate, because the supportstate remains an empty list due to the shared state within the compute group.
To Reproduce
The error can be reproduced with the following code. If you disable the compute groups by setting the COMPUTE_GROUPS variable to False it runs without errors.
đ Bug
Using a
torchmetrics.MetricCollection
withDiceScore(average="weighted")
and otherDiceScore
metrics in the same compute group will raise the following errorValueError: No samples to concatenate
, because thesupport
state remains an empty list due to the shared state within the compute group.To Reproduce
The error can be reproduced with the following code. If you disable the compute groups by setting the
COMPUTE_GROUPS
variable toFalse
it runs without errors.Code sample
```python import torch import pytorch_lightning as pl from torch.utils.data import Dataset, DataLoader import torchmetrics from torchmetrics import segmentation class DummySegmentationDataset(Dataset): def __init__(self, n_samples=16, n_classes=3): self.n_samples = n_samples self.n_classes = n_classes def __len__(self): return self.n_samples def __getitem__(self, idx): target = torch.full((self.n_classes, 128, 128), 0, dtype=torch.int8) preds = torch.full((self.n_classes, 128, 128), 0, dtype=torch.int8) return preds, target class MulticlassSegmentationMetrics(torchmetrics.MetricCollection): """Default metrics for multi-class semantic segmentation tasks.""" def __init__( self, num_classes: int, include_background: bool = False, prefix: str | None = None, postfix: str | None = None, compute_groups: bool = False, ) -> None: super().__init__( metrics={ "DiceScore (micro)": segmentation.DiceScore( num_classes=num_classes, include_background=include_background, average="micro", ), "DiceScore (weighted)": segmentation.DiceScore( num_classes=num_classes, include_background=include_background, average="weighted", ), }, prefix=prefix, postfix=postfix, compute_groups=compute_groups, ) class DummySegmentationModel(pl.LightningModule): def __init__(self, n_classes, metrics): super().__init__() self.n_classes = n_classes self.metrics = metrics def training_step(self, batch, batch_idx): preds, target = batch self.metrics.update(preds, target) return {"loss": torch.tensor(0.0, requires_grad=True)} # Dummy loss def on_train_epoch_end(self): metrics = self.metrics.compute() self.log_dict(metrics) print(f"Epoch {self.current_epoch} metrics:", metrics) def configure_optimizers(self): return torch.optim.SGD([torch.tensor([0.0], requires_grad=True)], lr=0.01) def main(): COMPUTE_GROUPS = True # When setting this to False, it works as expected N_SAMPLES, N_CLASSES = 16, 3 BATCH_SIZE = 4 dataset = DummySegmentationDataset(n_samples=N_SAMPLES, n_classes=N_CLASSES) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False) metrics = MulticlassSegmentationMetrics(num_classes=N_CLASSES, compute_groups=COMPUTE_GROUPS) model = DummySegmentationModel(n_classes=N_CLASSES, metrics=metrics) trainer = pl.Trainer( max_epochs=10, accelerator="cpu", enable_checkpointing=False, logger=False ) trainer.fit(model, dataloader) if __name__ == "__main__": main() ```Expected behavior
It should be possible to compute standard flavours of
DiceScore
within the same compute group of aMetricCollection
.Environment
1.6.0
3.11.10
Additional context
A simple fix could be to simply remove this
if
statement in theupdate()
method: https://github.com/Lightning-AI/torchmetrics/blob/d528131c8e7f130c65ba62a14e12f15906d488c1/src/torchmetrics/segmentation/dice.py#L134