Closed BramVanroy closed 10 months ago
This PR is currently not ready for integration. I have done some work on making the metric RegressionMetrics(Metric)
(not the model) compatible with distributed computing and this seems to work for training. However, I can't get this to work with prediction.
As far as I can tell, the issue is that the metrics are not gathered correctly over the different processes. So in this piece of code, RegressionMetrics
should get dist_sync_on_step=True
but only in distributed settings.
PyTorch Lightning has so many hoops (subclasses, properties) to jump through that I lost my patience to figure out how we can do something if multi_gpu
or if distributed
. For someone who knows PyTorch Lightning well, this is perhaps an easy fix so feel free to chime in.
A test case for distributed scenario has also been added. Note that for dev'ing this, I updated torchmetrics
to the newest version to be sure to avoid underlying issues.
Currently, the CometModel has lambda functions defined. These cannot be pickled and therefore multiprocessing (~multi-GPU training) is not possible.
This PR replaces the lambda functions with partials, that should be picklable.
closes #159