ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
529 stars 82 forks source link

Is any method implemented in ott to help reduce memory overhead? (Update: I found `batch-size`option) #422

Open feiyang-k opened 1 year ago

feiyang-k commented 1 year ago

Hi,

In our ML tasks, the problem of scale is often defined by num_of_training_samples by num_of_validation_samples. Our GPUs currently has 40~80 GB memory per card, which could handle problem of sizes around 350k by 10k. This is fine for classic datasets such as CIFAR-10, but is still away from million scales for modern datasets (or billion scales for language corpus). Is there any standard approach to reduce memory overhead? Is there any approximation method or batch-wise methods provided in the package that help with memory usage?

Thanks!

Update: I found batch-sizeoption for the online cost computation, which looks highly relevant!

marcocuturi commented 1 year ago

Hi @feiyang-k !!

Indeed, scalability will be an important issue for OT tasks. There are two possible workarounds at the moment:

feiyang-k commented 1 year ago

Thanks @! These methods are marvelous! It works like magic. I had been having some vague ideas on improving the efficiency of the OT solution. The moment I saw the idea of these methods, I finally understood what ideas are about :)

The batch-size works smoothly. I'm very interested in this 'low-rank' method. It seems to me it implements the idea similar to clustering, which has been seen as an effort to improve the scalability for computing distributional divergence metrics, but in a much more natural and elegant way. I really enjoyed reading this!

I read through the reference paper as well as the documentation.

Different from regular 'Sinkhorn' algorithm that approaches the dual problem, which is the same as other OT solvers, LRSinkhorn directly performs the low-rank factorization to the coupling matrix, which solves the primal problem. The results are given in q, r, g, where the products will give the transport plan.

Interestingly, I'm interested in the dual solutions more than the primal. I need the derivatives to perform some exploration into the corresponding practical problem. Directly recovering the dual solution from the primal can be tricky due to numerical issues. In this case, do you have any idea what will be a good way to obtain the dual solutions?

I also tried using the auto-differentiation in JAX for that, where the dual solutions will be the derivative of the marginals a and b, but I ran into memory issues immediately. I saw similar issues in another post. But the difference here is it seems I cannot use danskin for LRSinkhorn, is it?

More interesting, I saw it in the source code and some part of the document that there is the attribute for use_danskin, but setting it to True does not seem to make any difference. Is this a usable function?

feiyang-k commented 1 year ago

Thanks @! These methods are marvelous! It works like magic. I had been having some vague ideas on improving the efficiency of the OT solution. The moment I saw the idea of these methods, I finally understood what ideas are about :)

The batch-size works smoothly. I'm very interested in this 'low-rank' method. It seems to me it implements the idea similar to clustering, which has been seen as an effort to improve the scalability for computing distributional divergence metrics, but in a much more natural and elegant way. I really enjoyed reading this!

I read through the reference paper as well as the documentation.

Different from regular 'Sinkhorn' algorithm that approaches the dual problem, which is the same as other OT solvers, LRSinkhorn directly performs the low-rank factorization to the coupling matrix, which solves the primal problem. The results are given in q, r, g, where the products will give the transport plan.

Thanks a lot!