SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
329 stars 71 forks source link

Optimizing simple SDE w/ISSEM fails with Zygote #697

Open evolbio opened 2 years ago

evolbio commented 2 years ago

The MWE below yields the error

ERROR: MethodError: no method matching vec(::Nothing)
Closest candidates are:
  vec(::FillArrays.Zeros{T}) where T at /opt/julia/packages/FillArrays/5Arin/src/fillalgebra.jl:4
  vec(::StrideArraysCore.PtrArray{S, D, T, 1, C, 0}) where {S, D, T, C} at /opt/julia/packages/StrideArraysCore/VQxXL/src/reshape.jl:1
  vec(::StrideArraysCore.PtrArray{S, D, T, N, C, 0}) where {S, D, T, N, C} at /opt/julia/packages/StrideArraysCore/VQxXL/src/reshape.jl:2
  ...
Stacktrace:
  [1] _jacNoise!(λ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, 

with full stacktrace: trace.txt

When using AutoForwardDiff(), the code seems to run without error but is so slow that I gave up before it finished 10 iterates. Needed to set tolerances to avoid warnings.

using Optimization, DifferentialEquations, DiffEqSensitivity, OptimizationOptimisers

n = 500

ode(u, p, t) = p[1:n] .* u .- p[n+1:2n]

sqrt_abs(x) = (x > 16.) ? sqrt(x) : abs(x) / 4.
ode_noise(u, p, t) = sqrt_abs.(u)

function loss(p, u0)
    prob = SDEProblem(ode, ode_noise, u0, (0.,1.), p, saveat=0:0.1:1)
    x = solve(prob, ISSEM(), p=p, reltol=1e-4, abstol=1e-6)
    y = [sin(2*pi*z) for z = 0:0.1:1]
    sum(abs2.(x[1,:] - y))
end

opt_func(p,u0) = OptimizationFunction((p,u0) -> loss(p,u0), Optimization.AutoZygote())
opt_prob(p,u0) = OptimizationProblem(opt_func(p,u0), p, u0)

p = randn(2n);
u0 = randn(n);
u0[1] = 0.0

result = solve(opt_prob(p, u0), ADAM(), maxiters=10)
frankschae commented 2 years ago

The issue here is that

ode_noise(u, p, t) = sqrt_abs.(u)

does not depend on p and so we get nothing in the vjp computation from the AD backend/Zygote. In the ODE case, we check that nothing is returned but p is not a NullParameter and throw an error.

It would be nice if we could detect automatically if the parameters are actually used in the drift/diffusion function. Is there a general way to detect if an input argument is used again in the body of a function? @ChrisRackauckas

ChrisRackauckas commented 2 years ago

Just directly catch the nothing case like https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/derivative_wrappers.jl#L553-L554

frankschae commented 2 years ago

but for SDEs, I think one is often only interested in fitting parameters of the drift (or diffusion) function. So the way we handle https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/derivative_wrappers.jl#L553-L554 seems a bit too restrictive (i.e. requiring that parameters occur in f and g)?

ChrisRackauckas commented 2 years ago

yeah so instead of throwing an error in the noise case, we should probably just .= false in that branch

evolbio commented 2 years ago

Following the idea that the current code demands dependency of the diffusion term on the parameters, in the above code I made one change

ode_noise(u, p, t) = sqrt_abs.(u) .+ p[1]*1e-20*ones(length(u))

which threw the error

ERROR: BoundsError: attempt to access 1003-element Vector{Tuple{Float64, Float64}} at index [1004]

with stacktrace: trace.txt

evolbio commented 2 years ago

Updated to 7.3.0 but still get same error on MWE above

(@v1.8) pkg> st SciMLSensitivity
Status `/opt/julia/environments/v1.8/Project.toml`
  [1ed8b502] SciMLSensitivity v7.3.0

throwing the same error with different "Closest candidates"

julia> result = solve(opt_prob(p, u0), ADAM(), maxiters=10)
ERROR: MethodError: no method matching vec(::Nothing)
Closest candidates are:
  vec(::FillArrays.Zeros{T}) where T at /opt/julia/packages/FillArrays/5Arin/src/fillalgebra.jl:4
  vec(::LazyArrays.ApplyArray{T, 2, typeof(hcat)} where T) at /opt/julia/packages/LazyArrays/z8CRn/src/lazyconcat.jl:459
  vec(::Distributions.Distribution{<:Distributions.ArrayLikeVariate}) at /opt/julia/packages/Distributions/1PkiH/src/reshaped.jl:143
frankschae commented 2 years ago

oh hmm -- strange, sorry about that. Could you paste the full stack trace?

evolbio commented 2 years ago

stacktrace.txt

frankschae commented 2 years ago

In your code above, you are using using DiffEqSensitivity.

The package changed its name to SciMLSensitivity a few weeks ago (which looks correct from your ]st output).

