Open feiyang-k opened 1 year ago
Hi @feiyang-k !!
Indeed, scalability will be an important issue for OT tasks. There are two possible workarounds at the moment:
batch_size
option when creating a geometry object. This ensures the num_train
x num_val
matrix you mentioned is not materialized.Thanks @! These methods are marvelous! It works like magic. I had been having some vague ideas on improving the efficiency of the OT solution. The moment I saw the idea of these methods, I finally understood what ideas are about :)
The batch-size
works smoothly. I'm very interested in this 'low-rank' method. It seems to me it implements the idea similar to clustering, which has been seen as an effort to improve the scalability for computing distributional divergence metrics, but in a much more natural and elegant way. I really enjoyed reading this!
I read through the reference paper as well as the documentation.
Different from regular 'Sinkhorn' algorithm that approaches the dual problem, which is the same as other OT solvers, LRSinkhorn
directly performs the low-rank factorization to the coupling matrix, which solves the primal problem. The results are given in q, r, g, where the products will give the transport plan.
Interestingly, I'm interested in the dual solutions more than the primal. I need the derivatives to perform some exploration into the corresponding practical problem. Directly recovering the dual solution from the primal can be tricky due to numerical issues. In this case, do you have any idea what will be a good way to obtain the dual solutions?
I also tried using the auto-differentiation in JAX for that, where the dual solutions will be the derivative of the marginals a and b, but I ran into memory issues immediately. I saw similar issues in another post. But the difference here is it seems I cannot use danskin for LRSinkhorn
, is it?
More interesting, I saw it in the source code and some part of the document that there is the attribute for use_danskin
, but setting it to True
does not seem to make any difference. Is this a usable function?
Thanks @! These methods are marvelous! It works like magic. I had been having some vague ideas on improving the efficiency of the OT solution. The moment I saw the idea of these methods, I finally understood what ideas are about :)
The batch-size
works smoothly. I'm very interested in this 'low-rank' method. It seems to me it implements the idea similar to clustering, which has been seen as an effort to improve the scalability for computing distributional divergence metrics, but in a much more natural and elegant way. I really enjoyed reading this!
I read through the reference paper as well as the documentation.
Different from regular 'Sinkhorn' algorithm that approaches the dual problem, which is the same as other OT solvers, LRSinkhorn
directly performs the low-rank factorization to the coupling matrix, which solves the primal problem. The results are given in q, r, g, where the products will give the transport plan.
Interestingly, I'm interested in the dual solutions more than the primal. I need the derivatives to perform some exploration into the corresponding practical problem. Directly recovering the dual solution from the primal can be tricky due to numerical issues. In this case, do you have any idea what will be a good way to obtain the dual solutions?
I also tried using the auto-differentiation in JAX for that, where the dual solutions will be the derivative of the marginals a and b, but I ran into memory issues immediately. I saw similar issues in another post. But the difference here is it seems I cannot use danskin for LRSinkhorn
, is it?
More interesting, I saw it in the source code and some part of the document that there is the attribute for use_danskin
, but setting it to True
does not seem to make any difference. Is this a usable function?
Thanks a lot!
Hi,
In our ML tasks, the problem of scale is often defined by num_of_training_samples by num_of_validation_samples. Our GPUs currently has 40~80 GB memory per card, which could handle problem of sizes around 350k by 10k. This is fine for classic datasets such as CIFAR-10, but is still away from million scales for modern datasets (or billion scales for language corpus). Is there any standard approach to reduce memory overhead? Is there any approximation method or batch-wise methods provided in the package that help with memory usage?
Thanks!
Update: I found
batch-size
option for theonline cost computation
, which looks highly relevant!