JuliaDiff / ReverseDiff.jl

Reverse Mode Automatic Differentiation for Julia
Other
348 stars 56 forks source link

@grad_from_chainrules macro fails when using multi-output functions #221

Open ThummeTo opened 1 year ago

ThummeTo commented 1 year ago

Dear team,

first: Thanks for developing this nice package :-)

I think there is an error with the macro @grad_from_chainrules when using it on multi-output functions (for example a function that outputs a tuple of two vectors). Note, that gradient/jacobian determination is not part of the current Github-tests, only the rrules are evaluated directly, but no gradient/jacobian is built for testing ReverseDiff with the corresponding rrule. However this works fine for single-output functions together with ReverseDiff.gradient.

See the following MWE:

using ForwardDiff, Zygote, ReverseDiff, ChainRulesCore

# SINGLE OUTPUT FUNCTION 

f(x) = sum(4x .+ 1)

function ChainRulesCore.rrule(::typeof(f), x)
    r = f(x)
    function back(d)
        return ChainRulesCore.NoTangent(), fill(3, size(x))
    end
    return r, back
end

ReverseDiff.@grad_from_chainrules f(x::AbstractVector{<:ReverseDiff.TrackedReal})

seed = rand(3)

# Everything ok, ForwardDiff computes the correct derivatives (no frule defined),
# ReverseDiff and Zygote use the new rrule as to expect
ForwardDiff.gradient(f, seed)
Zygote.gradient(f, seed)[1]
ReverseDiff.gradient(f, seed)

# MULTI OUTPUT FUNCTION 

f_multi(x, y) = (4x .+ 1, 3x .+ 1 .+ y)

function ChainRulesCore.rrule(::typeof(f_multi), x, y)
    r = f_multi(x, y)
    function back(d)
        y1, y2 = d
        return ChainRulesCore.NoTangent(), fill(2 , size(x)), fill(3 , size(y))
    end
    return r, back
end

ReverseDiff.@grad_from_chainrules f_multi(x::AbstractVector{<:ReverseDiff.TrackedReal}, y::AbstractVector{<:Real})

# ForwardDiff computes the correct derivatives (no frule defined),
# Zygote use the new rrule as to expect, ReverseDiff fails!
ForwardDiff.jacobian(x -> f_multi(x, ones(3))[1], seed)
Zygote.jacobian(x -> f_multi(x, ones(3))[1], seed)[1]
ReverseDiff.jacobian(x -> f_multi(x, ones(3))[1], seed) # this errors!

Tested in Julia 1.8.5, all used libraries up-to-date.

Thanks in advance & best regards!

ThummeTo commented 1 year ago

Forgot to post the error message:

ERROR: MethodError: no method matching track(::Tuple{Vector{Float64}, Vector{Float64}}, ::Vector{ReverseDiff.AbstractInstruction})
Closest candidates are:
  track(::AbstractArray, ::Vector{ReverseDiff.AbstractInstruction}) at ...\ReverseDiff.jl\src\tracked.jl:469
  track(::Real, ::Vector{ReverseDiff.AbstractInstruction}) at ...\ReverseDiff.jl\src\tracked.jl:467
  track(::typeof(vcat), ::Union{Number, AbstractVecOrMat}...) at ...\ReverseDiff.jl\src\macros.jl:190       
  ...
Stacktrace:
 [1] track(#unused#::typeof(f_multi), x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, y::Vector{Float64})
   @ Main ...\ReverseDiff.jl\src\macros.jl:329
 [2] f_multi(x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, y::Vector{Float64})
   @ Main ...\ReverseDiff.jl\src\macros.jl:324
 [3] (::var"#17#18")(x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}})
   @ Main ...\MWE_multi_reversediff.jl:44
 [4] ReverseDiff.JacobianTape(f::var"#17#18", input::Vector{Float64}, cfg::ReverseDiff.JacobianConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Nothing})
   @ ReverseDiff ...\ReverseDiff.jl\src\api\tape.jl:229
 [5] jacobian(f::Function, input::Vector{Float64}, cfg::ReverseDiff.JacobianConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Nothing}) (repeats 2 times)
   @ ReverseDiff ...\src\api\jacobians.jl:23
 [6] top-level scope
   @ ...\MWE_multi_reversediff.jl:44
cortner commented 1 year ago

+1 --- I've run into the same problem and my MWE is almost identical to the one above.

For me this is a huge problem, because I am hoping to use RevDiff over Zygote to get second derivatives. But when you implement a pullback of a pullback then you will typically have multiple outputs to take care of.

If anybody can suggest how to fix this or work around it, I'd be very grateful.

CC @tjjarvinen