Open oleg-kachan opened 11 months ago
The exact ot.emd2 solver uses a compiled C++ solver so everything needs to be done on CPU and converted to numpy which is why it cannot be used with vmap that require only pytorch operation. We might be able to make sinkhorn compatile in the future but emd2 cannot (it is highly non vectorizable also so even if this was possible there would be no gain from batching).
Describe the bug
As my datapoints are empirical distributions I want to use the Wasserstein distance as a loss function over a batch of shape
(n_batch, n_points, dimension)
. Standard way to make functions that take a batch as an input istorch.vmap
, yet I get the error described below.To Reproduce
Error
Expected behavior
Make POT distance functions batchable via
torch.vmap
, seems Sinkhorn distance code has this problem too.