Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.15k stars 410 forks source link

ClasswiseWrapper and JaccardIndex confmat attribute error #2465

Closed isaaccorley closed 8 months ago

isaaccorley commented 8 months ago

šŸ› Bug

Using JaccardIndex and ClasswiseWrapper results in an error when trying to list the named_children of a MetricCollection.

To Reproduce

from torchmetrics.classification import JaccardIndex
from torchmetrics import MetricCollection
from torchmetrics.wrappers import ClasswiseWrapper

metrics = MetricCollection(
    {
        "IoU": ClasswiseWrapper(
            JaccardIndex(
                task="multiclass", num_classes=2, average="none"
            ),
            labels=["cat", "dog"],
        ),
    },
    prefix="train_",
)

for name, module in metrics.named_children():
    print(name, module)
AttributeError                            Traceback (most recent call last)
Cell In[9], [line 1](vscode-notebook-cell:?execution_count=9&line=1)
----> [1](vscode-notebook-cell:?execution_count=9&line=1) for name, module in metrics.named_children():
      [2](vscode-notebook-cell:?execution_count=9&line=2)     print(name, module)

File [/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2302](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2302), in Module.named_children(self)
   [2300](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2300) memo = set()
   [2301](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2301) for name, module in self._modules.items():
-> [2302](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2302)     if module is not None and module not in memo:
   [2303](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2303)         memo.add(module)
   [2304](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2304)         yield name, module

File [/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:928](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:928), in Metric.__hash__(self)
    [925](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:925) hash_vals = [self.__class__.__name__, id(self)]
    [927](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:927) for key in self._defaults:
--> [928](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:928)     val = getattr(self, key)
    [929](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:929)     # Special case: allow list values, so long
    [930](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:930)     # as their elements are hashable
    [931](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:931)     if hasattr(val, "__iter__") and not isinstance(val, Tensor):

File [/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:223](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:223), in ClasswiseWrapper.__getattr__(self, name)
    [220](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:220) if name in ["tp", "fp", "fn", "tn"]:
    [221](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:221)     return getattr(self.metric, name)
--> [223](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:223) return super().__getattr__(name)

File [/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1688](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1688), in Module.__getattr__(self, name)
   [1686](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1686)     if name in modules:
   [1687](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1687)         return modules[name]
-> [1688](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1688) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

AttributeError: 'ClasswiseWrapper' object has no attribute 'confmat'

Expected behavior

No error should occur.

Environment

Additional context

I found this bug by trying to train using PyTorch Lightning. During the trainer setup a list of all the named_children of the LightningModule occurs and this error is then raised.

I noticed in torchmetrics/wrappers/classwise.py that the __getattr__ has a check for ["tp", "fp", "fn", "tn"]. I think maybe confmat should be added to this?

    def __getattr__(self, name: str) -> Union[Tensor, "Module"]:
        """Get attribute from classwise wrapper."""
        # return state from self.metric
        if name in ["tp", "fp", "fn", "tn"]:
            return getattr(self.metric, name)

        return super().__getattr__(name)
github-actions[bot] commented 8 months ago

Hi! thanks for your contribution!, great first issue!

isaaccorley commented 8 months ago

Closing because I see this is fixed in #2424. Would love to see this in a release!