thanks for the great wrapper class! It is exactly, what I was in the need of, however mypy is ruining all the fun :/
The MultitaskWrapper's update method has typehints for dict[str, Tensor] which works for the classification and regression test cases, but showers mypy errors when using it with a detection metric, which expects list[dict[str, Tensor]].
The following test snippet with classification - detection combination for multitask cases.
🐛 Bug
Hej @SkafteNicki,
thanks for the great wrapper class! It is exactly, what I was in the need of, however mypy is ruining all the fun :/
The
MultitaskWrapper
'supdate
method has typehints fordict[str, Tensor]
which works for the classification and regression test cases, but showers mypy errors when using it with a detection metric, which expectslist[dict[str, Tensor]]
.The following test snippet with classification - detection combination for multitask cases.
To Reproduce
Code sample
```python from torchmetrics.classification import Accuracy from torchmetrics.detection import MeanAveragePrecision from torchmetrics.wrappers import MultitaskWrapper def test_multitask_wrapper() -> None: """Test the instantiation of multi-task metrics.""" accuracy = Accuracy(task="binary") map = MeanAveragePrecision(iou_type="bbox") metrics = MultitaskWrapper( task_metrics={"classification": accuracy, "segmentation": map} ) class_preds = torch.tensor([0, 1, 0, 1], dtype=torch.float32) class_targets = torch.tensor([1, 1, 0, 0]) bbox_preds = [ { "boxes": torch.tensor([[258.0, 41.0, 606.0, 285.0]]), "scores": torch.tensor([0.536]), "labels": torch.tensor([0]), } ] bbox_targets = [ { "boxes": torch.tensor([[214.0, 41.0, 562.0, 285.0]]), "scores": torch.tensor([0.536]), "labels": torch.tensor([0]), } ] preds = { "classification": class_preds, "segmentation": bbox_preds, } targets = { "classification": class_targets, "segmentation": bbox_targets, } metrics.update(preds, targets) metrics.forward(preds, targets) ``` #### Error message ```shell Argument 1 to "update" of "MultitaskWrapper" has incompatible type "dict[str, object]"; expected "dict[str, Tensor]" [arg-type] Argument 2 to "update" of "MultitaskWrapper" has incompatible type "dict[str, object]"; expected "dict[str, Tensor]" [arg-type] ```Expected behavior
Not getting a mypy error for detection target format
list[dict[str, Tensor]]
.Environment