JuliaOptimalTransport / OptimalTransport.jl

Optimal transport algorithms for Julia
https://juliaoptimaltransport.github.io/OptimalTransport.jl/dev
MIT License
93 stars 8 forks source link

ForwardDiff errors on differentiating through the output of `sinkhorn2` #86

Closed zsteve closed 3 years ago

zsteve commented 3 years ago

I'm currently trying to autodiff through sinkhorn2 via Optim.jl, but I'm running into the following error:

julia> opt_primal = optimize(u -> f_primal(softmax(u), ε, K, interp_frac), zeros(size(μ0)), LBFGS(), Optim.Options(store_trace = true, 
show_trace = true, iterations = 250); autodiff = :forward)
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#4#5", Float64}, Float64, 11})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
  (::Type{T})(::T) where T<:Number at boot.jl:760
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50
  ...               
Stacktrace:    
  [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#4#5", Float64}, Float64, 11})
    @ Base ./number.jl:7
  [2] setindex!
    @ ./array.jl:841 [inlined]
  [3] setindex!
    @ ./multidimensional.jl:639 [inlined]
  [4] macro expansion
    @ ./broadcast.jl:984 [inlined]
  [5] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [6] copyto!
    @ ./broadcast.jl:983 [inlined]
  [7] copyto!
    @ ./broadcast.jl:936 [inlined]
  [8] materialize!
    @ ./broadcast.jl:894 [inlined]
  [9] materialize!
    @ ./broadcast.jl:891 [inlined]
 [10] sinkhorn_gibbs(μ::Vector{Float64}, ν::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#4#5", Float64}, Float64, 11}}, K::Matrix{Float
64}; tol::Nothing, atol::Nothing, rtol::Nothing, check_marginal_step::Nothing, check_convergence::Nothing, maxiter::Int64)
    @ OptimalTransport ~/OptimalTransport.jl/src/OptimalTransport.jl:194
 [11] sinkhorn_gibbs
    @ ~/OptimalTransport.jl/src/OptimalTransport.jl:161 [inlined]
 [12] sinkhorn(μ::Vector{Float64}, ν::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#4#5", Float64}, Float64, 11}}, C::Matrix{Float64}, ε
::Float64; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OptimalTransport ~/OptimalTransport.jl/src/OptimalTransport.jl:262
 [13] sinkhorn
    @ ~/OptimalTransport.jl/src/OptimalTransport.jl:259 [inlined]
 [14] sinkhorn2(μ::Vector{Float64}, ν::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#4#5", Float64}, Float64, 11}}, C::Matrix{Float64}, 
ε::Float64; regularization::Bool, plan::Nothing, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OptimalTransport ~/OptimalTransport.jl/src/OptimalTransport.jl:286
 [15] #ot_smooth_primal#1
    @ ./REPL[18]:1 [inlined]
 [16] f_primal(μ::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#4#5", Float64}, Float64, 11}}, ε::Float64, K::Matrix{Float64}, f::Float6
4)

The relevant code appears to be OptimalTransport.jl:194:

187     norm_μ = sum(abs, μ) # for convergence check
188     isconverged = false
189     check_step = check_convergence === nothing ? 10 : check_convergence
190     for iter in 0:maxiter
191         if iter % check_step == 0
192             # check source marginal
193             # do not overwrite `tmp1` but reuse it for computing `u` if not converged
194             @. tmp2 = u * tmp1                                                                                                     
195             norm_uKv = sum(abs, tmp2)
196             @. tmp2 = μ - tmp2
197             norm_diff = sum(abs, tmp2)

I'm not sure why 194 causes this issue. I tried looking at the types of u, tmp1 and tmp2: [ Info: (Matrix{Float64}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#4#5", Float64}, Float64, 11}}, Matrix{Float64})

So it appears that the in-place assignment here is resulting in a type incompatibility. I suppose this could be mitigated by avoiding explicit in place assignments?

devmotion commented 3 years ago

Unfortunately I can't reproduce the issue (I don't know f_primal). The solution is to not to remove the in-place assignments but to fix the initializations of the variables - I've already noticed that they are not correct since both the element types and containers can be wrong. I think I already fixed this in my local branch, so it would be interesting if you can post a full example that allows to reproduce the issue.

devmotion commented 3 years ago

I added some of the changes that I wanted to make (performance can be improved a bit more, so it's not complete yet) to the dw/sinkhorn_gibbs branch: https://github.com/JuliaOptimalTransport/OptimalTransport.jl/tree/dw/sinkhorn_gibbs

You can check if this fixes your problem above. You can install the branch with

julia> ] add OptimalTransport#dw/sinkhorn_gibbs
zsteve commented 3 years ago

Unfortunately I can't reproduce the issue (I don't know f_primal). The solution is to not to remove the in-place assignments but to fix the initializations of the variables - I've already noticed that they are not correct since both the element types and containers can be wrong. I think I already fixed this in my local branch, so it would be interesting if you can post a full example that allows to reproduce the issue.

Yep, sorry for the slow response. I was under the impression that I was making some 'rookie mistake' with trying to do autodiff, but after reading the above, you are right! the problem results from u and v being initialised to have types derived from mu, I believe

The issue arises because in the case I posted, I wanted to differentiate in nu. The problem then is that typeof(u) = Vector{Float64}, but we need to be able to compute gradients in u and v.

I added some of the changes that I wanted to make (performance can be improved a bit more, so it's not complete yet) to the dw/sinkhorn_gibbs branch: https://github.com/JuliaOptimalTransport/OptimalTransport.jl/tree/dw/sinkhorn_gibbs You can check if this fixes your problem above. You can install the branch with julia> ] add OptimalTransport#dw/sinkhorn_gibbs

Yes, thanks for fixing this! It is working now. For reference here is the full code I was trying to run

zsteve commented 3 years ago

Here is a working example:

using OptimalTransport, PythonOT
using StatsBase, Distances
using ReverseDiff, Optim, LinearAlgebra
using PyPlot
using BenchmarkTools
using LogExpFunctions

ot_smooth_primal(α, β, C, ε; iter = 50) = OptimalTransport.sinkhorn2(α, β, C, ε; regularization = true, maxiter = iter)
gaussian(x, σ) = @. exp(-x^2/σ^2)

support = LinRange(-1, 1, 64)
μ0 = @. gaussian(support - 0.5, 0.1) + gaussian(support + 0.5, 0.1)
μ0 = μ0/sum(μ0)
μ1 = @. 0.5*gaussian(support - 0.25, 0.1) + 0.5*gaussian(support + 0.25, 0.1)
μ1 = μ1/sum(μ1)

ε = 0.01
C = pairwise(SqEuclidean(), support')
K = @. exp(-C/ε)

f_primal(μ, ε, C, f) = f*ot_smooth_primal(μ0, μ, C, ε; iter = 100) + (1.0-f)*ot_smooth_primal(μ1, μ, C, ε; iter = 100)

interp_frac = 0.5

opt_primal = optimize(u -> f_primal(softmax(u), ε, C, interp_frac), zeros(size(μ0)), LBFGS(), Optim.Options(store_trace = true, show_trace = true, iterations = 250); autodiff = :forward)
α_opt_primal = softmax(Optim.minimizer(opt_primal))

plot(support, α_opt_primal)