fwilliams / scalable-pytorch-sinkhorn

121 stars 7 forks source link

Fast, Memory-Efficient Approximate Wasserstein Distances

This repository contains PyTorch code to compute fast p-Wasserstein distances between d-dimensional point clouds using the Sinkhorn Algorithm.

This implementation uses linear memory overhead and is stable in float32, runs on the GPU, and fully differentiable.

This shows an example of the correspondences between two shapes found by computing the Sinkhorn distance on 200k input points:

How to use:

  1. Copy the sinkhorn.py file in this repository to your PyTorch codebase.
  2. pip install pykeops tqdm
  3. Import from sinkhorn import sinkhorn and use the sinkhorn function!

Running the example code

Look at example_basic.py for a basic example and example_optimize.py for an example of how to use Sinkhorn in your optimization

NOTE: To run the examples, you need to first run

pip install pykeops tqdm numpy scipy polyscope point-cloud-utils

sinkhorn function documentation

sinkhorn(x: torch.Tensor, y: torch.Tensor, p: float = 2,
             w_x: Union[torch.Tensor, None] = None,
             w_y: Union[torch.Tensor, None] = None,
             eps: float = 1e-3,
             max_iters: int = 100, stop_thresh: float = 1e-5,
             verbose=False)

Computes the Entropy-Regularized p-Wasserstein Distance between two d-dimensional point clouds using the Sinkhorn scaling algorithm. This code will use the GPU if you pass in GPU tensors. Note that this algorithm can be backpropped through (though this may be slow if using many iterations).

Arguments:

Returns:

A triple (d, corrs_x_to_y, corr_y_to_x) where: