Using JaccardIndex and ClasswiseWrapper results in an error when trying to list the named_children of a MetricCollection.
To Reproduce
from torchmetrics.classification import JaccardIndex
from torchmetrics import MetricCollection
from torchmetrics.wrappers import ClasswiseWrapper
metrics = MetricCollection(
{
"IoU": ClasswiseWrapper(
JaccardIndex(
task="multiclass", num_classes=2, average="none"
),
labels=["cat", "dog"],
),
},
prefix="train_",
)
for name, module in metrics.named_children():
print(name, module)
AttributeError Traceback (most recent call last)
Cell In[9], [line 1](vscode-notebook-cell:?execution_count=9&line=1)
----> [1](vscode-notebook-cell:?execution_count=9&line=1) for name, module in metrics.named_children():
[2](vscode-notebook-cell:?execution_count=9&line=2) print(name, module)
File [/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2302](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2302), in Module.named_children(self)
[2300](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2300) memo = set()
[2301](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2301) for name, module in self._modules.items():
-> [2302](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2302) if module is not None and module not in memo:
[2303](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2303) memo.add(module)
[2304](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2304) yield name, module
File [/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:928](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:928), in Metric.__hash__(self)
[925](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:925) hash_vals = [self.__class__.__name__, id(self)]
[927](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:927) for key in self._defaults:
--> [928](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:928) val = getattr(self, key)
[929](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:929) # Special case: allow list values, so long
[930](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:930) # as their elements are hashable
[931](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:931) if hasattr(val, "__iter__") and not isinstance(val, Tensor):
File [/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:223](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:223), in ClasswiseWrapper.__getattr__(self, name)
[220](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:220) if name in ["tp", "fp", "fn", "tn"]:
[221](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:221) return getattr(self.metric, name)
--> [223](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:223) return super().__getattr__(name)
File [/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1688](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1688), in Module.__getattr__(self, name)
[1686](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1686) if name in modules:
[1687](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1687) return modules[name]
-> [1688](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1688) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'ClasswiseWrapper' object has no attribute 'confmat'
Expected behavior
No error should occur.
Environment
TorchMetrics: 1.3.2
Python: 3.10.11
PyTorch: 2.2.1
Any other relevant information such as OS (e.g., Linux): Ubuntu
Additional context
I found this bug by trying to train using PyTorch Lightning. During the trainer setup a list of all the named_children of the LightningModule occurs and this error is then raised.
I noticed in torchmetrics/wrappers/classwise.py that the __getattr__ has a check for ["tp", "fp", "fn", "tn"]. I think maybe confmat should be added to this?
def __getattr__(self, name: str) -> Union[Tensor, "Module"]:
"""Get attribute from classwise wrapper."""
# return state from self.metric
if name in ["tp", "fp", "fn", "tn"]:
return getattr(self.metric, name)
return super().__getattr__(name)
š Bug
Using
JaccardIndex
andClasswiseWrapper
results in an error when trying to list thenamed_children
of aMetricCollection
.To Reproduce
Expected behavior
No error should occur.
Environment
Additional context
I found this bug by trying to train using PyTorch Lightning. During the trainer setup a list of all the
named_children
of the LightningModule occurs and this error is then raised.I noticed in
torchmetrics/wrappers/classwise.py
that the__getattr__
has a check for["tp", "fp", "fn", "tn"]
. I think maybeconfmat
should be added to this?