The stacktrace still points to:

 @ DiffEqSensitivity /opt/julia/packages/DiffEqSensitivity/Pn9H4/src/derivative_wrappers.jl:683 

Could you try again if it works with:

using SciMLSensitivity
evolbio commented 2 years ago

Sorry, I missed the name change. Using the following code

using Optimization, DifferentialEquations, SciMLSensitivity, OptimizationOptimisers

n = 500

ode(u, p, t) = p[1:n] .* u .- p[n+1:2n]

sqrt_abs(x) = (x > 16.) ? sqrt(x) : abs(x) / 4.
ode_noise(u, p, t) = sqrt_abs.(u)

function loss(p, u0)
    prob = SDEProblem(ode, ode_noise, u0, (0.,1.), p, saveat=0:0.1:1)
    x = solve(prob, ISSEM(), p=p, reltol=1e-4, abstol=1e-6)
    y = [sin(2*pi*z) for z = 0:0.1:1]
    sum(abs2.(x[1,:] - y))
end

opt_func(p,u0) = OptimizationFunction((p,u0) -> loss(p,u0), Optimization.AutoZygote())
opt_prob(p,u0) = OptimizationProblem(opt_func(p,u0), p, u0)

p = randn(2n);
u0 = randn(n);
u0[1] = 0.0

result = solve(opt_prob(p, u0), ADAM(), maxiters=10)

I get

ERROR: BoundsError: attempt to access 1016-element Vector{Tuple{Float64, Float64}} at index [1017]
Stacktrace:

with stacktrace.txt

frankschae commented 2 years ago

I think that might be a solver/tolerance issue. When I run your code with n=500 I can reproduce the error. For a smaller number of parameters and state size with InterpolatingAdjoint

function loss(p, u0)
    prob = SDEProblem(ode, ode_noise, u0, (0.,1.), p, saveat=0:0.1:1)
    x = solve(prob, ISSEM(), p=p, reltol=1e-4, abstol=1e-6, sensealg=InterpolatingAdjoint())
    y = [sin(2*pi*z) for z = 0:0.1:1]
    sum(abs2.(x[1,:] - y))
end

which is the automatic selected sensealg in your example for n=500, I get a warning:

┌ Warning: dt(-2.220446049250313e-16) <= dtmin(2.220446049250313e-16) at t=0.999998726537327. Aborting. There is either an error in your model specification or the true solution is unstable.
└ @ SciMLBase ~/.julia/packages/SciMLBase/QzHjf/src/integrator_interface.jl:484

I'm going to take a closer look at what's going on internally.

frankschae commented 2 years ago

hmm, so the non-linear solve call to the Newton method within the step of ISSEM proposes t= 1.000001 which leads to the indexing error in one of the first steps. I'm still not sure why that happens though.

evolbio commented 2 years ago

