Open RaghuSpaceRajan opened 4 years ago
Hi @RaghuSpaceRajan ,
Indeed, thanks for your detailed report!
This is visibly the result of an unfortunate copy-paste in a separate function... As you'll have remarked, a simple work-around is to remove the last "dummy" dimension of the weight vectors before feeding them to SamplesLoss
. I'm currently a bit to busy to implement the fix and re-package the library, but I will definitely do it in January. Of course, if you want to fix it properly and create a pull request, I'll be happy to accept it :-)
Best regards, Jean
Hi @jeanfeydy,
I changed check_shapes() to return the updated variables l_x, α, l_y, β as well and created a pull request.
By the way, I'm a bit confused about the usage of cluster labels. Is it only intended to be used with the multiscale
backend? That's what the line here seems to suggest: https://github.com/jeanfeydy/geomloss/blob/4e09e3bfd376d92f2bb7efdf0854a5b7c756eb0d/geomloss/samples_loss.py#L329
But line 207 suggests they can also be used for the auto
backend, which may lead to any of 3 backends being used.
And then there is a related check beginning here: https://github.com/jeanfeydy/geomloss/blob/4e09e3bfd376d92f2bb7efdf0854a5b7c756eb0d/geomloss/samples_loss.py#L219 which also seems somewhat contradictory to the check in line 329.
Greetings, Raghu.
Hi Jean,
Thanks for the library! I hope your PhD thesis writing's going fine.
I seem to have found an insidious bug. When we initialize measures as 2-D tensors with shapes like (N, 1) and (M, 1), check_shapes() tries to reshape them to be 1-D like (N,) and (M,) but these changes are not propagated to the original variables in the function calling check_shapes(). This leads to the following RunTimeError for me:
The code is here: https://github.com/jeanfeydy/geomloss/blob/4e09e3bfd376d92f2bb7efdf0854a5b7c756eb0d/geomloss/samples_loss.py#L297
Would you like me to fix it and create a pull request?
Greetings, Raghu.