JuliaOptimalTransport / OptimalTransport.jl

Optimal transport algorithms for Julia
https://juliaoptimaltransport.github.io/OptimalTransport.jl/dev
MIT License
93 stars 8 forks source link

Example: variational problems using autodiff #122

Closed zsteve closed 2 years ago

zsteve commented 2 years ago

Added a new example notebook examples/variational demonstrating how to use the package to differentiate through the output of sinkhorn2 and approximate solutions to PDEs as Wasserstein gradient flows.

coveralls commented 2 years ago

Pull Request Test Coverage Report for Build 1173827433


Totals Coverage Status
Change from base Build 1165287400: 0.0%
Covered Lines: 666
Relevant Lines: 682

💛 - Coveralls
codecov-commenter commented 2 years ago

Codecov Report

Merging #122 (5061289) into master (ce8110c) will not change coverage. The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #122   +/-   ##
=======================================
  Coverage   97.65%   97.65%           
=======================================
  Files          13       13           
  Lines         682      682           
=======================================
  Hits          666      666           
  Misses         16       16           

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update ce8110c...5061289. Read the comment docs.

davibarreira commented 2 years ago

Hey @zsteve , what exactly is this optimization that you used to minimize the $\rho$ (the variational problem)? This is the only thing that it wasn't quite clear to me when reading the example. I mean, since $\rho$ is actually a distribution, how are you using gradient descent on it?

function step(ρ0, τ, ε, C, G)
    opt = optimize(
        u -> G(softmax(u), ρ0, τ, ε, C),
        ones(size(ρ0)),
        LBFGS(),
        Optim.Options(; iterations=50, g_tol=1e-6);
        autodiff=:forward,
    )
    return softmax(Optim.minimizer(opt))
end