JuliaOptimalTransport / OptimalTransport.jl

Optimal transport algorithms for Julia
https://juliaoptimaltransport.github.io/OptimalTransport.jl/dev
MIT License
93 stars 8 forks source link

fix sinkhorn2 bug for ReverseDiff #130

Closed zsteve closed 2 years ago

zsteve commented 2 years ago

Using ReverseDiff with the current implementation of sinkhorn2 breaks:

using OptimalTransport
using ReverseDiff
using ForwardDiff
using LogExpFunctions
import NNlib
using LinearAlgebra

N = 100 
C = rand(N, N)
ε = 0.05

μ = NNlib.softmax(zeros(N,); dims = 1)
ν = NNlib.softmax(zeros(N,); dims = 1)

ReverseDiff.gradient(x -> sinkhorn2(μ, softmax(x), C, ε; regularization = true), zero(ν))

yields an error

ERROR: MethodError: no method matching index_bound(::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, ::ReverseDiff.TrackedReal{Float64, Float64, Nothing})
Closest candidates are:
  index_bound(::Any, ::AbstractArray{T, N}) where {T, N} at /home/syz/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/propagation.jl:25
Stacktrace:
  [1] broadcast_plus(x::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, y::Array{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 0}, #unused#::Type{Float64})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/elementwise.jl:360
  [2] broadcast
    @ ~/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/elementwise.jl:342 [inlined]
  [3] _materialize
    @ ~/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/broadcast.jl:265 [inlined]
  [4] materialize
    @ ~/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/broadcast.jl:282 [inlined]
  [5] sinkhorn2(μ::Vector{Float64}, ν::Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, C::Matrix{Float64}, ε::Float64, alg::SinkhornGibbs; regularization::Bool, plan::Nothing, kwargs::Base.Iterators.Pairs{Symbol, Union{Nothing, Int64}, Tuple{Symbol, Symbol}, NamedTuple{(:atol, :check_convergence), Tuple{Nothing, Int64}}})
    @ OptimalTransport ~/OptimalTransport.jl/src/entropic/sinkhorn.jl:204
  [6] #sinkhorn2#37
    @ ~/OptimalTransport.jl/src/entropic/sinkhorn_gibbs.jl:110 [inlined]
  [7] (::var"#3#4")(x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}})
    @ Main ./REPL[14]:1
  [8] ReverseDiff.GradientTape(f::var"#3#4", input::Vector{Float64}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/E4Tzn/src/api/tape.jl:199
  [9] gradient(f::Function, input::Vector{Float64}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}) (repeats 2 times)
    @ ReverseDiff ~/.julia/packages/ReverseDiff/E4Tzn/src/api/gradients.jl:22
 [10] top-level scope
    @ REPL[14]:1

The error is due to the second line of the below, which handles both cases where the output is scalar or array. This seems to cause issues with ReverseDiff.

        dot_matwise(γ, C) .+                                                                                                                                                                                                                                                   
        ε * reshape(sum(LogExpFunctions.xlogx, γ; dims=(1, 2)), size(γ)[3:end])

This PR instead uses a multiple dispatch on the type of γ, and so mitigates the error. I've updated the tests to check gradient computations using both ForwardDiff and ReverseDiff.

zsteve commented 2 years ago

Another problem of the proposed fix is that it won't allow us anymore to fuse the broadcast operations: we should have use dot_matwise(...) .+ eps .* reshape(...) on the master branch, it seems currently an intermediate array is created if reshape(...) returns an array. Maybe this fixes already the ReverseDiff issue?

Thanks for the insight - just checked this and indeed you're right :D It fixes the issue much more painlessly. Will revert to that and keep the tests.

coveralls commented 2 years ago

Pull Request Test Coverage Report for Build 1215992977

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details


Totals Coverage Status
Change from base Build 1186185679: 0.9%
Covered Lines: 662
Relevant Lines: 671

💛 - Coveralls
codecov-commenter commented 2 years ago

Codecov Report

Merging #130 (23a796a) into master (ab9bc76) will increase coverage by 0.88%. The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #130      +/-   ##
==========================================
+ Coverage   97.77%   98.65%   +0.88%     
==========================================
  Files          14       14              
  Lines         673      671       -2     
==========================================
+ Hits          658      662       +4     
+ Misses         15        9       -6     
Impacted Files Coverage Δ
src/entropic/sinkhorn.jl 100.00% <ø> (ø)
src/exact.jl 98.13% <0.00%> (+5.47%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update ab9bc76...23a796a. Read the comment docs.