Open jlperla opened 2 years ago
Seemed like maybe the issue was the dispatching on the getindex
for sol[1,2]
etc. which zygote might be getting confused on. Tried to add in a specialization but its rrule isn't getting called.
function ChainRulesCore.rrule(::typeof(getindex), A::MySolution, I1::Int, I2::Int)
@show "In rrule"
@show A
function pb(Δfield)
ΔA = Tangent{typeof(A)}(; u = ZeroTangent())
return NoTangent(), ΔA, NoTangent(), NoTangent()
end
return A.u[I2][I1], pb
end
Base.getindex(A::MySolution, I::Int) = A.u[I]
function Base.getindex(A::MySolution, I1::Int, I2::Int)
return A.u[I2][I1]
end
Can test with @which
and
z_val = [[1.0, 1.0], [0.1, 6.0], [6, 3.0]]
u_val = [[1.0, 1.0], [0.1, 6.0], [6, 3.0]]
ts = [0.00, 1.0, 2.0]
sol = MySolution(u_val, t_val, z_val)
@which sol[1,2] # etc.
Take the following MWE for the custom rule of a
make_solution
rrule. NOte that it returns a SciMLBase class.Note that the
MySolution
is using theAbstractRODESolution
type which gives all sorts of dispatching for solution types.The display of this from the 3 gradient calls is
Two things stand out.
f
where it uses thesol[1,2]
from the sciml dispatches, the type passed into the rrule is wrong (i.e. it isRecursiveArrayTools.VectorOfArray{Float64, 2, Vector{Vector{Float64}}}
which is presumably related to thesol[...]
interface) rather than the type ofsol
itself.f_1
is correct, and the only difference is usingsol.u[2,1]
instead ofsol[1,2]
.Δsol
even in the cases istypeof(Δsol) = Tangent{Any, NamedTuple{(:u, :z, :t, :interp, :retcode), Tuple{Vector{Any}, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}}}
and it is not typestable.Tangent{Any,....}
into something specific to theMySolution
type to make it type-stable?