Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.11k stars 401 forks source link

Metric not moved to device #531

Closed cowwoc closed 3 years ago

cowwoc commented 3 years ago

🐛 Bug

Version 1.4.7

Per https://torchmetrics.readthedocs.io/en/latest/pages/overview.html#metrics-and-devices if a metric is properly defined (identified as a child of a module) then it is supposed to be automatically moved to the same device as the module. Unfortunately, in my own project this does not occur.

When I run this code:

class MyModel(LightningModule):
  def __init__(self):
    self.accuracy= Accuracy()

  def forward(self, input):
    print(f"self.device: {self.device}")
    print(f"self.accuracy.device: {self.accuracy.device}")

I get:

self.device: cuda:0
self.accuracy.device: cpu

Expected behavior

I expect the metric to be on cuda:0.

Environment

github-actions[bot] commented 3 years ago

Hi! thanks for your contribution!, great first issue!

cowwoc commented 3 years ago

Interestingly, any states registered inside the metric are moved to the right device but self.device inside the metric evaluates to the wrong value. The reason this is relevant is I've got methods in my metric that create new Tensors on self.device. I suspect other users will also expect self.device to map to the same device used by the metric's state variables.

Borda commented 3 years ago

see #340

cowwoc commented 3 years ago

@Borda The linked issue does not resolve this problem. I am already doing what it recommends.

The documentation claims that the metric's device will be updated, but it is not. I consider this a bug report not a question.

Either the documentation or implementation are wrong. Please reopen this issue.

Borda commented 3 years ago

I see, then the issue is in docs, no metrics is automatically moved unless you use it with PL logging... Mind send PR fix for it?

cowwoc commented 3 years ago

The example code found at https://torchmetrics.readthedocs.io/en/latest/pages/overview.html#metrics-and-devices does say anything about having to use PL logging. Quoting the relevant parts:

when properly defined inside a Module or LightningModule the metric will be be automatically move to the same device as the the module when using .to(device)

from torchmetrics import Accuracy, MetricCollection

class MyModule(torch.nn.Module):
    def __init__(self):
        ...
        # valid ways metrics will be identified as child modules
        self.metric1 = Accuracy()
        self.metric2 = nn.ModuleList(Accuracy())
        self.metric3 = nn.ModuleDict({'accuracy': Accuracy()})
        self.metric4 = MetricCollection([Accuracy()]) # torchmetrics build-in collection class

    def forward(self, batch):
        data, target = batch
        preds = self(data)
        ...
        val1 = self.metric1(preds, target)
        val2 = self.metric2[0](preds, target)
        val3 = self.metric3['accuracy'](preds, target)
        val4 = self.metric4(preds, target)

It sounds a bit odd that you have to use PL logging in order for a metric to get moved to the correct device... Can you point me to the relevant code in PL that moves the metric?

SkafteNicki commented 3 years ago

The problem seems to be that if you call .cuda on a parent module, it does not execute:

for m in self.modules():
    m.cuda()

but it instead calls self._apply which will call

for module in self.children():
    module._apply(fn)

this will move the metric states to the correct device, but currently the metric.device is only updated when the .cuda, .cpu, .to methods are executed.

jlehrer1 commented 1 year ago

This is not resolved in 2.0.1. Setting up a LightningModule with the init like

        self.metrics = {
            "train": {"accuracy", Accuracy(task="binary")},
            "val": {"accuracy", Accuracy(task="binary")}
        }

And logging with

    def training_step(self, batch, batch_idx):
        loss, probs = self(batch)
        self.log(f"train_loss", loss, on_epoch=True, on_step=True)
        for name, metric in self.metrics["train"].items():
            value = metric(probs, batch[1])
            self.log(f"train_{name}", value=value)

Gives the error

RuntimeError: Encountered different devices in metric calculation (see stacktrace for details). 

This could be due to the metric class not being on the same device as input. Instead of `metric=BinaryAccuracy(...)` try to do `metric=BinaryAccuracy(...).to(device)` where device corresponds to the device of the input.

Where the stacktrace errors on

  File "/home/user/micromamba/lib/python3.9/site-packages/torchmetrics/metric.py", line 390, in wrapped_func
    update(*args, **kwargs)
  File "/home/user/micromamba/lib/python3.9/site-packages/torchmetrics/classification/stat_scores.py", line 322, in update
    self._update_state(tp, fp, tn, fn)
  File "/home/user/micromamba/lib/python3.9/site-packages/torchmetrics/classification/stat_scores.py", line 70, in _update_state
    self.tp += tp
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@Borda

SkafteNicki commented 1 year ago

@jlehrer1in your case, it has to with the initialization which should use a ModuleDict:

        self.metrics = torch.nn.ModuleDict({
            "train": {"accuracy", Accuracy(task="binary")},
            "val": {"accuracy", Accuracy(task="binary")}
        })

You can read more about why here: https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metrics-and-devices

amorehead commented 12 months ago

@SkafteNicki, it doesn't seem like your latest code snippet works with PyTorch 2.0+, since a dict is not a subclass of nn.Module (that's the error PyTorch is raising for me).

SkafteNicki commented 12 months ago

Hi @amorehead, in my last example I do not refer to a standard dict but instead the special ModuleDict from torch which is a subclass of nn.Module: https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#ModuleDict

amorehead commented 12 months ago

Hi, @SkafteNicki. When I referred to the dict object in your code example, I meant the inner-dict objects assigned to the train and val keys in your outer-ModuleDict object. When I try instantiating such a data structure as you have it above, PyTorch complains that the inner-dict must be a subclass of nn.Module. This instead works if I wrap the inner-dicts within another ModuleDict though. I mention this just in case anyone else runs into this issue.

SkafteNicki commented 12 months ago

@amorehead thanks, and sorry for the confusion on my part, you are indeed correct :) Thanks for clarifying this for anyone that stumbles on this issue.

amorehead commented 12 months ago

No worries! Thanks for the original suggestion. It reminded me to organize my torchmetrics more cleanly :)