JuliaOptimalTransport / OptimalTransport.jl

Optimal transport algorithms for Julia
https://juliaoptimaltransport.github.io/OptimalTransport.jl/dev
MIT License
94 stars 8 forks source link

Refactor `sinkhorn` and `sinkhorn2` #100

Closed devmotion closed 3 years ago

devmotion commented 3 years ago

I apologize in advance for the amount of changes in this PR but I didn't manage to keep it smaller without leaving the package in a half-broken state.

Basically, this PR unifies sinkhorn, sinkhorn_stabilized, and sinkhorn_stabilized_epsscaling with a single sinkhorn(mu, nu, C, eps, algorithm; kwargs...) syntax (the name sinkhorn seems a bit redundant, see comments below). More details:

The name sinkhorn(...) seems a bit redundant since it is already implied by the name of the algorithms (apart from the default choice). However, I did not want to make too many changes in the PR. One approach would be to define a type for entropically regularized OT problems (this could also be used in the SinkhornSolver struct to group together source and target marginals, cost matrix and regularization parameter) and then to use the solve approach (which is already used internally now without depending on CommonSolve). Then one could e.g. still use SinkhornGibbs() as the default algorithm for such problems but also write solve(prob, SinkhornStabilized(); kwargs...) etc. One could even keep sinkhorn(mu, nu, C, eps, alg; kwargs...) = solve(EntropicOTProblem(mu, nu, C, eps), alg; kwargs...).

coveralls commented 3 years ago

Pull Request Test Coverage Report for Build 928044999

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details


Changes Missing Coverage Covered Lines Changed/Added Lines %
src/entropic/sinkhorn_barycenter.jl 21 23 91.3%
src/entropic/sinkhorn_epsscaling.jl 40 42 95.24%
<!-- Total: 300 304 98.68% -->
Totals Coverage Status
Change from base Build 928044156: 1.6%
Covered Lines: 540
Relevant Lines: 561

💛 - Coveralls
davibarreira commented 3 years ago

I apologize in advance for the amount of changes in this PR but I didn't manage to keep it smaller without leaving the package in a half-broken state.

Basically, this PR unifies sinkhorn, sinkhorn_stabilized, and sinkhorn_stabilized_epsscaling with a single sinkhorn(mu, nu, C, eps, algorithm; kwargs...) syntax (the name sinkhorn seems a bit redundant, see comments below). More details:

  • sinkhorn(mu, nu, C, eps; kwargs...) = sinkhorn(mu, nu, C, eps, SinkhornGibbs(); kwargs...), i.e., without algorithm still the standard Sinkhorn algorithm is used as before
  • sinkhorn2 allows to specify an algorithm as well (default: SinkhornGibbs()), and hence supports all Sinkhorn variants (and also supports the additional regularization term for all algorithms)
  • SinkhornEpsilonScaling works both with the regular Sinkhorn algorithm and the stabilized version
  • SinkhornStabilized supports batches of histograms as well (requires NNlib.batched_mul! for performance and GPU support)
  • existing methods and keyword arguments are deprecated, i.e., changes should be non-breaking

The name sinkhorn(...) seems a bit redundant since it is already implied by the name of the algorithms (apart from the default choice). However, I did not want to make too many changes in the PR. One approach would be to define a type for entropically regularized OT problems (this could also be used in the SinkhornSolver struct to group together source and target marginals, cost matrix and regularization parameter) and then to use the solve approach (which is already used internally now without depending on CommonSolve). Then one could e.g. still use SinkhornGibbs() as the default algorithm for such problems but also write solve(prob, SinkhornStabilized(); kwargs...) etc. One could even keep sinkhorn(mu, nu, C, eps, alg; kwargs...) = solve(EntropicOTProblem(mu, nu, C, eps), alg; kwargs...).

Great PR! I still have to review the code you submitted, but in terms of design, I have a slightly different proposal related to the "problem" and "solver" paradigm. I'd say that, the "problem" would always be the original Optimal Transport problem (i.e., mu to nu with cost c and no regularization), and the "solver" would incorporate the addition of regularization or whatever it uses. Hence, instead of solve(EntropicOTProblem(mu, nu, C, eps), alg; kwargs...), it would be solve(OTProblem(mu, nu, C), alg; kwargs...). This proposal is more grounded in the theoretical consistency than perhaps programming wise. The idea of using sinkhorn would be (in theory) to approximately solve the original OT. Hence, although it solves a different problem, I'm arguing that it's still better understood as part of the solver.

