Closed mandeep-starling closed 1 month ago
@5uperpalo I am traveling. If you have a sec maybe you could have a look?
Otherwise ill have a look in the next coming days
Hey @mandeep-starling
there are a number of ways we could tackle this.
One is, as you suggest, getting the device from the Trainer
and pass it to the metrics like
trainer = Trainer(model, objective="binary", metrics=[BinaryAUROC().to(device), ...]
I can also set it up internally, but that involves some complications depending whether a user decides to use the nn.Module lightning metrics or their functional version.
I will think about it. I might branch out and you could try see what you prefer.
When running on a CUDA device, the following code:
Returns error:
I could manually try to do this on my end, by getting the device from here and setting it. But I wonder if this should work out of the box?
I read some discussion in pytorch-lightning here and here where it seems there is a way for metrics to be automatically moved to the same device as the model?