With different tolerances I get the Warning: dt(-2.220446049250313e-16) <= dtmin(2.2 ... warning. Is there an alternative to ISSEM()? As mentioned previously, my noise is

sqrt_abs(x) = (x > 16.) ? sqrt(x) : abs(x) / 4.
ode_noise(u, p, t) = sqrt_abs.(u)

I was not able to find anything other than ISSEM() to work with ForwardDiff, so presumably all others would also fail with Zygote.

ChrisRackauckas commented 2 years ago

I was not able to find anything other than ISSEM() to work with ForwardDiff

Share an example of that? All of the codes use the same integrator and explicit handling so they all should do the same thing here. Same with Zygote (in that Zygote will require something that can handle non-diagonal noise SDEs, so any non-diagonal noise SDE solver would have the same behavior).

ChrisRackauckas commented 2 years ago

Is the equation actually stiff enough to warrent ISSEM? This looks like a case better served by SOSRI.

evolbio commented 2 years ago

When I use SOSRI I get

ERROR: The algorithm is not compatible with the chosen noise type. Please see the documentation on the solver methods

with stacktrace.txt Here is the MWE, which is same as above but with SOSRI

using Optimization, DifferentialEquations, SciMLSensitivity, OptimizationOptimisers

n = 500

ode(u, p, t) = p[1:n] .* u .- p[n+1:2n]

sqrt_abs(x) = (x > 16.) ? sqrt(x) : abs(x) / 4.
ode_noise(u, p, t) = sqrt_abs.(u)

function loss(p, u0)
    prob = SDEProblem(ode, ode_noise, u0, (0.,1.), p, saveat=0:0.1:1)
    x = solve(prob, SOSRI(), p=p, reltol=1e-4, abstol=1e-6)
    y = [sin(2*pi*z) for z = 0:0.1:1]
    sum(abs2.(x[1,:] - y))
end

opt_func(p,u0) = OptimizationFunction((p,u0) -> loss(p,u0), Optimization.AutoZygote())
opt_prob(p,u0) = OptimizationProblem(opt_func(p,u0), p, u0)

p = randn(2n);
u0 = randn(n);
u0[1] = 0.0

result = solve(opt_prob(p, u0), ADAM(), maxiters=10)
ChrisRackauckas commented 2 years ago

Yes, that's what I mentioned. SOSRI cannot solve the commutative SDE that the reverse solve will generate. Use Optimization.AutoForwardDiff() instead. The extra complexities means that the cutoff to reverse mode will be higher on SDEs than ODEs, and I wouldn't be surprised if 500 parameters is well within the range of forward-mode outperforming reverse.

evolbio commented 2 years ago

Yes, SOSRI works with ForwardDiff. I have not done full benchmarks, but my impression is that for high tolerances, SOSRI is faster than ISSEM on the MWE. On my code, which is currently set at reltol=1e-4, abstol=1e-6, SOSRI is so slow that it never completed a full iterate after a period during which ISSEM completed many. On my code, SRA3 also appears to be very slow compared with ISSEM.

When I drop the noise and use an ODE, Zygote works with Tsit5 and is faster than ForwardDiff (I need to recheck but am reasonably sure about that).

I am fitting DiffEqs, and so starting with random parameters. Perhaps the early iterates are a bit harder and the ideal solver changes over time. However, with ForwardDiff, overall, ISSEM seems to work most reliably and also fastest. Maybe I could raise the tolerances. However, currently, the code functions well w/ForwardDiff and ISSEM, although very slowly.

evolbio commented 2 years ago

In other words, I was hoping to be able to try Zygote with ISSEM, which I think would be faster for my problem than ForwardDiff.

ChrisRackauckas commented 2 years ago

On my code, which is currently set at reltol=1e-4, abstol=1e-6, SOSRI is so slow

Tolerances in the SDE space are written in terms of strong error. It's very different from the way you'd think of error in ODEs. Default ODE tolerances are reltol=1e-3, abstol=1e-6 and Tsit5 converges as O(dt^5). Default SDE tolerances are reltol=1e-2,abstol=1e-2 and converges as O(dt^3/2). Getting a real SDE to reltol=1e-4 in strong error will be a few orders of magnitude more work than getting the ODE, and does not translate to expectation error bounds.

When I drop the noise and use an ODE, Zygote works with Tsit5 and is faster than ForwardDiff (I need to recheck but am reasonably sure about that).

Yes, for ODEs the cutoff is around 100 (see https://arxiv.org/abs/1812.01892). But for many technical factors, the cutoff will be higher for SDEs. I don't have a number, but I wouldn't be surprised if it takes at least 1000 parameters before reverse mode ends up being faster, if not more, because you cannot use the faster specialized integrators for diagonal noise (since the reverse pass is not diagonal in that case). While I would want to get this investigated and fixed, I am pretty sure it won't be faster even when fully optimized just due to the method differences.

evolbio commented 2 years ago

Thanks, that is very helpful. I will try to raise my tolerances and see what happens.

I have Lux NNs embedded within the deterministic (drift) term, so I am often well over 1000 parameters. If I remember correctly from other runs, Zygote does great on NNs so it may be that Zygote would help a lot for some of my SDE runs.


My particular application may not be of interest, but just in case: I am searching for biological cellular transcription factor networks to solve specific environmental challenges. I got this to work reasonably well for SDEs when using the standard biological model for the binding of transcription factors to control gene expression

https://doi.org/10.1101/2022.07.05.498863

which is a fairly powerful proof of principle for how SciML can solve relatively tough and potentially realistic optimization problems in cellular biology.

However, I guessed that I would get much better optimization if I used neural nets to describe the functions that map the concentrations of transcription factors to the level of gene expression for each gene. My current runs show that this is true, significantly advancing the optimizations.

Then, once the optimizations are in place for NNs, one can fit the biological model of transcription factor binding to each NN and obtain a description of the control of binding and expression in terms of the standard biological factors.