Closed Farbodch closed 2 months ago
Hi @Farbodch sorry about this, we did not expect this usage, but this is indeed very valid, specially since we explored this here :https://proceedings.neurips.cc/paper_files/paper/2022/hash/2d69e771d9f274f7c624198ea74f5b98-Abstract-Conference.html
essentially this is just an API bug, and should work all right, it's just that we shouldn't try to pull the .f
and .g
potentials from LR sinkhorn output in that case.
How urgent is this? If you need this for ICML let us know, we can come with a slightly dumb patch.
Hi @marcocuturi
Thank you for the reply! It’s not super urgent/it won’t make it to ICML, but I would greatly appreciate any updates!
Hi @Farbodch , it's finally implemented!
closed via #568
Describe the bug When
is called (through
jax.jit(jax.value_and_grad(...))
), the expected result is for the low-rank sinkhorn (LRSinkhorn) solver to be used. However, an AttributeError is thrown instead:Full Error Output
To Reproduce Relevant code snippet used (to reproduce the behavior):
Additional information (please complete the following information): Overall script works as expected when
sink_div
(as stated above) is replaced with direct lowrank sinkhorn solver (sink_lr_cost
below):System/Environment information