devmotion commented 3 years ago

Interesting, I'll think about your comment. To me the entropically regularized OT problem is really a separate problem and not just induced by the approximation algorithm. Of course, it is derived from the exact OT problem but it always seemed to me a separate entity. Similar to like in other fields regularized problems are considered separately, e.g., least squares problem with Tikhonov regularization, Lasso or elastic net that have specific mathematical properties.

Intuitively, I assume that making the regularization part of the algorithm could lead to problems or inconsistencies when composing different algorithms such as epsilon scaling and different Sinkhorn algorithms - then the regularization would only be a parameter of the subalgorithm but not the main algorithm.

davibarreira commented 3 years ago

Interesting, I'll think about your comment. To me the entropically regularized OT problem is really a separate problem and not just induced by the approximation algorithm. Of course, it is derived from the exact OT problem but it always seemed to me a separate entity. Similar to like in other fields regularized problems are considered separately, e.g., least squares problem with Tikhonov regularization, Lasso or elastic net that have specific mathematical properties.

Intuitively, I assume that making the regularization part of the algorithm could lead to problems or inconsistencies when composing different algorithms such as epsilon scaling and different Sinkhorn algorithms - then the regularization would only be a parameter of the subalgorithm but not the main algorithm.

I think this approach you described is actually the more standard way. But the reason why I tend to think like this in OT is that mostly the regularization is added not to penalize somehow the errors, but just make the original problem solvable. The use of sinkhorn instead of the unregularized problem is usually just for computational reasons (at least in ML, it's the only reason people choose it). If we look at a linear regression, the addition of a regularization is not to "facilitate" the solution of the original regression, but just to emphasize a different aspect. So I see a fundamental distinction on the motivation.

But this would be just a way of thinking things. So if it makes coding problematic, I'd say we should just stick with your design. I don't see that much of a gain.

zsteve commented 3 years ago

I also am yet to go through the code, but I think about the above comments I'd also lean towards what is implemented currently. The reason that I am not too keen with the approach of viewing the regularisation as an 'add-on' to unregularised OT is that the scaling algorithms for solution are highly specific, i.e. the choice of regulariser dictates the available algorithms and there's very little scope for mixing and matching regularisers and algorithms.

I think in the future it might be useful to have a more general solver that can deal with arbitrary regulariser, using some gradient methods like L-BFGS, similar to the ot.smooth module in POT.

devmotion commented 3 years ago

I'd also lean towards what is implemented currently.

Actually none of the problem types or drastic interface changes is implemented in this PR, I just thought it would be worth considering it since sinkhorn(..., SinkhornStabilized()) etc. seemed a bit redundant and already in the internal solver it seemed natural to group together the other arguments (source and target marginals, cost matrix, and regularization parameter).

If there are problems types for entropically regularized OT, quadratically regularized OT etc. (also a single type that allows to dispatch on the type of regularization would be sufficient), then it would be immediately clear that solve(entropic_prob) should use a default algorithm for solving the entropically regularized OT problem whereas solve(quadratic_prob) should use one for quadratically regularized OT problems. Whereas it is not possible to deduce from the arguments mu, nu, C, and eps the type of the regularization - either one needs specific function names for each problem type (similar to the current approach which however also mixes it with different algorithms) or an additional argument that allows to specify the type of the regularization (but in this case arguably it seems simpler and more direct to indicate this with the problem type).

I think in the future it might be useful to have a more general solver that can deal with arbitrary regulariser, using some gradient methods like L-BFGS, similar to the ot.smooth module in POT.

The nice thing is that one could define for each algorithm to which problem types it can be applied and to show an error message if the users does use an incorrect algorithm for the provided problem. Or one couldn't allow it at all by using a type structure on the algorithms and problems (seems a bit less flexible though). And similarly for some algorithms one could define that they may be applied to OT problems with arbitrary regulariser.

devmotion commented 3 years ago

What's the status here? Did you have a look at the changes?

davibarreira commented 3 years ago

I don't know much about Sinkhorn other than the regular implementation. So I tried reading the whole code, but I was not very thorough, specially in the unbalanced or the other modifications, such as epsilon scalling.

devmotion commented 3 years ago

GPU tests pass now (new CUDA + NNlibCUDA versions fix some bugs in CUDA that caused timeouts when instantiating a new project environment as done in our tests).