jeanfeydy / geomloss

Geometric loss functions between point clouds, images and volumes
MIT License
570 stars 57 forks source link

Batch support of SampleLoss #38

Open mi92 opened 3 years ago

mi92 commented 3 years ago

First of all thanks for the great library!

I just tried to run SampleLoss with batches of data and it did not work.

So, I have two tensors x,y of the same shape [batch_dim, n_points, feature_dim] and wish to compute the sinkhorn divergence between the point clouds x[0] and y[0], x[1] and y[1] in a batched way (to prevent slow loops) in order to return a tensor of shape [batch_dim].

However, when trying this out with SampleLoss() I receive a shape error.

To reproduce a minimal example I add the following collab here: https://colab.research.google.com/drive/1NqagWVIv-a8YN258NcFEBXbRBFAVuuiR?usp=sharing

NightWinkle commented 3 years ago

You have to install GeomLoss Github version using the following instruction : pip install git+https://github.com/jeanfeydy/geomloss

This problem has been fixed in a commit more recent than last release.