SciML / DifferenceEquations.jl

Solving difference equations with DifferenceEquations.jl and the SciML ecosystem.
MIT License
32 stars 6 forks source link

SciML/Zygote/ChainRules issue with adjoint type #75

Open jlperla opened 2 years ago

jlperla commented 2 years ago

Take the following MWE for the custom rule of a make_solution rrule. NOte that it returns a SciMLBase class.

using SciMLBase, Zygote, ChainRulesTestUtils, ChainRulesCore
struct MySolution{T,N,uType,tType,IType} <: SciMLBase.AbstractRODESolution{T,N,uType}
    u::uType
    z::Vector{Vector{Float64}}
    t::tType
    interp::IType
    retcode::Symbol
end

function MySolution(u, t, z; interp = SciMLBase.ConstantInterpolation(t, u), retcode = :Default)
    T = eltype(eltype(u))
    N = 2
    # TODO: add support for has_analytic in the future
    sol = MySolution{T,N,typeof(u),typeof(t),typeof(interp)}(u, z, t, interp, retcode)
    return sol
end

function make_solution(u, t, z)
    sol = MySolution(u, t, z)
    return sol
end

function ChainRulesCore.rrule(::typeof(make_solution), u, t, z)
    sol = MySolution(u, t, z)

    function solve_pb(Δsol)
        @show typeof(Δsol)
        # Not always the solution type!
       # The calculation doesn't matter since it is just the type that is messed up.
        return NoTangent(), NoTangent(), NoTangent(), NoTangent()
    end
    return sol, solve_pb
end

#### using the make_solution
function f(a, z)
    u_in = [a, 2 * a, 3 * a]
    sol = make_solution(u_in, [0.0, 1.0, 2.0], z)
    return sol[1, 2]
end
function f_1(a, z)
    u_in = [a, 2 * a, 3 * a]
    sol = make_solution(u_in, [0.0, 1.0, 2.0], z)
    return sol.u[2][1] # should be equivalent to sol[1,2] in f(a,z)
end
function f_2(a, z)
    u_in = [a, 2 * a, 3 * a]
    sol = make_solution(u_in, [0.0, 1.0, 2.0], z)
    return sol.z[1][2]
end

# Questions:   Should/can the tangent type be made typestable?
# Note that the f_1 and f do the same thing, but the type is wrong for the sol[1,2] one
z_val = [[1.0, 1.0], [0.1, 6.0], [6, 3.0]]

f([5.1, 0.3], z_val)
f_1([5.1, 0.3], z_val)
f_2([5.1, 0.3], z_val)
gradient(f, [5.1, 0.3], z_val)
gradient(f_1, [5.1, 0.3], z_val) # different type!
gradient(f_2, [5.1, 0.3], z_val) # others are fine, just the sciml one.

Note that the MySolution is using the AbstractRODESolution type which gives all sorts of dispatching for solution types.

The display of this from the 3 gradient calls is

typeof(Δsol) = RecursiveArrayTools.VectorOfArray{Float64, 2, Vector{Vector{Float64}}}

typeof(Δsol) = Tangent{Any, NamedTuple{(:u, :z, :t, :interp, :retcode), Tuple{Vector{Any}, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}}}

typeof(Δsol) = Tangent{Any, NamedTuple{(:u, :z, :t, :interp, :retcode), Tuple{ZeroTangent, Vector{Any}, ZeroTangent, ZeroTangent, ZeroTangent}}}

Two things stand out.

  1. In f where it uses the sol[1,2] from the sciml dispatches, the type passed into the rrule is wrong (i.e. it is RecursiveArrayTools.VectorOfArray{Float64, 2, Vector{Vector{Float64}}} which is presumably related to the sol[...] interface) rather than the type of sol itself.
    • Note that f_1 is correct, and the only difference is using sol.u[2,1] instead of sol[1,2].
  2. The type of the Δsol even in the cases is typeof(Δsol) = Tangent{Any, NamedTuple{(:u, :z, :t, :interp, :retcode), Tuple{Vector{Any}, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}}} and it is not typestable.
    • Is there any way to turn the Tangent{Any,....} into something specific to the MySolution type to make it type-stable?
jlperla commented 2 years ago

Seemed like maybe the issue was the dispatching on the getindex for sol[1,2] etc. which zygote might be getting confused on. Tried to add in a specialization but its rrule isn't getting called.

function ChainRulesCore.rrule(::typeof(getindex), A::MySolution, I1::Int, I2::Int)
    @show "In rrule"
    @show A
    function pb(Δfield)
        ΔA = Tangent{typeof(A)}(; u = ZeroTangent())
        return NoTangent(), ΔA, NoTangent(), NoTangent()
    end
    return A.u[I2][I1], pb
end
Base.getindex(A::MySolution, I::Int) = A.u[I]
function Base.getindex(A::MySolution, I1::Int, I2::Int)
    return A.u[I2][I1]
end

Can test with @which and

z_val = [[1.0, 1.0], [0.1, 6.0], [6, 3.0]]
u_val = [[1.0, 1.0], [0.1, 6.0], [6, 3.0]]
ts = [0.00, 1.0, 2.0]
sol = MySolution(u_val, t_val, z_val)
@which sol[1,2] # etc.