Open evolbio opened 2 years ago
The issue here is that
ode_noise(u, p, t) = sqrt_abs.(u)
does not depend on p
and so we get nothing
in the vjp computation from the AD backend/Zygote. In the ODE case, we check that nothing
is returned but p
is not a NullParameter
and throw an error.
It would be nice if we could detect automatically if the parameters are actually used in the drift/diffusion function. Is there a general way to detect if an input argument is used again in the body of a function? @ChrisRackauckas
Just directly catch the nothing
case like https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/derivative_wrappers.jl#L553-L554
but for SDEs, I think one is often only interested in fitting parameters of the drift (or diffusion) function. So the way we handle https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/derivative_wrappers.jl#L553-L554 seems a bit too restrictive (i.e. requiring that parameters occur in f
and g
)?
yeah so instead of throwing an error in the noise case, we should probably just .= false
in that branch
Following the idea that the current code demands dependency of the diffusion term on the parameters, in the above code I made one change
ode_noise(u, p, t) = sqrt_abs.(u) .+ p[1]*1e-20*ones(length(u))
which threw the error
ERROR: BoundsError: attempt to access 1003-element Vector{Tuple{Float64, Float64}} at index [1004]
with stacktrace: trace.txt
Updated to 7.3.0 but still get same error on MWE above
(@v1.8) pkg> st SciMLSensitivity
Status `/opt/julia/environments/v1.8/Project.toml`
[1ed8b502] SciMLSensitivity v7.3.0
throwing the same error with different "Closest candidates"
julia> result = solve(opt_prob(p, u0), ADAM(), maxiters=10)
ERROR: MethodError: no method matching vec(::Nothing)
Closest candidates are:
vec(::FillArrays.Zeros{T}) where T at /opt/julia/packages/FillArrays/5Arin/src/fillalgebra.jl:4
vec(::LazyArrays.ApplyArray{T, 2, typeof(hcat)} where T) at /opt/julia/packages/LazyArrays/z8CRn/src/lazyconcat.jl:459
vec(::Distributions.Distribution{<:Distributions.ArrayLikeVariate}) at /opt/julia/packages/Distributions/1PkiH/src/reshaped.jl:143
oh hmm -- strange, sorry about that. Could you paste the full stack trace?
In your code above, you are using using DiffEqSensitivity
.
The package changed its name to SciMLSensitivity
a few weeks ago (which looks correct from your ]st
output).
The stacktrace still points to:
@ DiffEqSensitivity /opt/julia/packages/DiffEqSensitivity/Pn9H4/src/derivative_wrappers.jl:683
Could you try again if it works with:
using SciMLSensitivity
Sorry, I missed the name change. Using the following code
using Optimization, DifferentialEquations, SciMLSensitivity, OptimizationOptimisers
n = 500
ode(u, p, t) = p[1:n] .* u .- p[n+1:2n]
sqrt_abs(x) = (x > 16.) ? sqrt(x) : abs(x) / 4.
ode_noise(u, p, t) = sqrt_abs.(u)
function loss(p, u0)
prob = SDEProblem(ode, ode_noise, u0, (0.,1.), p, saveat=0:0.1:1)
x = solve(prob, ISSEM(), p=p, reltol=1e-4, abstol=1e-6)
y = [sin(2*pi*z) for z = 0:0.1:1]
sum(abs2.(x[1,:] - y))
end
opt_func(p,u0) = OptimizationFunction((p,u0) -> loss(p,u0), Optimization.AutoZygote())
opt_prob(p,u0) = OptimizationProblem(opt_func(p,u0), p, u0)
p = randn(2n);
u0 = randn(n);
u0[1] = 0.0
result = solve(opt_prob(p, u0), ADAM(), maxiters=10)
I get
ERROR: BoundsError: attempt to access 1016-element Vector{Tuple{Float64, Float64}} at index [1017]
Stacktrace:
with stacktrace.txt
I think that might be a solver/tolerance issue. When I run your code with n=500
I can reproduce the error. For a smaller number of parameters and state size with InterpolatingAdjoint
function loss(p, u0)
prob = SDEProblem(ode, ode_noise, u0, (0.,1.), p, saveat=0:0.1:1)
x = solve(prob, ISSEM(), p=p, reltol=1e-4, abstol=1e-6, sensealg=InterpolatingAdjoint())
y = [sin(2*pi*z) for z = 0:0.1:1]
sum(abs2.(x[1,:] - y))
end
which is the automatic selected sensealg in your example for n=500
, I get a warning:
┌ Warning: dt(-2.220446049250313e-16) <= dtmin(2.220446049250313e-16) at t=0.999998726537327. Aborting. There is either an error in your model specification or the true solution is unstable.
└ @ SciMLBase ~/.julia/packages/SciMLBase/QzHjf/src/integrator_interface.jl:484
I'm going to take a closer look at what's going on internally.
hmm, so the non-linear solve call to the Newton method within the step of ISSEM
proposes t= 1.000001
which leads to the indexing error in one of the first steps. I'm still not sure why that happens though.
With different tolerances I get the Warning: dt(-2.220446049250313e-16) <= dtmin(2.2 ...
warning. Is there an alternative to ISSEM()? As mentioned previously, my noise is
sqrt_abs(x) = (x > 16.) ? sqrt(x) : abs(x) / 4.
ode_noise(u, p, t) = sqrt_abs.(u)
I was not able to find anything other than ISSEM() to work with ForwardDiff, so presumably all others would also fail with Zygote.
I was not able to find anything other than ISSEM() to work with ForwardDiff
Share an example of that? All of the codes use the same integrator and explicit handling so they all should do the same thing here. Same with Zygote (in that Zygote will require something that can handle non-diagonal noise SDEs, so any non-diagonal noise SDE solver would have the same behavior).
Is the equation actually stiff enough to warrent ISSEM? This looks like a case better served by SOSRI.
When I use SOSRI I get
ERROR: The algorithm is not compatible with the chosen noise type. Please see the documentation on the solver methods
with stacktrace.txt Here is the MWE, which is same as above but with SOSRI
using Optimization, DifferentialEquations, SciMLSensitivity, OptimizationOptimisers
n = 500
ode(u, p, t) = p[1:n] .* u .- p[n+1:2n]
sqrt_abs(x) = (x > 16.) ? sqrt(x) : abs(x) / 4.
ode_noise(u, p, t) = sqrt_abs.(u)
function loss(p, u0)
prob = SDEProblem(ode, ode_noise, u0, (0.,1.), p, saveat=0:0.1:1)
x = solve(prob, SOSRI(), p=p, reltol=1e-4, abstol=1e-6)
y = [sin(2*pi*z) for z = 0:0.1:1]
sum(abs2.(x[1,:] - y))
end
opt_func(p,u0) = OptimizationFunction((p,u0) -> loss(p,u0), Optimization.AutoZygote())
opt_prob(p,u0) = OptimizationProblem(opt_func(p,u0), p, u0)
p = randn(2n);
u0 = randn(n);
u0[1] = 0.0
result = solve(opt_prob(p, u0), ADAM(), maxiters=10)
Yes, that's what I mentioned. SOSRI cannot solve the commutative SDE that the reverse solve will generate. Use Optimization.AutoForwardDiff()
instead. The extra complexities means that the cutoff to reverse mode will be higher on SDEs than ODEs, and I wouldn't be surprised if 500 parameters is well within the range of forward-mode outperforming reverse.
Yes, SOSRI works with ForwardDiff. I have not done full benchmarks, but my impression is that for high tolerances, SOSRI is faster than ISSEM on the MWE. On my code, which is currently set at reltol=1e-4, abstol=1e-6, SOSRI is so slow that it never completed a full iterate after a period during which ISSEM completed many. On my code, SRA3 also appears to be very slow compared with ISSEM.
When I drop the noise and use an ODE, Zygote works with Tsit5 and is faster than ForwardDiff (I need to recheck but am reasonably sure about that).
I am fitting DiffEqs, and so starting with random parameters. Perhaps the early iterates are a bit harder and the ideal solver changes over time. However, with ForwardDiff, overall, ISSEM seems to work most reliably and also fastest. Maybe I could raise the tolerances. However, currently, the code functions well w/ForwardDiff and ISSEM, although very slowly.
In other words, I was hoping to be able to try Zygote with ISSEM, which I think would be faster for my problem than ForwardDiff.
On my code, which is currently set at reltol=1e-4, abstol=1e-6, SOSRI is so slow
Tolerances in the SDE space are written in terms of strong error. It's very different from the way you'd think of error in ODEs. Default ODE tolerances are reltol=1e-3, abstol=1e-6
and Tsit5 converges as O(dt^5). Default SDE tolerances are reltol=1e-2,abstol=1e-2
and converges as O(dt^3/2). Getting a real SDE to reltol=1e-4 in strong error will be a few orders of magnitude more work than getting the ODE, and does not translate to expectation error bounds.
When I drop the noise and use an ODE, Zygote works with Tsit5 and is faster than ForwardDiff (I need to recheck but am reasonably sure about that).
Yes, for ODEs the cutoff is around 100 (see https://arxiv.org/abs/1812.01892). But for many technical factors, the cutoff will be higher for SDEs. I don't have a number, but I wouldn't be surprised if it takes at least 1000 parameters before reverse mode ends up being faster, if not more, because you cannot use the faster specialized integrators for diagonal noise (since the reverse pass is not diagonal in that case). While I would want to get this investigated and fixed, I am pretty sure it won't be faster even when fully optimized just due to the method differences.
Thanks, that is very helpful. I will try to raise my tolerances and see what happens.
I have Lux NNs embedded within the deterministic (drift) term, so I am often well over 1000 parameters. If I remember correctly from other runs, Zygote does great on NNs so it may be that Zygote would help a lot for some of my SDE runs.
My particular application may not be of interest, but just in case: I am searching for biological cellular transcription factor networks to solve specific environmental challenges. I got this to work reasonably well for SDEs when using the standard biological model for the binding of transcription factors to control gene expression
https://doi.org/10.1101/2022.07.05.498863
which is a fairly powerful proof of principle for how SciML can solve relatively tough and potentially realistic optimization problems in cellular biology.
However, I guessed that I would get much better optimization if I used neural nets to describe the functions that map the concentrations of transcription factors to the level of gene expression for each gene. My current runs show that this is true, significantly advancing the optimizations.
Then, once the optimizations are in place for NNs, one can fit the biological model of transcription factor binding to each NN and obtain a description of the control of binding and expression in terms of the standard biological factors.
The MWE below yields the error
with full stacktrace: trace.txt
When using AutoForwardDiff(), the code seems to run without error but is so slow that I gave up before it finished 10 iterates. Needed to set tolerances to avoid warnings.