SciML / DifferentialEquations.jl

Multi-language suite for high-performance solvers of differential equations and scientific machine learning (SciML) components. Ordinary differential equations (ODEs), stochastic differential equations (SDEs), delay differential equations (DDEs), differential-algebraic equations (DAEs), and more in Julia.
https://docs.sciml.ai/DiffEqDocs/stable/
Other
2.86k stars 228 forks source link

Double callback crossing floating pointer reducer errored. Report this issue. #674

Closed samuela closed 3 years ago

samuela commented 4 years ago

I'm running julia 1.5.1. Reproduction is: https://github.com/samuela/research/commit/e8287f3bd17f4a1dcebdf038c80681210007ee86. You can recreate the error with ] instantiate, ] build, and then julia --project difftaichi/mass_spring.jl. The code is much messier than I'd like but considering I don't understand this error, and it is specifically requested that I report it, I thought I'd just put this out there.

samuela@n64:~/dev/research/julia/odecontrol$ julia --project difftaichi/mass_spring.jl 

[Taichi] mode=release
[Taichi] preparing sandbox at /tmp/taichi-xim6t566
[Taichi] version 0.6.35, llvm 10.0.0, commit 17c9fb79, linux, python 3.6.9
[Taichi] Starting on arch=x64
[Taichi] Starting on arch=x64
n_objects= 14    n_springs= 30
[Taichi] materializing...
[ Info: Testing DiffTaichi gradients
[ Info: Testing dynamics gradients
[ Info: Testing observation gradients
 22.380492 seconds (18.37 M allocations: 2.656 GiB, 1.91% gc time)
 34.355208 seconds (35.80 M allocations: 4.183 GiB, 2.92% gc time)
[ Info: InterpolatingAdjoint
ERROR: LoadError: Double callback crossing floating pointer reducer errored. Report this issue.
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] find_callback_time(::OrdinaryDiffEq.ODEIntegrator{VCABM,false,Array{Float64,1},Nothing,Float64,Array{Float64,1},Float64,Float64,Float64,Array{Array{Float64,1},1},DiffEqBase.ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},false,Array{Float64,1},DiffEqBase.ODEFunction{false,var"#aug_dynamics#14"{typeof(dynamics),typeof(cost),var"#53#54"},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},VCABM,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,var"#aug_dynamics#14"{typeof(dynamics),typeof(cost),var"#53#54"},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.VCABMConstantCache{Array{Float64,1},Array{Array{Float64,1},1},Array{Float64,2},Float64,Array{Float64,1}}},DiffEqBase.DEStats},DiffEqBase.ODEFunction{false,var"#aug_dynamics#14"{typeof(dynamics),typeof(cost),var"#53#54"},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.VCABMConstantCache{Array{Float64,1},Array{Array{Float64,1},1},Array{Float64,2},Float64,Array{Float64,1}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),DiffEqBase.CallbackSet{Tuple{VectorContinuousCallback{var"#49#51",var"#50#52",var"#50#52",typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,Base.Order.ForwardOrdering},DataStructures.BinaryHeap{Float64,Base.Order.ForwardOrdering},Nothing,Nothing,Int64,Tuple{},Tuple{},Tuple{}},Array{Float64,1},Float64,DiffEqBase.CallbackCache{Array{Float64,1},Array{Float64,1}},OrdinaryDiffEq.DefaultInit}, ::VectorContinuousCallback{var"#49#51",var"#50#52",var"#50#52",typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}, ::Int64) at /home/samuela/.julia/packages/DiffEqBase/kRzKx/src/callbacks.jl:739
 [3] find_first_continuous_callback at /home/samuela/.julia/packages/DiffEqBase/kRzKx/src/callbacks.jl:407 [inlined]
 [4] handle_callbacks! at /home/samuela/.julia/packages/OrdinaryDiffEq/IgZMk/src/integrators/integrator_utils.jl:247 [inlined]
 [5] _loopfooter!(::OrdinaryDiffEq.ODEIntegrator{VCABM,false,Array{Float64,1},Nothing,Float64,Array{Float64,1},Float64,Float64,Float64,Array{Array{Float64,1},1},DiffEqBase.ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},false,Array{Float64,1},DiffEqBase.ODEFunction{false,var"#aug_dynamics#14"{typeof(dynamics),typeof(cost),var"#53#54"},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},VCABM,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,var"#aug_dynamics#14"{typeof(dynamics),typeof(cost),var"#53#54"},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.VCABMConstantCache{Array{Float64,1},Array{Array{Float64,1},1},Array{Float64,2},Float64,Array{Float64,1}}},DiffEqBase.DEStats},DiffEqBase.ODEFunction{false,var"#aug_dynamics#14"{typeof(dynamics),typeof(cost),var"#53#54"},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.VCABMConstantCache{Array{Float64,1},Array{Array{Float64,1},1},Array{Float64,2},Float64,Array{Float64,1}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),DiffEqBase.CallbackSet{Tuple{VectorContinuousCallback{var"#49#51",var"#50#52",var"#50#52",typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,Base.Order.ForwardOrdering},DataStructures.BinaryHeap{Float64,Base.Order.ForwardOrdering},Nothing,Nothing,Int64,Tuple{},Tuple{},Tuple{}},Array{Float64,1},Float64,DiffEqBase.CallbackCache{Array{Float64,1},Array{Float64,1}},OrdinaryDiffEq.DefaultInit}) at /home/samuela/.julia/packages/OrdinaryDiffEq/IgZMk/src/integrators/integrator_utils.jl:202
 [6] loopfooter! at /home/samuela/.julia/packages/OrdinaryDiffEq/IgZMk/src/integrators/integrator_utils.jl:166 [inlined]
 [7] solve!(::OrdinaryDiffEq.ODEIntegrator{VCABM,false,Array{Float64,1},Nothing,Float64,Array{Float64,1},Float64,Float64,Float64,Array{Array{Float64,1},1},DiffEqBase.ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},false,Array{Float64,1},DiffEqBase.ODEFunction{false,var"#aug_dynamics#14"{typeof(dynamics),typeof(cost),var"#53#54"},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},VCABM,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,var"#aug_dynamics#14"{typeof(dynamics),typeof(cost),var"#53#54"},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.VCABMConstantCache{Array{Float64,1},Array{Array{Float64,1},1},Array{Float64,2},Float64,Array{Float64,1}}},DiffEqBase.DEStats},DiffEqBase.ODEFunction{false,var"#aug_dynamics#14"{typeof(dynamics),typeof(cost),var"#53#54"},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.VCABMConstantCache{Array{Float64,1},Array{Array{Float64,1},1},Array{Float64,2},Float64,Array{Float64,1}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),DiffEqBase.CallbackSet{Tuple{VectorContinuousCallback{var"#49#51",var"#50#52",var"#50#52",typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,Base.Order.ForwardOrdering},DataStructures.BinaryHeap{Float64,Base.Order.ForwardOrdering},Nothing,Nothing,Int64,Tuple{},Tuple{},Tuple{}},Array{Float64,1},Float64,DiffEqBase.CallbackCache{Array{Float64,1},Array{Float64,1}},OrdinaryDiffEq.DefaultInit}) at /home/samuela/.julia/packages/OrdinaryDiffEq/IgZMk/src/solve.jl:444
 [8] __solve(::ODEProblem{Array{Float64,1},Tuple{Float64,Float64},false,Array{Float64,1},DiffEqBase.ODEFunction{false,var"#aug_dynamics#14"{typeof(dynamics),typeof(cost),var"#53#54"},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem}, ::VCABM; kwargs::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:atol, :rtol, :callback),Tuple{Float64,Float64,VectorContinuousCallback{var"#49#51",var"#50#52",var"#50#52",typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}}}}) at /home/samuela/.julia/packages/OrdinaryDiffEq/IgZMk/src/solve.jl:5
 [9] #solve_call#456 at /home/samuela/.julia/packages/DiffEqBase/kRzKx/src/solve.jl:65 [inlined]
 [10] #solve_up#458 at /home/samuela/.julia/packages/DiffEqBase/kRzKx/src/solve.jl:86 [inlined]
 [11] #solve#457 at /home/samuela/.julia/packages/DiffEqBase/kRzKx/src/solve.jl:74 [inlined]
 [12] (::var"#loss_pullback#15"{Float64,var"#aug_dynamics#14"{typeof(dynamics),typeof(cost),var"#53#54"}})(::Array{Float64,1}, ::Array{Float64,1}, ::VCABM, ::Dict{Symbol,Any}) at /home/samuela/dev/research/julia/odecontrol/ppg.jl:47
 [13] (::var"#ez_loss_and_grad#21"{var"#ez_loss_and_grad#4#22"{var"#loss_pullback#15"{Float64,var"#aug_dynamics#14"{typeof(dynamics),typeof(cost),var"#53#54"}}}})(::Array{Float64,1}, ::Array{Float64,1}, ::VCABM, ::InterpolatingAdjoint{0,true,Val{:central},ZygoteVJP,Bool}; fwd_solve_kwargs::Dict{Symbol,Any}) at /home/samuela/dev/research/julia/odecontrol/ppg.jl:186
 [14] (::var"#13#36"{Dict{Symbol,Any},Array{Float64,1},VCABM,InterpolatingAdjoint{0,true,Val{:central},ZygoteVJP,Bool}})(::Array{Float64,1}) at /home/samuela/dev/research/julia/odecontrol/ppg.jl:241
 [15] iterate at ./generator.jl:47 [inlined]
 [16] _collect(::Array{Array{Float64,1},1}, ::Base.Generator{Array{Array{Float64,1},1},var"#13#36"{Dict{Symbol,Any},Array{Float64,1},VCABM,InterpolatingAdjoint{0,true,Val{:central},ZygoteVJP,Bool}}}, ::Base.EltypeUnknown, ::Base.HasShape{1}) at ./array.jl:699
 [17] collect_similar at ./array.jl:628 [inlined]
 [18] map at ./abstractarray.jl:2162 [inlined]
 [19] (::var"#ez_loss_and_grad_many#34"{var"#ez_loss_and_grad_many#12#35"{var"#_aggregate_batch_results#26"}})(::Array{Array{Float64,1},1}, ::Array{Float64,1}, ::VCABM, ::InterpolatingAdjoint{0,true,Val{:central},ZygoteVJP,Bool}; fwd_solve_kwargs::Dict{Symbol,Any}) at /home/samuela/dev/research/julia/odecontrol/ppg.jl:239
 [20] (::var"#59#60")(::Array{Array{Float64,1},1}, ::Array{Float64,1}) at /home/samuela/dev/research/julia/odecontrol/difftaichi/mass_spring.jl:300
 [21] run(::var"#59#60") at /home/samuela/dev/research/julia/odecontrol/difftaichi/mass_spring.jl:266
 [22] top-level scope at /home/samuela/dev/research/julia/odecontrol/difftaichi/mass_spring.jl:299
 [23] include(::Function, ::Module, ::String) at ./Base.jl:380
 [24] include(::Module, ::String) at ./Base.jl:368
 [25] exec_options(::Base.JLOptions) at ./client.jl:296
 [26] _start() at ./client.jl:506
