SciML / DifferentialEquations.jl

Multi-language suite for high-performance solvers of differential equations and scientific machine learning (SciML) components. Ordinary differential equations (ODEs), stochastic differential equations (SDEs), delay differential equations (DDEs), differential-algebraic equations (DAEs), and more in Julia.
https://docs.sciml.ai/DiffEqDocs/stable/
Other
2.85k stars 226 forks source link

Error when mixing datatypes while differentiating using ForwardDiff and using FunctionWrapperSpecialize #994

Open LilithHafner opened 11 months ago

LilithHafner commented 11 months ago

This MWE

using DifferentialEquations
using ForwardDiff: gradient

function ode_f(du, u, p, t)
    x = u[1]
    v = u[2]
    dx = v
    dv = -x
    du[1] = dx
    du[2] = dv
end

function f(initial)
    tspan = (0.0,1.0)
    prob = ODEProblem(ode_f, initial, tspan)
    sol = solve(prob)
    sol[end][1]
end

gradient(f, Float32[1.0, 1.0])

Gives a "No matching function wrapper was found!" error with a very long stacktrace.

Empirically, I have many workarounds:

I suspect that there is an inconsistency between the code the decides which input types to precompute when using FunctionWrappersWrappers and the code that calls the doubly wrapped function.

This issue stems from an investigation into https://github.com/SciML/juliatorch/issues/10. If this issue were fixed then I expect that https://github.com/SciML/juliatorch/issues/10 would also be fixed

Maually truncated stacktrace ``` ERROR: No matching function wrapper was found! Stacktrace: [1] _call(#unused#::Tuple{}, arg::Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}, fww::FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Flo @ FunctionWrappersWrappers ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:23 [2] _call(fw::Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, SciMLBase.NullParameters, ForwardDiff.Dual{ForwardDiff.Tag{DiffEq @ FunctionWrappersWrappers ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:13 [3] _call(fw::Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.Ordin @ FunctionWrappersWrappers ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:13 [4] _call(fw::Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, ForwardDif @ FunctionWrappersWrappers ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:13 [5] _call(fw::Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1} @ FunctionWrappersWrappers ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:13 [6] (::FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, @ FunctionWrappersWrappers ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:10 [7] (::ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float @ SciMLBase ~/.julia/packages/SciMLBase/VS2ST/src/scimlfunctions.jl:2394 [8] ode_determine_initdt(u0::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, t::Float64, tdir::Float64, dtmax::Float64, abstol::Float32, reltol::ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), prob::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Floa @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/yppG9/src/initdt.jl:53 [9] auto_dt_reset! @ ~/.julia/packages/OrdinaryDiffEq/yppG9/src/integrators/integrator_interface.jl:449 [inlined] [10] handle_dt!(integrator::OrdinaryDiffEq.ODEIntegrator{CompositeAlgorithm{Tuple{Vern7{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Rodas5P{1, false, LinearSolve.DefaultLinearSolver, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, OrdinaryDiffEq.AutoSwitchCache{Vern7{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Rodas5P{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}, Rational{Int64}, Int64}}, true, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Nothing, Float64, SciMLBase.NullParameters, Float64, Float32, Float32, Float64, Vector{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}}, ODESolution{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 2, Vector{V @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:555 [11] __init(prob::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Forwar @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:517 [12] __init (repeats 5 times) @ ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:10 [inlined] [13] __solve(::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{type @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:5 [14] __solve @ ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:1 [inlined] [15] #solve_call#34 @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:557 [inlined] [16] solve_call @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:523 [inlined] [17] #solve_up#42 @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:1006 [inlined] [18] solve_up @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:992 [inlined] [19] #solve#40 @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:929 [inlined] [20] __solve(::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float @ DifferentialEquations ~/.julia/packages/DifferentialEquations/Tu7HS/src/default_solve.jl:14 [21] __solve @ ~/.julia/packages/DifferentialEquations/Tu7HS/src/default_solve.jl:1 [inlined] [22] #__solve#63 @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:1285 [inlined] [23] __solve @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:1278 [inlined] [24] solve_call(::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, SciMLBase.NullParameters, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, SciMLBase.NullParameters, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float64, 2}, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), @ DiffEqBase ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:557 [25] solve_call @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:523 [inlined] [26] #solve_up#42 @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:998 [inlined] [27] solve_up @ ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:992 [inlined] [28] solve(::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(ode_f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{true}, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ DiffEqBase ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:929 [29] solve(::ODEProblem{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(ode_f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}) @ DiffEqBase ~/.julia/packages/DiffEqBase/NpZ7U/src/solve.jl:919 [30] f(initial::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}) @ Main ./REPL[19]:4 [31] vector_mode_dual_eval! @ ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:24 [inlined] [32] vector_mode_gradient(f::typeof(f), x::Vector{Float32}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}}) @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:89 [33] gradient(f::Function, x::Vector{Float32}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}}, ::Val{true}) @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:0 [34] gradient(f::Function, x::Vector{Float32}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float32}, Float32, 2}}}) @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:17 [35] gradient(f::Function, x::Vector{Float32}) @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:17 [36] top-level scope @ REPL[20]:1 ```
oscardssmith commented 11 months ago

slightly better repoducer:

using OrdinaryDiffEq
using ForwardDiff: gradient
ode_f(du, u, p, t) = du[1] = -u[1]
function f(initial)
    tspan = (0.0,1.0)
    prob = ODEProblem(ode_f, initial, tspan)
    solve(prob, Rodas5P())[end][1]
end
gradient(f, Float32[1.0])

Specifically if you use solve(prob, FBDF()), it works so the problem seems to be for the tgrad.

oscardssmith commented 11 months ago

fixed by https://github.com/SciML/OrdinaryDiffEq.jl/pull/2051