JuliaDiff / TaylorDiff.jl

Taylor-mode automatic differentiation for higher-order derivatives
https://juliadiff.org/TaylorDiff.jl/
MIT License
73 stars 8 forks source link

Can't differentiate an ODE solver due to lack of isnan and other type errors. #73

Closed orebas closed 4 months ago

orebas commented 5 months ago

TaylorDiff.jl seems to throw an error when I try to differentiate a fairly simple ODE solver. There is an error on the MTK side, but even after the workaround there (See https://discourse.julialang.org/t/error-trying-to-forwarddiff-through-an-ode-solver/114339/6) I can't get Taylor diff.jl to work.

MWE:

using ModelingToolkit, DifferentialEquations
using TaylorDiff, ForwardDiff
using DifferentiationInterface, Enzyme, Zygote, ReverseDiff
using SciMLSensitivity
#import Base.isnan
#function isnan(x::TaylorScalar{Float64, 2})
#   return false
#end

function ADTest()
    @parameters a
    @variables t x1(t) 
    D = Differential(t)
    states = [x1]
    parameters = [a]

    @named pre_model = ODESystem([D(x1) ~ a * x1], t, states, parameters)
    model = structural_simplify(pre_model)

    ic = Dict(x1 => 1.0)
    p_true = Dict(a => 2.0)

    problem = ODEProblem{true, SciMLBase.FullSpecialize}(model, ic, [0.0, 1.0], p_true)
    soln = ModelingToolkit.solve(problem, Tsit5(), abstol = 1e-12, reltol = 1e-12)
    display(soln(0.5, idxs = [x1]))

    function different_time(new_ic, new_params, new_t)
        #newprob = ODEProblem{true, SciMLBase.FullSpecialize}(model, new_ic, [0.0, new_t*2], new_params)
        #newprob = remake(problem, u0=new_ic, tspan = [0.0, new_t], p = new_params)
        newprob = remake(problem, u0 = new_ic, tspan = [0.0, new_t], p=new_params)
        newprob = remake(newprob, u0 = typeof(new_t).(newprob.u0))
        new_soln = ModelingToolkit.solve(newprob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
        return (soln(new_t, idxs = [x1]))
    end

    function just_t(new_t)
        return different_time(ic, p_true, new_t)[1]
    end
    display(different_time(ic, p_true, 2e-5))
    display(just_t(0.5))

    #display(ForwardDiff.derivative(just_t,1.0))
    display(TaylorDiff.derivative(just_t,1.0,1))  #isnan error
    #display(value_and_gradient(just_t, AutoForwardDiff(), 1.0)) 
    #display(value_and_gradient(just_t, AutoReverseDiff(), 1.0))    
    #display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Reverse), 1.0)) 
    #display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Forward), 1.0)) 
    #display(value_and_gradient(just_t, AutoZygote(), 1.0)) 

end

ADTest()
orebas commented 5 months ago

Running the above, the error is

ERROR: LoadError: MethodError: no method matching isnan(::TaylorScalar{Float64, 2})

Closest candidates are:
  isnan(::Missing)
   @ Base missing.jl:101
  isnan(::BigFloat)
   @ Base mpfr.jl:982
  isnan(::Complex)
   @ Base complex.jl:151
  ...

Stacktrace:
  [1] _any(f::typeof(isnan), itr::Tuple{TaylorScalar{Float64, 2}, TaylorScalar{Float64, 2}}, ::Colon)
    @ Base ./reduce.jl:1220
  [2] any(f::Function, itr::Tuple{TaylorScalar{Float64, 2}, TaylorScalar{Float64, 2}})
    @ Base ./reduce.jl:1235
  [3] get_concrete_tspan(prob::ODEProblem{…}, isadapt::Bool, kwargs::@Kwargs{…}, p::ModelingToolkit.MTKParameters{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1287
  [4] get_concrete_problem(prob::ODEProblem{…}, isadapt::Bool; kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1169
  [5] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1074
  [6] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1003
  [7] (::var"#different_time#1"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::TaylorScalar{Float64, 2})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:32
  [8] (::var"#just_t#2"{var"#different_time#1"{ODESolution{…}, ODEProblem{…}, Num}, Dict{Num, Float64}, Dict{Num, Float64}})(new_t::TaylorScalar{Float64, 2})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:37
  [9] derivative
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:28 [inlined]
 [10] derivative
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:18 [inlined]
 [11] ADTest()
    @ Main ~/learning/ODETests/PLI/MWE3.jl:44
 [12] top-level scope
    @ ~/learning/ODETests/PLI/MWE3.jl:53
 [13] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [14] top-level scope
    @ REPL[1]:1
orebas commented 5 months ago

If I go ahead and try to define isnan, (you can uncomment 4 lines near the top of the MWE), the error becomes

ERROR: LoadError: Non-concrete element type inside of an `Array` detected.
Arrays with non-concrete element types, such as
`Array{Union{Float32,Float64}}`, are not supported by the
differential equation solvers. Anyways, this is bad for
performance so you don't want to be doing this!

If this was a mistake, promote the element types to be
all the same. If this was intentional, for example,
using Unitful.jl with different unit values, then use
an array type which has fast broadcast support for
heterogeneous values such as the ArrayPartition
from RecursiveArrayTools.jl. For example:

```julia
using RecursiveArrayTools
x = ArrayPartition([1.0,2.0],[1f0,2f0])
y = ArrayPartition([3.0,4.0],[3f0,4f0])
x .+ y # fast, stable, and usable as u0 into DiffEq!

Element type: Any

Some of the types have been truncated in the stacktrace for improved reading. To emit complete information in the stack trace, evaluate TruncatedStacktraces.VERBOSE[] = true and re-run the code.

Stacktrace: [1] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…}) @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:592 [2] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…}) @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1080 [3] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…}) @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1003 [4] (::var"#different_time#3"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::TaylorScalar{Float64, 2}) @ Main ~/learning/ODETests/PLI/MWE3.jl:32 [5] (::var"#just_t#4"{var"#different_time#3"{ODESolution{…}, ODEProblem{…}, Num}, Dict{Num, Float64}, Dict{Num, Float64}})(new_t::TaylorScalar{Float64, 2}) @ Main ~/learning/ODETests/PLI/MWE3.jl:37 [6] derivative @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:28 [inlined] [7] derivative @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:18 [inlined] [8] ADTest() @ Main ~/learning/ODETests/PLI/MWE3.jl:44 [9] top-level scope @ ~/learning/ODETests/PLI/MWE3.jl:53 [10] include(fname::String) @ Base.MainInclude ./client.jl:489 [11] top-level scope @ REPL[1]:1 in expression starting at /home/orebas/learning/ODETests/PLI/MWE3.jl:53 Some type information was truncated. Use show(err) to see complete types.

tansongchen commented 5 months ago

This is identified previously: #35, due to the type system inconsistency issues. Unfortunately I haven't figured out a good way to handle this...

tansongchen commented 4 months ago

Ok I now believe not <: Real is a design error and needs to be fixed. I initiated a fix at https://github.com/JuliaDiff/TaylorDiff.jl/tree/subtype-number , when it is done you will be fine at this application

tansongchen commented 4 months ago

Fixed in latest version 0.2.2

orebas commented 4 months ago

I'm still getting this error with the above MWE:

ERROR: LoadError: MethodError: no method matching TaylorScalar{Float64, 2}(::Tuple{Float64, ChainRulesCore.ZeroTangent})

Closest candidates are:
  TaylorScalar{T, N}(::TaylorScalar{T, M}) where {T, N, M}
   @ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/scalar.jl:65
  TaylorScalar{T, N}(::S, ::S) where {T, S<:Real, N}
   @ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/scalar.jl:58
  TaylorScalar{T, N}(::S) where {T, S<:Real, N}
   @ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/scalar.jl:46
  ...

Stacktrace:
  [1] sign(t::TaylorScalar{Float64, 2})
    @ TaylorDiff ~/.julia/packages/TaylorDiff/75xsf/src/codegen.jl:20
  [2] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::TaylorScalar{…}, dtmin::TaylorScalar{…}, dtmax::TaylorScalar{…}, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/solve.jl:120
  [3] __solve(::ODEProblem{…}, ::Tsit5{…}; kwargs::@Kwargs{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/YXsFS/src/solve.jl:6
  [4] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:612
  [5] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1080
  [6] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1003
  [7] (::var"#different_time#5"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::TaylorScalar{Float64, 2})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:32
  [8] (::var"#just_t#6"{var"#different_time#5"{ODESolution{…}, ODEProblem{…}, Num}, Dict{Num, Float64}, Dict{Num, Float64}})(new_t::TaylorScalar{Float64, 2})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:37
  [9] derivatives
    @ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:66 [inlined]
 [10] derivative
    @ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:54 [inlined]
 [11] derivative
    @ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:35 [inlined]
 [12] derivative
    @ ~/.julia/packages/TaylorDiff/75xsf/src/derivative.jl:30 [inlined]
 [13] ADTest()
    @ Main ~/learning/ODETests/PLI/MWE3.jl:44
 [14] top-level scope
    @ ~/learning/ODETests/PLI/MWE3.jl:53
 [15] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [16] top-level scope
    @ REPL[5]:1
tansongchen commented 4 months ago

Oh that's a problem with codegen. I will run you example and make it work tomorrow

tansongchen commented 4 months ago

Ok so I fixed a minor problem related to convert special tangent types at ChainRules. Now they should be fine

julia> ForwardDiff.derivative(just_t, 1.0)
14.778112197861631

julia> TaylorDiff.derivative(just_t, 1.0, 1)
14.77811219786163