in expression starting at /home/samuela/dev/research/julia/odecontrol/difftaichi/mass_spring.jl:299
ChrisRackauckas commented 4 years ago

You can't mix continuous callbacks and InterpolatingAdjoint in most cases. Is ReverseDiffAdjoint fine here?

samuela commented 4 years ago

Oh I wasn't aware of that... Why is that not a safe combo? What should be done instead?

ChrisRackauckas commented 4 years ago

Adjoint methods, without a special correction, cannot differentiate through discontinuities introduced in the event handling code. Essentially if you cause a discontinuous change, there's a boundary term in the integral that needs to be added. Automatic differentiation will automatically handle this, but the adjoint methods which are defining and solving the adjoint equation don't do this without extra terms added to them which we haven't done yet. I have been planning to get to this but it's just a bit time consuming and tricky to handle since there's not really a good derivation out there.

The thing that makes this tricker is that it's only the case where you introduce discontinuities which require the extra term (integrator.u = ...). IIRC there's a few packages which rely on the fact that this works. For example, when computing Poincure plots you only need the event detection to find the zero but you don't actually make a modification at the zero, so that's a case where we would be adding an error to a working code.

I hope we can get someone on this in the next month because it's a known wart with weird edge cases. We do mention it in the docs:

https://diffeq.sciml.ai/latest/analysis/sensitivity/#Choosing-a-Sensitivity-Algorithm

