PythonOT / POT

POT : Python Optimal Transport
https://PythonOT.github.io/
MIT License
2.38k stars 497 forks source link

Can not batch `ot.emd2` via `torch.vmap` #532

Open oleg-kachan opened 11 months ago

oleg-kachan commented 11 months ago

Describe the bug

As my datapoints are empirical distributions I want to use the Wasserstein distance as a loss function over a batch of shape (n_batch, n_points, dimension). Standard way to make functions that take a batch as an input is torch.vmap, yet I get the error described below.

To Reproduce

def wasserstein2_loss(X, Y):
    n, m = X.shape[0], Y.shape[0]
    a = torch.ones(n) / n
    b = torch.ones(m) / m
    M = ot.dist(X, Y, metric="sqeuclidean")
    return ot.emd2(a, b, M) ** 0.5

wasserstein2_loss_batched = torch.vmap(wasserstein2_loss)
W2 = wasserstein2_loss_batched(X, Y) # should be an array of shape `n_batch`

Error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 W2 = wasserstein2_loss_batched(X, Y)

File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:434, in vmap.<locals>.wrapped(*args, **kwargs)
    430     return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
    431                          args_spec, out_dims, randomness, **kwargs)
    433 # If chunk_size is not specified.
--> 434 return _flat_vmap(
    435     func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
    436 )

File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:39, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
     36 @functools.wraps(f)
     37 def fn(*args, **kwargs):
     38     with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39         return f(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:619, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    617 try:
    618     batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 619     batched_outputs = func(*batched_inputs, **kwargs)
    620     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    621 finally:

Cell In[4], line 13, in wasserstein2_loss(X, Y)
     11 b = torch.ones(m) / m
     12 M = ot.dist(X, Y, metric="sqeuclidean")
---> 13 return wasserstein_distance(a, b, M) ** 0.5

File /usr/local/lib/python3.10/dist-packages/ot/lp/__init__.py:488, in emd2(a, b, M, processes, numItermax, log, return_matrix, center_dual, numThreads, check_marginals)
    485 nx = get_backend(M0, a0, b0)
    487 # convert to numpy
--> 488 M, a, b = nx.to_numpy(M, a, b)
    490 a = np.asarray(a, dtype=np.float64)
    491 b = np.asarray(b, dtype=np.float64)

File /usr/local/lib/python3.10/dist-packages/ot/backend.py:207, in Backend.to_numpy(self, *arrays)
    205     return self._to_numpy(arrays[0])
    206 else:
--> 207     return [self._to_numpy(array) for array in arrays]

File /usr/local/lib/python3.10/dist-packages/ot/backend.py:207, in <listcomp>(.0)
    205     return self._to_numpy(arrays[0])
    206 else:
--> 207     return [self._to_numpy(array) for array in arrays]

File /usr/local/lib/python3.10/dist-packages/ot/backend.py:1763, in TorchBackend._to_numpy(self, a)
   1761 if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
   1762     return np.array(a)
-> 1763 return a.cpu().detach().numpy()

RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

Expected behavior

Make POT distance functions batchable via torch.vmap, seems Sinkhorn distance code has this problem too.

rflamary commented 11 months ago

The exact ot.emd2 solver uses a compiled C++ solver so everything needs to be done on CPU and converted to numpy which is why it cannot be used with vmap that require only pytorch operation. We might be able to make sinkhorn compatile in the future but emd2 cannot (it is highly non vectorizable also so even if this was possible there would be no gain from batching).