Open jaanli opened 3 months ago
Hey @jaanli
I'm converting this to a feature request because it isn't a bug. The docs for self.log
state that you can log a float, tensor, or TorchMetrics Metric. The ValueError you get is directly from Lightning, informing the user that you can't log arbitrary objects, which is correct.
If you want to log a TorchEval Metric, I suggest you compute it normally and then pass the value (scalar tensor) to the self.log()
call. Regarding the device management, that's just a general advice that Lightning gives. If you see the need, feel free to force your metric modules on CPU to perform the computations.
Super helpful, thanks so much @awaelchli ! Will see if we have the engineering support to fix this compatibility issue :)
Bug description
Half of our team uses vanilla PyTorch and the other half uses (PyTorch) Lightning. We need to use several custom metrics for our use case, and need fine-grained control over which device these metrics are on (and for use with fully-sharded data parallel language models).
However, it seems like we are blocked on using Lightning and these custom metrics, potentially because Lightning recommends one delete any
model.to(device)
calls: https://lightning.ai/docs/pytorch/stable/accelerators/accelerator_prepare.html#delete-cuda-or-to-callsAnd custom metrics, such as those that require
torcheval
, sometimes need to callmetric.to(device)
as some computation can only happen on GPU and some only on CPU. One example would be binned precision recall curves in extreme multi-label classification, where it is computationally infeasible (takes too long) on CPU, but GPU memory is exhausted for large language models -- so metrics and intermediate steps must be copied onto and off of GPU memory during training for early stopping.Here's one example:
https://pytorch.org/torcheval/main/_modules/torcheval/metrics/classification/binned_precision_recall_curve.html#MulticlassBinnedPrecisionRecallCurve
To reproduce, based on a vanilla PyTorch example: https://raw.githubusercontent.com/pytorch/examples/main/mnist/main.py
The specific error with custom metrics is:
Any advice on how to include custom metrics in Lightning from torcheval or otherwise, that require passing intermediate states on or off of GPUs?
Or must such evaluation for early stopping happen outside Lightning modules?
Thank you - any advice appreciated on the canonical ways to solve this problem - we can't be the only ones running into this blocker on Lightning usage with standard PyTorch tools...
What version are you seeing the problem on?
v2.2
How to reproduce the bug
This runs fine.
With PyTorch Lightning:
ValueError:
self.log(val_acc, <torcheval.metrics.classification.accuracy.MulticlassAccuracy object at 0x16fc1ed10>)
was called, butMulticlassAccuracy
values cannot be logged