jeanfeydy / geomloss

Geometric loss functions between point clouds, images and volumes
MIT License
599 stars 60 forks source link

check_shapes() updating l_x, α, l_y, β inside itself but the changes are not propagating to the original variables when it returns #16

Open RaghuSpaceRajan opened 4 years ago

RaghuSpaceRajan commented 4 years ago

Hi Jean,

Thanks for the library! I hope your PhD thesis writing's going fine.

I seem to have found an insidious bug. When we initialize measures as 2-D tensors with shapes like (N, 1) and (M, 1), check_shapes() tries to reshape them to be 1-D like (N,) and (M,) but these changes are not propagated to the original variables in the function calling check_shapes(). This leads to the following RunTimeError for me:

-> loss_p2 = Loss_p2(a_i, x_i, b_j, y_j)
(Pdb) c
Traceback (most recent call last):
  File "/home/rajanr/anaconda3/envs/py36/lib/python3.6/pdb.py", line 1667, in main
    pdb._runscript(mainpyfile)
  File "/home/rajanr/anaconda3/envs/py36/lib/python3.6/pdb.py", line 1548, in _runscript
    self.run(statement)
  File "/home/rajanr/anaconda3/envs/py36/lib/python3.6/bdb.py", line 434, in run
    exec(cmd, globals, locals)
  File "<string>", line 1, in <module>
  File "/home/rajanr/custom-gym-env/Wasserstein_sinkhorn_stuff.py", line 86, in <module>
    scaling=scaling, backend="multiscale")
  File "/home/rajanr/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/rajanr/anaconda3/envs/py36/lib/python3.6/site-packages/geomloss/samples_loss.py", line 237, in forward
    verbose = self.verbose )
  File "/home/rajanr/anaconda3/envs/py36/lib/python3.6/site-packages/geomloss/sinkhorn_samples.py", line 262, in sinkhorn_multiscale
    debias = debias)
  File "/home/rajanr/anaconda3/envs/py36/lib/python3.6/site-packages/geomloss/sinkhorn_divergence.py", line 162, in sinkhorn_loop
    at_x = λ * softmin(ε, C_xx, α_log + a_x/ε )  # OT(α,α)
  File "/home/rajanr/anaconda3/envs/py36/lib/python3.6/site-packages/geomloss/sinkhorn_samples.py", line 117, in softmin_multiscale
    return - ε * log_conv( x, y, f_y.view(-1,1), torch.Tensor([1/ε]).type_as(x), ranges=ranges_xy ).view(-1)
  File "/home/rajanr/anaconda3/envs/py36/lib/python3.6/site-packages/pykeops/torch/generic/generic_red.py", line 351, in __call__
    out = GenredAutograd.apply(self.formula, self.aliases, backend, self.dtype, device_id, ranges, *args)
  File "/home/rajanr/anaconda3/envs/py36/lib/python3.6/site-packages/pykeops/torch/generic/generic_red.py", line 43, in forward
    *args)
RuntimeError: [KeOps] Wrong value of the 'j' dimension 0for arg number 2 : is 9 but was 3 in previous 'j' arguments.
Uncaught exception. Entering post mortem debugging
Running 'cont' or 'step' will restart the program

The code is here: https://github.com/jeanfeydy/geomloss/blob/4e09e3bfd376d92f2bb7efdf0854a5b7c756eb0d/geomloss/samples_loss.py#L297

Would you like me to fix it and create a pull request?

Greetings, Raghu.

jeanfeydy commented 4 years ago

Hi @RaghuSpaceRajan ,

Indeed, thanks for your detailed report! This is visibly the result of an unfortunate copy-paste in a separate function... As you'll have remarked, a simple work-around is to remove the last "dummy" dimension of the weight vectors before feeding them to SamplesLoss. I'm currently a bit to busy to implement the fix and re-package the library, but I will definitely do it in January. Of course, if you want to fix it properly and create a pull request, I'll be happy to accept it :-)

Best regards, Jean

RaghuSpaceRajan commented 4 years ago

Hi @jeanfeydy,

I changed check_shapes() to return the updated variables l_x, α, l_y, β as well and created a pull request.

By the way, I'm a bit confused about the usage of cluster labels. Is it only intended to be used with the multiscale backend? That's what the line here seems to suggest: https://github.com/jeanfeydy/geomloss/blob/4e09e3bfd376d92f2bb7efdf0854a5b7c756eb0d/geomloss/samples_loss.py#L329

But line 207 suggests they can also be used for the auto backend, which may lead to any of 3 backends being used.

And then there is a related check beginning here: https://github.com/jeanfeydy/geomloss/blob/4e09e3bfd376d92f2bb7efdf0854a5b7c756eb0d/geomloss/samples_loss.py#L219 which also seems somewhat contradictory to the check in line 329.

Greetings, Raghu.