jeanfeydy / geomloss

Geometric loss functions between point clouds, images and volumes
MIT License
599 stars 60 forks source link

Latest release 0.2.3 RuntimeError #9

Closed pleaseconnectwifi closed 5 years ago

pleaseconnectwifi commented 5 years ago

When I tried to run the python script

import torch

from geomloss import SamplesLoss
WD = SamplesLoss(loss='sinkhorn', p=2, blur=.05)
a = torch.randn((8,4096,3))
b = torch.randn((8,4096,3))
c= WD(a,b)
print(c)

with the latest release (0.2.3), an error occurs

Traceback (most recent call last): File "t.py", line 7, in c= WD(a,b) File "/home/miniconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, *kwargs) File "/home/miniconda3/envs/py37/lib/python3.7/site-packages/geomloss/samples_loss.py", line 237, in forward verbose = self.verbose ) File "/home/miniconda3/envs/py37/lib/python3.7/site-packages/geomloss/sinkhorn_samples.py", line 52, in sinkhorn_tensorized C_xx, C_yy, C_xy, C_yx, ε_s, ρ, debias=debias ) File "/home/miniconda3/envs/py37/lib/python3.7/site-packages/geomloss/sinkhorn_divergence.py", line 162, in sinkhorn_loop at_x = λ softmin(ε, C_xx, α_log + a_x/ε ) # OT(α,α) RuntimeError: The size of tensor a (4096) must match the size of tensor b (32768) at non-singleton dimension 1

but the same script works well with an older version (0.2.1) in the same environment.

jeanfeydy commented 5 years ago

Hi @pleaseconnectwifi ,

Thanks a lot for your feedback. This bug was due to a stupid mistake from my part, and is now fixed in 4e09e3bfd376d92f2bb7efdf0854a5b7c756eb0d . That's why setting up continuous integration at some point will be really needed... The fix will be included in the next PyPi release: until then, if you don't want to bother with a git install, you can directly apply it by hand: it's a two-character modification in the geomloss/sinkhorn_samples.py file.

Best regards,

Jean

zhangmozhe commented 4 years ago

Hi, do you have a plan to publish the revision of this bug to pip repo? Thanks.