jrzaurin / pytorch-widedeep

A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
Apache License 2.0
1.3k stars 191 forks source link

Metrics not on same device as Trainer/model #231

Closed mandeep-starling closed 1 month ago

mandeep-starling commented 2 months ago

When running on a CUDA device, the following code:

from pytorch_widedeep import Trainer
from torchmetrics.classification import BinaryAUROC, BinaryPrecision, BinaryRecall

trainer = Trainer(model, objective="binary", metrics=[BinaryAUROC(), BinaryPrecision(), BinaryRecall()])

Returns error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
...
  File "/opt/conda/lib/python3.10/site-packages/pytorch_widedeep/utils/general_utils.py", line 12, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_widedeep/training/trainer.py", line 524, in fit
    train_score, train_loss = self._train_step(
  File "/opt/conda/lib/python3.10/site-packages/pytorch_widedeep/training/trainer.py", line 1000, in _train_step
    score = self._get_score(y_pred, y)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_widedeep/training/trainer.py", line 1049, in _get_score
    score = self.metric(torch.sigmoid(y_pred), y)
  File "/opt/conda/lib/python3.10/site-packages/pytorch_widedeep/metrics.py", line 40, in __call__
    metric.update(y_pred, y_true.int())  # type: ignore[attr-defined]
  File "/opt/conda/lib/python3.10/site-packages/torchmetrics/metric.py", line 486, in wrapped_func
    raise RuntimeError(
RuntimeError: Encountered different devices in metric calculation (see stacktrace for details). This could be due to the metric class not being on the same device as input. Instead of `metric=BinaryPrecision(...)` try to do `metric=BinaryPrecision(...).to(device)` where device corresponds to the device of the input.

I could manually try to do this on my end, by getting the device from here and setting it. But I wonder if this should work out of the box?

I read some discussion in pytorch-lightning here and here where it seems there is a way for metrics to be automatically moved to the same device as the model?

jrzaurin commented 2 months ago

@5uperpalo I am traveling. If you have a sec maybe you could have a look?

Otherwise ill have a look in the next coming days

jrzaurin commented 2 months ago

Hey @mandeep-starling

there are a number of ways we could tackle this.

One is, as you suggest, getting the device from the Trainer and pass it to the metrics like

trainer = Trainer(model, objective="binary", metrics=[BinaryAUROC().to(device), ...]

I can also set it up internally, but that involves some complications depending whether a user decides to use the nn.Module lightning metrics or their functional version.

I will think about it. I might branch out and you could try see what you prefer.