Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.14k stars 408 forks source link

Typing error for detection metrics in MultitaskWrapper #2398

Closed ponderbb closed 8 months ago

ponderbb commented 9 months ago

🐛 Bug

Hej @SkafteNicki,

thanks for the great wrapper class! It is exactly, what I was in the need of, however mypy is ruining all the fun :/

The MultitaskWrapper's update method has typehints for dict[str, Tensor] which works for the classification and regression test cases, but showers mypy errors when using it with a detection metric, which expects list[dict[str, Tensor]].

The following test snippet with classification - detection combination for multitask cases.

To Reproduce

Code sample ```python from torchmetrics.classification import Accuracy from torchmetrics.detection import MeanAveragePrecision from torchmetrics.wrappers import MultitaskWrapper def test_multitask_wrapper() -> None: """Test the instantiation of multi-task metrics.""" accuracy = Accuracy(task="binary") map = MeanAveragePrecision(iou_type="bbox") metrics = MultitaskWrapper( task_metrics={"classification": accuracy, "segmentation": map} ) class_preds = torch.tensor([0, 1, 0, 1], dtype=torch.float32) class_targets = torch.tensor([1, 1, 0, 0]) bbox_preds = [ { "boxes": torch.tensor([[258.0, 41.0, 606.0, 285.0]]), "scores": torch.tensor([0.536]), "labels": torch.tensor([0]), } ] bbox_targets = [ { "boxes": torch.tensor([[214.0, 41.0, 562.0, 285.0]]), "scores": torch.tensor([0.536]), "labels": torch.tensor([0]), } ] preds = { "classification": class_preds, "segmentation": bbox_preds, } targets = { "classification": class_targets, "segmentation": bbox_targets, } metrics.update(preds, targets) metrics.forward(preds, targets) ``` #### Error message ```shell Argument 1 to "update" of "MultitaskWrapper" has incompatible type "dict[str, object]"; expected "dict[str, Tensor]" [arg-type] Argument 2 to "update" of "MultitaskWrapper" has incompatible type "dict[str, object]"; expected "dict[str, Tensor]" [arg-type] ```

Expected behavior

Not getting a mypy error for detection target format list[dict[str, Tensor]].

Environment

github-actions[bot] commented 9 months ago

Hi! thanks for your contribution!, great first issue!

Borda commented 9 months ago

Not getting a mypy error for detection target format list[dict[str, Tensor]].

That looks good to me, @SkafteNicki?