The methods which utilize altered differential equation systems only work on ODEs (without events), but work on any ODE solver.

But probably not prominently enough.

samuela commented 4 years ago

Hmm, I could envision how a boundary term could pop out of these discontinuities, but I'll have to think a bit more about exactly what it should be...

If you end up coming across any references that touch on this I'd def be interested in giving them a read.

ChrisRackauckas commented 4 years ago

There's a reference in https://github.com/SciML/DiffEqSensitivity.jl/issues/4#issuecomment-369219280 (and note: I could never find a single thing showing Sundials working with that case, so that seems to be a false lead, but that thesis at least has a decent derivation that an implementation could be done from, but it doesn't seem as clean as I'd expect it to be)

samuela commented 4 years ago

You can't mix continuous callbacks and InterpolatingAdjoint in most cases. Is ReverseDiffAdjoint fine here?

Would DiscreteCallback be safe instead?

ChrisRackauckas commented 4 years ago

Not if you modify integrator.u. There is no discontinuity term w.r.t. t any more but there still is w.r.t. u, so you need to perturb lambda by the derivative of the callback function itself.

samuela commented 4 years ago

Oh, right right. So (IIUC?) the thing that's missing then is "backprop" through the callback functions themselves, ie figure out where the callbacks happen in the forward pass, start from the end doing a callback-less adjoint solve until you hit a known callback time, at which point you pipe the adjoint at that time through the reverse AD of the callback function and continue with the next smooth segment.

samuela commented 4 years ago

Oh, right right. So (IIUC?) the thing that's missing then is "backprop" through the callback functions themselves, ie figure out where the callbacks happen in the forward pass, start from the end doing a callback-less adjoint solve until you hit a known callback time, at which point you pipe the adjoint at that time through the reverse AD of the callback function and continue with the next smooth segment.

So that was a lie. Here's a counterexample: https://gist.github.com/samuela/95dfd752c53e712a984e22832ca72fbe. In this case we try simulating a simple billiards ball hitting a bumper in 1-d (pilfered from the DiffTaichi paper). tl;dr is that just piping the adjoints through is not sufficient to get correct gradients.

This example raised the question: How can I get the times where my callback actually fired in the forwards solve? I wasn't able to figure out any way to recover that information.

ChrisRackauckas commented 4 years ago

This example raised the question: How can I get the times where my callback actually fired in the forwards solve? I wasn't able to figure out any way to recover that information.

You'd have to create a cache array yourself there.

ChrisRackauckas commented 3 years ago

This one was fixed awhile back and we forgot to close it.