Open spinjo opened 1 month ago
Hi,
Thanks for opening this issue! I haven't tried this, and am very curious if you've found any performance benefits?
From my understanding, there do not exist fast exact OT GPU solvers, only Sinkhorn-based ones. If you look behind the hood of POT, I believe even using a backend with GPU support still solves exact OT on the CPU: https://pythonot.github.io/quickstart.html#gpu-acceleration
When implementing this package there was a bug in POT that allocated extra GPU memory when using pytorch backend (see https://github.com/PythonOT/POT/issues/523). I also found that when doing DDP POT seemed to be doing operations only on one of the GPUs (so incurring GPU-GPU transfer cost). So I didn't explore this option at the time and manually disabled the backend.
I haven't checked if they are doing anything smarter than what I implemented here. Would be happy to include in torchcfm if you file a PR at least as optional. Before making default would like to know that
Do you know of any methods that are making use of back-propping through the solver in this setting? As you mention it could be useful there.
Thanks again for your Alex
In the current implementation, tensors are moved to numpy + CPU before calling the optimal transport solver, see e.g. https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/optimal_transport.py#L88.
Since version 0.8, the POT package supports backends beyond numpy, and GPU acceleration, see https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends-on-cpu-gpu. This can speed up the OT solver especially for large batchsize, and enables new features like differentiation through the solver. Is there a reason why torchcfm uses the numpy + cpu policy?
I am successfully using the torch + GPU support of POT and am happy to file a PR if there is interest in including this in torchcfm.