Closed sebapersson closed 10 months ago
Are you sure that AMICI is using the exact same loss function here? It seems very numerically unstable. I get:
du, dp = adjoint_sensitivities(solForward, Rosenbrock23(), dgdu_discrete=compute∂G∂u, t=[120.0],
sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP()),
abstol=1e-8, reltol=1e-8)
([-216421.83685668814, -1.6858632567123687e8, 7.018270541507498e-7, 362.3966048917397, 543.0384299075158, 1.635347306251986e8, -2.754778878095396e7, -206524.68630875406, -2.3349172307966763e11, 347.52497578730214 … 368.05406281668667, 356.948127920856, 1.7258208429450923e8, -7.321321335008193e8, -216421.83685668817, -216420.34264405907, 226915.23606140065, -216420.34264405907, -2.311567443657711e11, -3.229788032430855e11], [68199.89783762788 -1.5098419512755957e9 … 0.003940099991672051 1.8215793216270419e-6])
and you can see the values are much more than floating point eps apart, so this wouldn't be a loss function you'd want to use in practice. Can you confirm those derivative values?
Does AMICI overwrite some default values of Sundials? If I set dtmin
, I get some values that agree reasonably well with Rosenbrock23
du, dp = adjoint_sensitivities(solForward, CVODE_BDF(), dgdu_discrete=compute∂G∂u, t=[120.0],
sensealg=InterpolatingAdjoint(),
abstol=1e-8, reltol=1e-8, dtmin = 1e-14)
julia> du, dp
([-216421.6404173965, -1.6858617571679205e8, -5.0432981731661105e-6, 362.3962733866135, 543.0379210572581, 1.6353457628614962e8, -2.754776317992651e7, -206524.4942973113, -2.3349151465126083e11, 347.52466850475935 … 368.0537219715651, 356.94780532366275, 1.725819189064261e8, -7.321314792626908e8, -216421.64041739647, -216420.144702485, 226914.94915938325, -216420.14470248503, -2.3115653802170447e11, -3.229785149334503e11], [68199.84529885986 -1.509840790874716e9 … 0.003940097888561451 3.163782126191852e-7])
(Sundials default is here: https://github.com/LLNL/sundials/blob/1ea097bb3bce207335ac35f0b5e78df5d71c6409/src/ida/ida_impl.h#L57)
I'll wait on @sebapersson here. It looks like it's just a numerically unstable kind of problem and it comes down to some tuning of solver parameters.
Thanks for looking at this again, I am currently on vacation (without computer access), but will come back to this as soon as I am back from vacation (14:th of August)
I can verify AMICI uses the same loss function (we actually use AMICI to produce reference values to test PEtab.jl). AMICI further produces a sufficiently accurate gradient when using the full loss-function (which in addition to the above MVE also includes more simulation conditions, observables etc...)
using Zygote
using SciMLSensitivity
using OrdinaryDiffEq
using Sundials
using PEtab
# To run the code you need to have downloaded the Bachmann-folder
pathYML = joinpath(@__DIR__, "Bachmann_MSB2011", "Bachmann_MSB2011.yaml")
petabModel = readPEtabModel(pathYML, verbose=false, forceBuildJuliaFiles=true)
# Gradient via ForwardDiff
petabProblem = createPEtabODEProblem(petabModel,
gradientMethod=:ForwardDiff,
odeSolverOptions=ODESolverOptions(Rodas4P(), abstol=1e-8, reltol=1e-8))
# Gradient via adjoint sensitivity analysis
petabProblemAdjoint = createPEtabODEProblem(petabModel,
gradientMethod=:Adjoint,
odeSolverOptions=ODESolverOptions(Rodas4P(), abstol=1e-8, reltol=1e-8),
odeSolverGradientOptions=ODESolverOptions(CVODE_BDF(), abstol=1e-8, reltol=1e-8, maxiters=100000, dtmin=1e-14),
sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP()))
# Gradient computed in AMICI (and hard copied here)
gradAMICI = [2.9123036726432547e7, 26.75753206184606, 8.139436368976583, 3.979215132439631e6, 2.362669741125894e7, 56567.96099941749, 829981.2894948393, 9.249979318847144e-11, -3.674355490308252e-11, 6.4713903456660934e7, -6.216618460775475e7, -1.2302004250372443e7, 2.812836122198149e6, -81817.54182044673, 23.81253297730834, 1.1502376612847047e-8, -0.7444354455312014, 0.6584068806862132, 0.3521714827335728, 8.174147852179505, 4.041393946368265e7, 8.172620249742219e6, -2.9104178930013895e7, 2.780567403648323e7, 0.31335070822723, -3.0256829194270867, 15141.1326948684, 6.510155952896329e8, -200.01871759739583, 107.45945117271157, -579.8857903941822, 2.2254074903412387e7, -9.193123740384564, 3.9898200148833704e8, -1.6775086908858556, 371.8707779489648, -50.65674284670836, 107198.61656083872, 1.5148174890084148e6, -0.0020129092986699054, -6.019687884413975e-12, 8.382925355697271e-5, 4.895916881583459e-12, 0.0002068309667750989, 2.4063616048099682e-8, 0.001302805507842194, 4.8218980022871345e-5, -0.0018898089753366174, -5.956819123448245e-7, 0.0069504265085979355, 0.008229558502845482, 9.905985853846503e-5, 4.14481029220916e-5, 0.005386598590721612, -1.6947560786713038e-5, 0.0001973291750177248, -0.0014853751559264195, -0.00013533007099318588, -7.676164759365769e-5, -5.317945519800476e-5, -0.0013875876954289959, -0.0004197826946574939, -0.003895401763819917, -0.002060359521383121, -2.1329316752522168e-7, -0.004842482446853321, -0.0019366248209785665, -4.903539343791784e-6, -0.00032014447232609584, -0.010491063662829791, -0.0022864447741598294, -0.0011221762291364582, -0.005515201267928384, -3.5757281075962954e-6, -0.0010322146707361008, -2.4901223595792135e-7, -0.000862145369358733, -8.226572863530094e-7, 0.0005460028418970843, 0.00017272708113702173, -0.0008015884639911653, -0.0002561303822940498, -8.061971740784452e-5, -0.004254578045748392, 0.00014810766168345247, -0.010747507432695314, -8.463803129554054e-5, 3.5328451149772463e-6, -178577.7840564104, -491515.3668710257, -3406.081995278741, -121202.28875719375, -544614.123204796, -30.21315520454095, 926485.4807104656, -465463.77468134294, -315694.7848000287, -43217.3449109291, -51.49314829614667, 4.8775035352683924e7, -2.2146742441410364e-5, -1.367975198174858e-13, -1.3521061665928313e-5, -3.676672528955571e-8, -2.535348361972927e-6, 8.78484393486553e-13, 0.002143788550220361, 0.0002305910170887085, 2.6759744664511304e-6, -315425.8670240166, -29655.634334714232, -155287.10560768208, -737467.3777072319] .* -1
# Parameter vector (which crashed MVE above)
p = [-2.787878787878788, -0.4545454545454546, 0.3333333333333335, -0.27272727272727293, -2.2727272727272725, 0.8787878787878789, 2.02020202020202, 5.545454545454545, -1.4242424242424243, 6.212121212121213, 0.11111111111111116, 0.4545454545454546, -1.606060606060606, 0.09090909090909083, -0.8181818181818183, -2.757575757575758, -2.1515151515151514, 0.21212121212121193, -0.21212121212121193, -1.5858585858585859, -0.9393939393939394, -2.4545454545454546, 1.4242424242424239, -2.9393939393939394, 1.121212121212121, -1.8484848484848484, -1.1818181818181819, -3.0, 2.090909090909091, -0.5151515151515151, 1.9696969696969697, -2.1515151515151514, 1.9090909090909092, -2.0303030303030303, 0.5757575757575757, -0.030303030303030276, 2.8181818181818183, -1.606060606060606, -2.090909090909091, 1.7272727272727275, -1.121212121212121, -0.9393939393939394, -2.4545454545454546, -1.2424242424242424, 1.1818181818181817, -2.6363636363636362, 1.666666666666667, 2.878787878787879, 0.21212121212121193, -1.3636363636363635, -2.696969696969697, -0.030303030303030276, -0.27272727272727293, 1.8484848484848486, -2.090909090909091, -1.303030303030303, 0.8181818181818183, -0.4545454545454546, -1.0, -0.8181818181818183, 0.6969696969696968, 0.15151515151515138, 2.6969696969696972, 1.8484848484848486, -2.757575757575758, 1.7272727272727275, 0.15151515151515138, -2.696969696969697, 0.21212121212121193, 1.7272727272727275, 2.6969696969696972, 2.2727272727272725, 2.5151515151515156, -0.09090909090909083, 0.3333333333333335, -2.5757575757575757, 0.21212121212121193, -1.8484848484848484, -1.0, -1.606060606060606, -0.030303030303030276, -0.030303030303030276, -1.3636363636363635, 2.090909090909091, -2.2727272727272725, 2.757575757575758, -0.6363636363636362, -1.7272727272727273, 1.0606060606060606, 2.5151515151515156, -2.090909090909091, 2.6363636363636367, 2.212121212121212, -0.6363636363636362, -1.4242424242424243, 1.1818181818181817, 1.9696969696969697, -0.21212121212121193, -1.121212121212121, 0.9393939393939394, 1.3030303030303028, -1.6666666666666667, 0.030303030303030276, 1.3030303030303028, -0.21212121212121193, -0.8787878787878789, 0.7575757575757578, 0.030303030303030276, -2.212121212121212, 1.5454545454545459, -0.030303030303030276, 0.5757575757575757, 1.9090909090909092]
gradientForward = petabProblem.computeGradient(p)
gradientAdjoint = petabProblemAdjoint.computeGradient(p)
println(gradientForward')
println(gradientAdjoint')
println(gradAMICI')
# Gradient ForwardDiff
-2.91232e7 -26.7575 -8.138 -3.9796e6 -2.36292e7 -56575.4 -8.30025e5 -9.24997e-11 3.67433e-11 … -0.0021438 -0.000230591 -2.67598e-6 3.15426e5 29655.6 1.55287e5 7.37467e5
# Gradient Adjoint
-2.9123e7 -26.7575 -8.13944 -3.97922e6 -2.36267e7 -56568.0 -8.29981e5 -9.24998e-11 … -0.00214379 -0.000230591 -2.67597e-6 3.15426e5 29655.6 1.55287e5 7.37467e5
# Gradient AMICI
-2.9123e7 -26.7575 -8.13944 -3.97922e6 -2.36267e7 -56568.0 -8.29981e5 -9.24998e-11 … -0.00214379 -0.000230591 -2.67597e-6 3.15426e5 29655.6 1.55287e5 7.37467e5
To run the code you need the Bachmann folder.
And @frankschae to my knowledge AMICI only overwrites maxiter (to 1e6 for reverse pass), abstol, and reltol to 1e-8. Otherwise default Sundial values are used.
Further, decreasing dtmin solves the problem for the case above, however, if I use another randomly generated parameter vector (to mimic a multi-start optimization) the code crashes regardless of dtmin (as something appears to be off with Sundials interface), while AMICI computes a sufficiently accurate gradient.
p_crash = [1.2424242424242422, -2.3333333333333335, 2.3030303030303028, -2.515151515151515, 3.0, -0.4545454545454546, -2.3636363636363638, 2.545454545454546, 2.9393939393939394, 7.181818181818182, 2.090909090909091, -0.15151515151515138, 0.7575757575757578, -2.8181818181818183, -1.1818181818181819, -1.7878787878787878, 2.090909090909091, -2.1515151515151514, 2.8181818181818183, 2.8686868686868685, 0.6969696969696968, -2.3333333333333335, -2.0303030303030303, -1.0, -1.606060606060606, 1.1818181818181817, 1.4242424242424239, -1.6666666666666667, 2.9393939393939394, -0.8181818181818183, -1.4242424242424243, 0.21212121212121193, 1.1818181818181817, 2.454545454545454, 0.8181818181818183, -1.606060606060606, 2.2727272727272725, 1.121212121212121, 2.9393939393939394, 0.09090909090909083, 1.121212121212121, 1.6060606060606064, -2.393939393939394, -2.515151515151515, -0.3333333333333335, 0.8787878787878789, -2.878787878787879, -0.3333333333333335, -2.4545454545454546, -0.8787878787878789, 2.3939393939393936, -1.4242424242424243, 1.6060606060606064, 0.9393939393939394, 0.39393939393939403, -1.3636363636363635, -2.4545454545454546, 0.4545454545454546, 2.9393939393939394, -2.4545454545454546, 0.21212121212121193, -2.393939393939394, -0.4545454545454546, -1.5454545454545454, 1.4848484848484844, 2.333333333333333, -1.303030303030303, -0.8181818181818183, -0.15151515151515138, 2.090909090909091, -0.39393939393939403, 2.5151515151515156, -2.757575757575758, 0.030303030303030276, -2.3333333333333335, -0.39393939393939403, -0.030303030303030276, 1.3636363636363633, 0.3333333333333335, -1.9090909090909092, -1.7272727272727273, 0.7575757575757578, 1.0, 2.757575757575758, 0.5757575757575757, 0.5151515151515151, -1.4848484848484849, -2.4545454545454546, -1.1818181818181819, -1.606060606060606, -1.8484848484848484, -2.2727272727272725, -1.0606060606060606, -2.393939393939394, -2.0303030303030303, -0.39393939393939403, -2.5757575757575757, -1.7878787878787878, 1.7272727272727275, 1.787878787878788, -1.6666666666666667, 1.666666666666667, 0.27272727272727293, 0.8787878787878789, -0.5757575757575757, -2.757575757575758, -1.303030303030303, 2.5151515151515156, 2.878787878787879, -0.030303030303030276, 2.454545454545454, 2.0303030303030303, 0.6969696969696968]
gradientForward = petabProblem.computeGradient(p_crash)
gradientAdjoint = petabProblemAdjoint.computeGradient(p_crash)
# Gradient ForwardDiff
-41276.3 0.000180495 0.000182287 -1.96312e5 -2.09225 -5318.03 18645.5 28.5328 … -5.18931e-12 -0.00908472 4.26499 -2.48366 4.30205e-6 12.5726 8.35768 5.44818
# Gradient Adjoint
# Fails
# Gradient AMICI
-41276.3 0.000180495 0.000182288 -1.96312e5 -2.09745 -5318.03 18645.5 28.5323 … -5.18931e-12 -0.00908472 4.26499 -2.48366 4.30205e-6 12.5726 8.35768 5.44818
And I receive the error message (heavily truncated)
[CVODES WARNING] CVode
Internal t = 240 and h = -1.31973e-14 are such that t + h = t on the next step. The solver will continue anyway.
ERROR: BoundsError: attempt to access 14-element Vector{Float64} at index [0]
The above computation failure can also be described by the following MVE with similar error message (I have isolated what crashes the code):
using ModelingToolkit
using OrdinaryDiffEq
using Sundials
using SciMLSensitivity
# Autogenerated from PEtab.jl
function getODEModel_Bachmann_MSB2011()
# Model name: Bachmann_MSB2011
# Number of parameters: 39
# Number of species: 25
### Define independent and dependent variables
ModelingToolkit.@variables t p1EpoRpJAK2(t) pSTAT5(t) EpoRJAK2_CIS(t) SOCS3nRNA4(t) SOCS3RNA(t) SHP1(t) STAT5(t) EpoRJAK2(t) CISnRNA1(t) SOCS3nRNA1(t) SOCS3nRNA2(t) CISnRNA3(t) CISnRNA4(t) SOCS3(t) CISnRNA5(t) SOCS3nRNA5(t) SOCS3nRNA3(t) SHP1Act(t) npSTAT5(t) p12EpoRpJAK2(t) p2EpoRpJAK2(t) CIS(t) EpoRpJAK2(t) CISnRNA2(t) CISRNA(t)
### Store dependent variables in array for ODESystem command
stateArray = [p1EpoRpJAK2, pSTAT5, EpoRJAK2_CIS, SOCS3nRNA4, SOCS3RNA, SHP1, STAT5, EpoRJAK2, CISnRNA1, SOCS3nRNA1, SOCS3nRNA2, CISnRNA3, CISnRNA4, SOCS3, CISnRNA5, SOCS3nRNA5, SOCS3nRNA3, SHP1Act, npSTAT5, p12EpoRpJAK2, p2EpoRpJAK2, CIS, EpoRpJAK2, CISnRNA2, CISRNA]
### Define variable parameters
### Define potential algebraic variables
### Define parameters
ModelingToolkit.@parameters SOCS3RNATurn STAT5Imp SOCS3Eqc EpoRCISRemove STAT5ActEpoR SHP1ActEpoR JAK2EpoRDeaSHP1 CISTurn SOCS3Turn init_EpoRJAK2_CIS SOCS3Inh ActD init_CIS_multiplier cyt CISRNAEqc JAK2ActEpo Epo SOCS3oe CISInh SHP1Dea SOCS3EqcOE CISRNADelay init_SHP1 CISEqcOE EpoRActJAK2 SOCS3RNAEqc CISEqc SHP1ProOE SOCS3RNADelay init_STAT5 CISoe CISRNATurn init_SHP1_multiplier init_EpoRJAK2 nuc EpoRCISInh STAT5ActJAK2 STAT5Exp init_SOCS3_multiplier
### Store parameters in array for ODESystem command
parameterArray = [SOCS3RNATurn, STAT5Imp, SOCS3Eqc, EpoRCISRemove, STAT5ActEpoR, SHP1ActEpoR, JAK2EpoRDeaSHP1, CISTurn, SOCS3Turn, init_EpoRJAK2_CIS, SOCS3Inh, ActD, init_CIS_multiplier, cyt, CISRNAEqc, JAK2ActEpo, Epo, SOCS3oe, CISInh, SHP1Dea, SOCS3EqcOE, CISRNADelay, init_SHP1, CISEqcOE, EpoRActJAK2, SOCS3RNAEqc, CISEqc, SHP1ProOE, SOCS3RNADelay, init_STAT5, CISoe, CISRNATurn, init_SHP1_multiplier, init_EpoRJAK2, nuc, EpoRCISInh, STAT5ActJAK2, STAT5Exp, init_SOCS3_multiplier]
### Define an operator for the differentiation w.r.t. time
D = Differential(t)
### Continious events ###
### Discrete events ###
### Derivatives ###
eqs = [
D(p1EpoRpJAK2) ~ -1.0 * ( 1 /cyt ) * (cyt*(((JAK2EpoRDeaSHP1*SHP1Act)*p1EpoRpJAK2)/init_SHP1))+1.0 * ( 1 /cyt ) * (cyt*((EpoRpJAK2*EpoRActJAK2)/(((SOCS3*SOCS3Inh)/SOCS3Eqc)+1)))-1.0 * ( 1 /cyt ) * (cyt*(((3*EpoRActJAK2)*p1EpoRpJAK2)/((((SOCS3*SOCS3Inh)/SOCS3Eqc)+1)*((EpoRCISInh*EpoRJAK2_CIS)+1)))),
D(pSTAT5) ~ -1.0 * ( 1 /cyt ) * ((cyt*STAT5Imp)*pSTAT5)+1.0 * ( 1 /cyt ) * (cyt*(((STAT5*STAT5ActEpoR)*((p12EpoRpJAK2+p1EpoRpJAK2)^2))/(((init_EpoRJAK2^2)*(((SOCS3*SOCS3Inh)/SOCS3Eqc)+1))*(((CIS*CISInh)/CISEqc)+1))))+1.0 * ( 1 /cyt ) * (cyt*(((STAT5*STAT5ActJAK2)*(((EpoRpJAK2+p12EpoRpJAK2)+p1EpoRpJAK2)+p2EpoRpJAK2))/(init_EpoRJAK2*(((SOCS3*SOCS3Inh)/SOCS3Eqc)+1)))),
D(EpoRJAK2_CIS) ~ -1.0 * ( 1 /cyt ) * (cyt*(((EpoRJAK2_CIS*EpoRCISRemove)*(p12EpoRpJAK2+p1EpoRpJAK2))/init_EpoRJAK2)),
D(SOCS3nRNA4) ~ +1.0 * ( 1 /nuc ) * ((nuc*SOCS3nRNA3)*SOCS3RNADelay)-1.0 * ( 1 /nuc ) * ((nuc*SOCS3nRNA4)*SOCS3RNADelay),
D(SOCS3RNA) ~ +1.0 * ( 1 /cyt ) * ((nuc*SOCS3nRNA5)*SOCS3RNADelay)-1.0 * ( 1 /cyt ) * ((cyt*SOCS3RNA)*SOCS3RNATurn),
D(SHP1) ~ -1.0 * ( 1 /cyt ) * (cyt*(((SHP1*SHP1ActEpoR)*(((EpoRpJAK2+p12EpoRpJAK2)+p1EpoRpJAK2)+p2EpoRpJAK2))/init_EpoRJAK2))+1.0 * ( 1 /cyt ) * ((cyt*SHP1Dea)*SHP1Act),
D(STAT5) ~ +1.0 * ( 1 /cyt ) * ((nuc*STAT5Exp)*npSTAT5)-1.0 * ( 1 /cyt ) * (cyt*(((STAT5*STAT5ActEpoR)*((p12EpoRpJAK2+p1EpoRpJAK2)^2))/(((init_EpoRJAK2^2)*(((SOCS3*SOCS3Inh)/SOCS3Eqc)+1))*(((CIS*CISInh)/CISEqc)+1))))-1.0 * ( 1 /cyt ) * (cyt*(((STAT5*STAT5ActJAK2)*(((EpoRpJAK2+p12EpoRpJAK2)+p1EpoRpJAK2)+p2EpoRpJAK2))/(init_EpoRJAK2*(((SOCS3*SOCS3Inh)/SOCS3Eqc)+1)))),
D(EpoRJAK2) ~ +1.0 * ( 1 /cyt ) * (cyt*(((JAK2EpoRDeaSHP1*SHP1Act)*p12EpoRpJAK2)/init_SHP1))+1.0 * ( 1 /cyt ) * (cyt*(((JAK2EpoRDeaSHP1*SHP1Act)*p2EpoRpJAK2)/init_SHP1))+1.0 * ( 1 /cyt ) * (cyt*(((JAK2EpoRDeaSHP1*SHP1Act)*p1EpoRpJAK2)/init_SHP1))+1.0 * ( 1 /cyt ) * (cyt*(((EpoRpJAK2*JAK2EpoRDeaSHP1)*SHP1Act)/init_SHP1))-1.0 * ( 1 /cyt ) * (cyt*(((Epo*EpoRJAK2)*JAK2ActEpo)/(((SOCS3*SOCS3Inh)/SOCS3Eqc)+1))),
D(CISnRNA1) ~ -1.0 * ( 1 /nuc ) * ((nuc*CISnRNA1)*CISRNADelay)+1.0 * ( 1 /nuc ) * (nuc*((((CISRNAEqc*CISRNATurn)*npSTAT5)*ActD)/init_STAT5)),
D(SOCS3nRNA1) ~ +1.0 * ( 1 /nuc ) * (nuc*((((SOCS3RNAEqc*SOCS3RNATurn)*npSTAT5)*ActD)/init_STAT5))-1.0 * ( 1 /nuc ) * ((nuc*SOCS3nRNA1)*SOCS3RNADelay),
D(SOCS3nRNA2) ~ +1.0 * ( 1 /nuc ) * ((nuc*SOCS3nRNA1)*SOCS3RNADelay)-1.0 * ( 1 /nuc ) * ((nuc*SOCS3nRNA2)*SOCS3RNADelay),
D(CISnRNA3) ~ -1.0 * ( 1 /nuc ) * ((nuc*CISnRNA3)*CISRNADelay)+1.0 * ( 1 /nuc ) * ((nuc*CISnRNA2)*CISRNADelay),
D(CISnRNA4) ~ -1.0 * ( 1 /nuc ) * ((nuc*CISnRNA4)*CISRNADelay)+1.0 * ( 1 /nuc ) * ((nuc*CISnRNA3)*CISRNADelay),
D(SOCS3) ~ +1.0 * ( 1 /cyt ) * (cyt*(((SOCS3RNA*SOCS3Eqc)*SOCS3Turn)/SOCS3RNAEqc))+1.0 * ( 1 /cyt ) * ((((cyt*SOCS3oe)*SOCS3Eqc)*SOCS3Turn)*SOCS3EqcOE)-1.0 * ( 1 /cyt ) * ((cyt*SOCS3)*SOCS3Turn),
D(CISnRNA5) ~ +1.0 * ( 1 /nuc ) * ((nuc*CISnRNA4)*CISRNADelay)-1.0 * ( 1 /nuc ) * ((nuc*CISnRNA5)*CISRNADelay),
D(SOCS3nRNA5) ~ -1.0 * ( 1 /nuc ) * ((nuc*SOCS3nRNA5)*SOCS3RNADelay)+1.0 * ( 1 /nuc ) * ((nuc*SOCS3nRNA4)*SOCS3RNADelay),
D(SOCS3nRNA3) ~ -1.0 * ( 1 /nuc ) * ((nuc*SOCS3nRNA3)*SOCS3RNADelay)+1.0 * ( 1 /nuc ) * ((nuc*SOCS3nRNA2)*SOCS3RNADelay),
D(SHP1Act) ~ +1.0 * ( 1 /cyt ) * (cyt*(((SHP1*SHP1ActEpoR)*(((EpoRpJAK2+p12EpoRpJAK2)+p1EpoRpJAK2)+p2EpoRpJAK2))/init_EpoRJAK2))-1.0 * ( 1 /cyt ) * ((cyt*SHP1Dea)*SHP1Act),
D(npSTAT5) ~ +1.0 * ( 1 /nuc ) * ((cyt*STAT5Imp)*pSTAT5)-1.0 * ( 1 /nuc ) * ((nuc*STAT5Exp)*npSTAT5),
D(p12EpoRpJAK2) ~ -1.0 * ( 1 /cyt ) * (cyt*(((JAK2EpoRDeaSHP1*SHP1Act)*p12EpoRpJAK2)/init_SHP1))+1.0 * ( 1 /cyt ) * (cyt*((EpoRActJAK2*p2EpoRpJAK2)/(((SOCS3*SOCS3Inh)/SOCS3Eqc)+1)))+1.0 * ( 1 /cyt ) * (cyt*(((3*EpoRActJAK2)*p1EpoRpJAK2)/((((SOCS3*SOCS3Inh)/SOCS3Eqc)+1)*((EpoRCISInh*EpoRJAK2_CIS)+1)))),
D(p2EpoRpJAK2) ~ -1.0 * ( 1 /cyt ) * (cyt*(((JAK2EpoRDeaSHP1*SHP1Act)*p2EpoRpJAK2)/init_SHP1))-1.0 * ( 1 /cyt ) * (cyt*((EpoRActJAK2*p2EpoRpJAK2)/(((SOCS3*SOCS3Inh)/SOCS3Eqc)+1)))+1.0 * ( 1 /cyt ) * (cyt*(((3*EpoRpJAK2)*EpoRActJAK2)/((((SOCS3*SOCS3Inh)/SOCS3Eqc)+1)*((EpoRCISInh*EpoRJAK2_CIS)+1)))),
D(CIS) ~ +1.0 * ( 1 /cyt ) * ((((cyt*CISEqc)*CISTurn)*CISEqcOE)*CISoe)+1.0 * ( 1 /cyt ) * (cyt*(((CISRNA*CISEqc)*CISTurn)/CISRNAEqc))-1.0 * ( 1 /cyt ) * ((cyt*CIS)*CISTurn),
D(EpoRpJAK2) ~ -1.0 * ( 1 /cyt ) * (cyt*((EpoRpJAK2*EpoRActJAK2)/(((SOCS3*SOCS3Inh)/SOCS3Eqc)+1)))-1.0 * ( 1 /cyt ) * (cyt*(((3*EpoRpJAK2)*EpoRActJAK2)/((((SOCS3*SOCS3Inh)/SOCS3Eqc)+1)*((EpoRCISInh*EpoRJAK2_CIS)+1))))-1.0 * ( 1 /cyt ) * (cyt*(((EpoRpJAK2*JAK2EpoRDeaSHP1)*SHP1Act)/init_SHP1))+1.0 * ( 1 /cyt ) * (cyt*(((Epo*EpoRJAK2)*JAK2ActEpo)/(((SOCS3*SOCS3Inh)/SOCS3Eqc)+1))),
D(CISnRNA2) ~ +1.0 * ( 1 /nuc ) * ((nuc*CISnRNA1)*CISRNADelay)-1.0 * ( 1 /nuc ) * ((nuc*CISnRNA2)*CISRNADelay),
D(CISRNA) ~ -1.0 * ( 1 /cyt ) * ((cyt*CISRNA)*CISRNATurn)+1.0 * ( 1 /cyt ) * ((nuc*CISnRNA5)*CISRNADelay)
]
@named sys = ODESystem(eqs, t, stateArray, parameterArray)
### Initial species concentrations ###
initialSpeciesValues = [
p1EpoRpJAK2 => 0.0,
pSTAT5 => 0.0,
EpoRJAK2_CIS => init_EpoRJAK2_CIS,
SOCS3nRNA4 => 0.0,
SOCS3RNA => 0.0,
SHP1 => init_SHP1*((init_SHP1_multiplier*SHP1ProOE)+1),
STAT5 => init_STAT5,
EpoRJAK2 => init_EpoRJAK2,
CISnRNA1 => 0.0,
SOCS3nRNA1 => 0.0,
SOCS3nRNA2 => 0.0,
CISnRNA3 => 0.0,
CISnRNA4 => 0.0,
SOCS3 => (init_SOCS3_multiplier*SOCS3EqcOE)*SOCS3Eqc,
CISnRNA5 => 0.0,
SOCS3nRNA5 => 0.0,
SOCS3nRNA3 => 0.0,
SHP1Act => 0.0,
npSTAT5 => 0.0,
p12EpoRpJAK2 => 0.0,
p2EpoRpJAK2 => 0.0,
CIS => (init_CIS_multiplier*CISEqc)*CISEqcOE,
EpoRpJAK2 => 0.0,
CISnRNA2 => 0.0,
CISRNA => 0.0
]
### SBML file parameter values ###
trueParameterValues = [
SOCS3RNATurn => 0.00830917643120369,
STAT5Imp => 0.0268865083829685,
SOCS3Eqc => 173.64470023136,
EpoRCISRemove => 5.42980693903448,
STAT5ActEpoR => 38.9957991073948,
SHP1ActEpoR => 0.00100000000000006,
JAK2EpoRDeaSHP1 => 142.72332309738,
CISTurn => 0.0083988695167017,
SOCS3Turn => 9999.99999999912,
init_EpoRJAK2_CIS => 0.0,
SOCS3Inh => 10.4078649133666,
ActD => 1.25e-7,
init_CIS_multiplier => 0.0,
cyt => 0.4,
CISRNAEqc => 1.0,
JAK2ActEpo => 633167.430600806,
Epo => 1.25e-7,
SOCS3oe => 1.25e-7,
CISInh => 7.85269991450496e8,
SHP1Dea => 0.00816220490950374,
SOCS3EqcOE => 0.679165515556864,
CISRNADelay => 0.14477775532111,
init_SHP1 => 26.7251164277109,
CISEqcOE => 0.530264447119609,
EpoRActJAK2 => 0.267304849333058,
SOCS3RNAEqc => 1.0,
CISEqc => 432.860413434913,
SHP1ProOE => 2.82568153411555,
SOCS3RNADelay => 1.06458446742251,
init_STAT5 => 79.75363993771,
CISoe => 1.25e-7,
CISRNATurn => 999.999999999946,
init_SHP1_multiplier => 1.0,
init_EpoRJAK2 => 3.97622369384192,
nuc => 0.275,
EpoRCISInh => 999999.999999912,
STAT5ActJAK2 => 0.0781068855795467,
STAT5Exp => 0.0745150819016423,
init_SOCS3_multiplier => 0.0
]
return sys, initialSpeciesValues, trueParameterValues
end
function compute∂G∂u(out, u, p, t, i)
dataObserved = [0.1021160220994476]
σ = 0.037649358067924674
h = 0.001747528400007683 + 1.072267222010323 * ( 2.0 * u[23] + 2.0 * u[20] + 2.0 * u[1] + 2.0 * u[21] ) / p[34]
∂h∂u = zeros(length(u))
∂h∂u[1] = (2.0*1.072267222010323) / p[34]
∂h∂u[20] = (2.0*1.072267222010323) / p[34]
∂h∂u[21] = (2.0*1.072267222010323) / p[34]
∂h∂u[23] = (2.0*1.072267222010323) / p[34]
out .= 0.0
for i in eachindex(dataObserved)
∂h∂u .*= (1 / (log(10) * h)) * (log10(exp10(h)) - log10(dataObserved[i])) / σ^2
out .+= ∂h∂u
end
end
sys, stateMap, parameterMap = getODEModel_Bachmann_MSB2011()
odeProblem = ODEProblem{true, SciMLBase.FullSpecialize}(sys, stateMap, [0.0, 250.0], parameterMap, jac=true)
# Parameter vector and initial value vector that crashes
p = [657.9332246575682, 0.1, 0.0657933224657568, 869.7490026177834, 4.977023564332109, 0.7054802310718645, 123.28467394420659, 0.3511191734215131, 739.0722033525775, 0.0, 123.28467394420659, 1.0, 0.0, 0.4, 1.0, 1.5199110829529332e7, 1.25e-7, 0.0, 200.92330025650458, 5.72236765935022, 0.016297508346206444, 0.0030538555088334154, 15.199110829529332, 0.004641588833612777, 0.004328761281083057, 1.0, 17.47528400007683, 0.0015199110829529332, 0.007054802310718645, 26.56087782946684, 0.0, 1000.0, 0.0, 0.024770763559917114, 0.275, 351.11917342151344, 0.004641588833612777, 0.0093260334688322, 0.0]
u0 = [0.0, 0.0, 0.0, 0.0, 0.0, 15.199110829529332, 26.56087782946684, 0.024770763559917114, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
odeProblem.p .= p
odeProblem.u0 .= u0
solForward = solve(odeProblem, CVODE_BDF(), abstol=1e-8, reltol=1e-8)
# Here the code crashes
du, dp = adjoint_sensitivities(solForward, CVODE_BDF(), dgdu_discrete=compute∂G∂u, t=[240.0],
sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP()),
abstol=1e-8, reltol=1e-8, dtmin=1e-14)
TLDR; Decreasing dtmin can prevent some integration failures, however, there are still cases where AMICI can compute the gradient while SciMLSensitivity fails. Moreover, I think there is a bug in the Sundial's interface as upon failure the code crashes.
I used Julia 1.9.2 above.
Moreover, I think there is a bug in the Sundial's interface as upon failure the code crashes.
Did you check if AMICI and Sundials.jl use the same Sundials version? What happens for other (maybe easier-to-solve, smaller) ODEs; Do you get the exact same gradient value?
AMICI uses Sundials version 5.8.0 and Sundials.jl uses what I assume is the latest (6.6).
For smaller easier to solve ODE:s SciMLSensitivity encounters fewer integration failures (but still more than AMICI). If it would help I could try to put together a MVE for a smaller 8 ODE:s stiff ODE model?
No we currently are using Sundials v5.2.
If it would help I could try to put together a MVE for a smaller 8 ODE:s stiff ODE model?
Yes, that would be immensely helpful. Right now there is nothing actionable.
Also, can you try QuadratureAdjoint instead?
I have now put together a MVE for the smaller stiff Boehm model.
using ModelingToolkit
using OrdinaryDiffEq
using Sundials
using SciMLSensitivity
# Autogenerated by PEtab.jl
function getODEModel_Boehm_JProteomeRes2014()
# Model name: Boehm_JProteomeRes2014
# Number of parameters: 10
# Number of species: 8
### Define independent and dependent variables
ModelingToolkit.@variables t STAT5A(t) pApA(t) nucpApB(t) nucpBpB(t) STAT5B(t) pApB(t) nucpApA(t) pBpB(t) BaF3_Epo(t)
### Store dependent variables in array for ODESystem command
stateArray = [STAT5A, pApA, nucpApB, nucpBpB, STAT5B, pApB, nucpApA, pBpB, BaF3_Epo]
### Define variable parameters
### Define potential algebraic variables
### Define parameters
ModelingToolkit.@parameters specC17 Epo_degradation_BaF3 k_exp_homo k_phos cyt ratio nuc k_imp_homo k_imp_hetero k_exp_hetero
### Store parameters in array for ODESystem command
parameterArray = [specC17, Epo_degradation_BaF3, k_exp_homo, k_phos, cyt, ratio, nuc, k_imp_homo, k_imp_hetero, k_exp_hetero]
### Define an operator for the differentiation w.r.t. time
D = Differential(t)
### Continious events ###
### Discrete events ###
### Derivatives ###
eqs = [
D(STAT5A) ~ -2.0 * ( 1 /cyt ) * (((cyt*BaF3_Epo)*(STAT5A^2))*k_phos)+1.0 * ( 1 /cyt ) * ((nuc*k_exp_hetero)*nucpApB)+2.0 * ( 1 /cyt ) * ((nuc*k_exp_homo)*nucpApA)-1.0 * ( 1 /cyt ) * ((((cyt*BaF3_Epo)*STAT5A)*STAT5B)*k_phos),
D(pApA) ~ +1.0 * ( 1 /cyt ) * (((cyt*BaF3_Epo)*(STAT5A^2))*k_phos)-1.0 * ( 1 /cyt ) * ((cyt*k_imp_homo)*pApA),
D(nucpApB) ~ +1.0 * ( 1 /nuc ) * ((cyt*k_imp_hetero)*pApB)-1.0 * ( 1 /nuc ) * ((nuc*k_exp_hetero)*nucpApB),
D(nucpBpB) ~ -1.0 * ( 1 /nuc ) * ((nuc*k_exp_homo)*nucpBpB)+1.0 * ( 1 /nuc ) * ((cyt*k_imp_homo)*pBpB),
D(STAT5B) ~ -2.0 * ( 1 /cyt ) * (((cyt*BaF3_Epo)*(STAT5B^2))*k_phos)+2.0 * ( 1 /cyt ) * ((nuc*k_exp_homo)*nucpBpB)+1.0 * ( 1 /cyt ) * ((nuc*k_exp_hetero)*nucpApB)-1.0 * ( 1 /cyt ) * ((((cyt*BaF3_Epo)*STAT5A)*STAT5B)*k_phos),
D(pApB) ~ -1.0 * ( 1 /cyt ) * ((cyt*k_imp_hetero)*pApB)+1.0 * ( 1 /cyt ) * ((((cyt*BaF3_Epo)*STAT5A)*STAT5B)*k_phos),
D(nucpApA) ~ +1.0 * ( 1 /nuc ) * ((cyt*k_imp_homo)*pApA)-1.0 * ( 1 /nuc ) * ((nuc*k_exp_homo)*nucpApA),
D(pBpB) ~ +1.0 * ( 1 /cyt ) * (((cyt*BaF3_Epo)*(STAT5B^2))*k_phos)-1.0 * ( 1 /cyt ) * ((cyt*k_imp_homo)*pBpB),
BaF3_Epo ~ 1.25e-7*exp((-1*Epo_degradation_BaF3)*t)
]
@named sys = ODESystem(eqs, t, stateArray, parameterArray)
### Initial species concentrations ###
initialSpeciesValues = [
STAT5A => 207.6*ratio,
pApA => 0.0,
nucpApB => 0.0,
nucpBpB => 0.0,
STAT5B => 207.6-(207.6*ratio),
pApB => 0.0,
nucpApA => 0.0,
pBpB => 0.0,
BaF3_Epo => 1.25e-7
]
### SBML file parameter values ###
trueParameterValues = [
specC17 => 0.107,
Epo_degradation_BaF3 => 0.0269738286367359,
k_exp_homo => 0.00617193081581346,
k_phos => 15766.8336642826,
cyt => 1.4,
ratio => 0.693,
nuc => 0.45,
k_imp_homo => 96945.5391768823,
k_imp_hetero => 0.0163708512310568,
k_exp_hetero => 1.00094251286741e-5
]
return sys, initialSpeciesValues, trueParameterValues
end
function compute∂G∂u(out, u, pODEProblem, t, i)
dataObserved = [32.2110771608676]
# Measurement noise to be estimated
σ = exp10(-4.1591591591591595)
# Observation model
h = ( 100.0 * u[6] + 100.0 * u[1] * pODEProblem[1] + 200.0 * u[2] * pODEProblem[1] ) / ( 2.0 * u[6] + u[1] * pODEProblem[1] + 2.0 * u[2] * pODEProblem[1] - u[5] * ( pODEProblem[1] - 1.0 ) - 2.0 * u[8] * ( pODEProblem[1] - 1.0 ) )
# Symbolically computed ∂h∂u
∂h∂u = zeros(length(u))
∂h∂u[1] = (100.0u[5]*pODEProblem[1] + 100.0u[6]*pODEProblem[1] + 200.0u[8]*pODEProblem[1] - 100.0u[5]*(pODEProblem[1]^2) - 200.0u[8]*(pODEProblem[1]^2)) / ((u[5] + 2.0u[6] + 2.0u[8] + u[1]*pODEProblem[1] + 2.0u[2]*pODEProblem[1] - u[5]*pODEProblem[1] - 2.0u[8]*pODEProblem[1])^2)
∂h∂u[2] = (200.0u[5]*pODEProblem[1] + 200.0u[6]*pODEProblem[1] + 400.0u[8]*pODEProblem[1] - 200.0u[5]*(pODEProblem[1]^2) - 400.0u[8]*(pODEProblem[1]^2)) / ((u[5] + 2.0u[6] + 2.0u[8] + u[1]*pODEProblem[1] + 2.0u[2]*pODEProblem[1] - u[5]*pODEProblem[1] - 2.0u[8]*pODEProblem[1])^2)
∂h∂u[5] = ((1.0 - pODEProblem[1])*(-100.0u[6] - 100.0u[1]*pODEProblem[1] - 200.0u[2]*pODEProblem[1])) / ((u[5] + 2.0u[6] + 2.0u[8] + u[1]*pODEProblem[1] + 2.0u[2]*pODEProblem[1] - u[5]*pODEProblem[1] - 2.0u[8]*pODEProblem[1])^2)
∂h∂u[6] = (100.0u[5] + 200.0u[8] - 100.0u[1]*pODEProblem[1] - 100.0u[5]*pODEProblem[1] - 200.0u[2]*pODEProblem[1] - 200.0u[8]*pODEProblem[1]) / ((u[5] + 2.0u[6] + 2.0u[8] + u[1]*pODEProblem[1] + 2.0u[2]*pODEProblem[1] - u[5]*pODEProblem[1] - 2.0u[8]*pODEProblem[1])^2)
∂h∂u[8] = (2.0(pODEProblem[1] - 1.0)*(100.0u[6] + 100.0u[1]*pODEProblem[1] + 200.0u[2]*pODEProblem[1])) / ((u[5] + 2.0u[6] + 2.0u[8] + u[1]*pODEProblem[1] + 2.0u[2]*pODEProblem[1] - u[5]*pODEProblem[1] - 2.0u[8]*pODEProblem[1])^2)
out .= 0.0
for i in eachindex(dataObserved)
∂h∂u .*= (h - dataObserved[i]) / σ^2
out .+= ∂h∂u
end
end
sys, stateMap, parameterMap = getODEModel_Boehm_JProteomeRes2014()
odeProblem = ODEProblem{true, SciMLBase.FullSpecialize}(structural_simplify(sys), stateMap, [0.0, 250.0], parameterMap, jac=true)
# Parameter vector and initial value vector that crashes
odeProblem.p .= [0.107, 56201.73848083188, 6438.857427240426, 7.959777002314978e-5, 1.4, 0.693, 0.45, 5737.976414214139, 0.6995920165435374, 0.006350425168595962]
odeProblem.u0 .= [143.86679999999998, 0.0, 0.0, 0.0, 63.73320000000001, 0.0, 0.0, 0.0]
solForward = solve(odeProblem, CVODE_BDF(), abstol=1e-8, reltol=1e-8)
# Here the code crashes
du, dp = adjoint_sensitivities(solForward, CVODE_BDF(), dgdu_discrete=compute∂G∂u, t=[240.0],
sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP()),
abstol=1e-8, reltol=1e-8, dtmin=1e-14)
Which for this random parameter vectors outputs (heavily truncated)
[CVODES WARNING] CVode
Internal t = 240 and h = -1e-14 are such that t + h = t on the next step. The solver will continue anyway.
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [0]
This results holds for InterpolatingAdjoint and QuadratureAdjoint regardless of VJP choice. For the results above (Bachmann model) QuadratureAdjoint also fails for the same parameter vector (and it fails with the same error message).
Are you sure about the tolerance chances and the choice of norms? I cannot see how they are the same. See this code which works:
using ModelingToolkit
using OrdinaryDiffEq
using Sundials
using SciMLSensitivity
# Autogenerated by PEtab.jl
function getODEModel_Boehm_JProteomeRes2014()
# Model name: Boehm_JProteomeRes2014
# Number of parameters: 10
# Number of species: 8
### Define independent and dependent variables
ModelingToolkit.@variables t STAT5A(t) pApA(t) nucpApB(t) nucpBpB(t) STAT5B(t) pApB(t) nucpApA(t) pBpB(t) BaF3_Epo(t)
### Store dependent variables in array for ODESystem command
stateArray = [STAT5A, pApA, nucpApB, nucpBpB, STAT5B, pApB, nucpApA, pBpB, BaF3_Epo]
### Define variable parameters
### Define potential algebraic variables
### Define parameters
ModelingToolkit.@parameters specC17 Epo_degradation_BaF3 k_exp_homo k_phos cyt ratio nuc k_imp_homo k_imp_hetero k_exp_hetero
### Store parameters in array for ODESystem command
parameterArray = [specC17, Epo_degradation_BaF3, k_exp_homo, k_phos, cyt, ratio, nuc, k_imp_homo, k_imp_hetero, k_exp_hetero]
### Define an operator for the differentiation w.r.t. time
D = Differential(t)
### Continious events ###
### Discrete events ###
### Derivatives ###
eqs = [
D(STAT5A) ~ -2.0 * ( 1 /cyt ) * (((cyt*BaF3_Epo)*(STAT5A^2))*k_phos)+1.0 * ( 1 /cyt ) * ((nuc*k_exp_hetero)*nucpApB)+2.0 * ( 1 /cyt ) * ((nuc*k_exp_homo)*nucpApA)-1.0 * ( 1 /cyt ) * ((((cyt*BaF3_Epo)*STAT5A)*STAT5B)*k_phos),
D(pApA) ~ +1.0 * ( 1 /cyt ) * (((cyt*BaF3_Epo)*(STAT5A^2))*k_phos)-1.0 * ( 1 /cyt ) * ((cyt*k_imp_homo)*pApA),
D(nucpApB) ~ +1.0 * ( 1 /nuc ) * ((cyt*k_imp_hetero)*pApB)-1.0 * ( 1 /nuc ) * ((nuc*k_exp_hetero)*nucpApB),
D(nucpBpB) ~ -1.0 * ( 1 /nuc ) * ((nuc*k_exp_homo)*nucpBpB)+1.0 * ( 1 /nuc ) * ((cyt*k_imp_homo)*pBpB),
D(STAT5B) ~ -2.0 * ( 1 /cyt ) * (((cyt*BaF3_Epo)*(STAT5B^2))*k_phos)+2.0 * ( 1 /cyt ) * ((nuc*k_exp_homo)*nucpBpB)+1.0 * ( 1 /cyt ) * ((nuc*k_exp_hetero)*nucpApB)-1.0 * ( 1 /cyt ) * ((((cyt*BaF3_Epo)*STAT5A)*STAT5B)*k_phos),
D(pApB) ~ -1.0 * ( 1 /cyt ) * ((cyt*k_imp_hetero)*pApB)+1.0 * ( 1 /cyt ) * ((((cyt*BaF3_Epo)*STAT5A)*STAT5B)*k_phos),
D(nucpApA) ~ +1.0 * ( 1 /nuc ) * ((cyt*k_imp_homo)*pApA)-1.0 * ( 1 /nuc ) * ((nuc*k_exp_homo)*nucpApA),
D(pBpB) ~ +1.0 * ( 1 /cyt ) * (((cyt*BaF3_Epo)*(STAT5B^2))*k_phos)-1.0 * ( 1 /cyt ) * ((cyt*k_imp_homo)*pBpB),
BaF3_Epo ~ 1.25e-7*exp((-1*Epo_degradation_BaF3)*t)
]
@named sys = ODESystem(eqs, t, stateArray, parameterArray)
### Initial species concentrations ###
initialSpeciesValues = [
STAT5A => 207.6*ratio,
pApA => 0.0,
nucpApB => 0.0,
nucpBpB => 0.0,
STAT5B => 207.6-(207.6*ratio),
pApB => 0.0,
nucpApA => 0.0,
pBpB => 0.0,
BaF3_Epo => 1.25e-7
]
### SBML file parameter values ###
trueParameterValues = [
specC17 => 0.107,
Epo_degradation_BaF3 => 0.0269738286367359,
k_exp_homo => 0.00617193081581346,
k_phos => 15766.8336642826,
cyt => 1.4,
ratio => 0.693,
nuc => 0.45,
k_imp_homo => 96945.5391768823,
k_imp_hetero => 0.0163708512310568,
k_exp_hetero => 1.00094251286741e-5
]
return sys, initialSpeciesValues, trueParameterValues
end
function compute∂G∂u(out, u, pODEProblem, t, i)
dataObserved = [32.2110771608676]
# Measurement noise to be estimated
σ = exp10(-4.1591591591591595)
# Observation model
h = ( 100.0 * u[6] + 100.0 * u[1] * pODEProblem[1] + 200.0 * u[2] * pODEProblem[1] ) / ( 2.0 * u[6] + u[1] * pODEProblem[1] + 2.0 * u[2] * pODEProblem[1] - u[5] * ( pODEProblem[1] - 1.0 ) - 2.0 * u[8] * ( pODEProblem[1] - 1.0 ) )
# Symbolically computed ∂h∂u
∂h∂u = zeros(length(u))
∂h∂u[1] = (100.0u[5]*pODEProblem[1] + 100.0u[6]*pODEProblem[1] + 200.0u[8]*pODEProblem[1] - 100.0u[5]*(pODEProblem[1]^2) - 200.0u[8]*(pODEProblem[1]^2)) / ((u[5] + 2.0u[6] + 2.0u[8] + u[1]*pODEProblem[1] + 2.0u[2]*pODEProblem[1] - u[5]*pODEProblem[1] - 2.0u[8]*pODEProblem[1])^2)
∂h∂u[2] = (200.0u[5]*pODEProblem[1] + 200.0u[6]*pODEProblem[1] + 400.0u[8]*pODEProblem[1] - 200.0u[5]*(pODEProblem[1]^2) - 400.0u[8]*(pODEProblem[1]^2)) / ((u[5] + 2.0u[6] + 2.0u[8] + u[1]*pODEProblem[1] + 2.0u[2]*pODEProblem[1] - u[5]*pODEProblem[1] - 2.0u[8]*pODEProblem[1])^2)
∂h∂u[5] = ((1.0 - pODEProblem[1])*(-100.0u[6] - 100.0u[1]*pODEProblem[1] - 200.0u[2]*pODEProblem[1])) / ((u[5] + 2.0u[6] + 2.0u[8] + u[1]*pODEProblem[1] + 2.0u[2]*pODEProblem[1] - u[5]*pODEProblem[1] - 2.0u[8]*pODEProblem[1])^2)
∂h∂u[6] = (100.0u[5] + 200.0u[8] - 100.0u[1]*pODEProblem[1] - 100.0u[5]*pODEProblem[1] - 200.0u[2]*pODEProblem[1] - 200.0u[8]*pODEProblem[1]) / ((u[5] + 2.0u[6] + 2.0u[8] + u[1]*pODEProblem[1] + 2.0u[2]*pODEProblem[1] - u[5]*pODEProblem[1] - 2.0u[8]*pODEProblem[1])^2)
∂h∂u[8] = (2.0(pODEProblem[1] - 1.0)*(100.0u[6] + 100.0u[1]*pODEProblem[1] + 200.0u[2]*pODEProblem[1])) / ((u[5] + 2.0u[6] + 2.0u[8] + u[1]*pODEProblem[1] + 2.0u[2]*pODEProblem[1] - u[5]*pODEProblem[1] - 2.0u[8]*pODEProblem[1])^2)
out .= 0.0
for i in eachindex(dataObserved)
∂h∂u .*= (h - dataObserved[i]) / σ^2
out .+= ∂h∂u
end
end
sys, stateMap, parameterMap = getODEModel_Boehm_JProteomeRes2014()
odeProblem = ODEProblem{true, SciMLBase.FullSpecialize}(structural_simplify(sys), stateMap, [0.0, 250.0], parameterMap, jac=true)
# Parameter vector and initial value vector that crashes
odeProblem.p .= [0.107, 56201.73848083188, 6438.857427240426, 7.959777002314978e-5, 1.4, 0.693, 0.45, 5737.976414214139, 0.6995920165435374, 0.006350425168595962]
odeProblem.u0 .= [143.86679999999998, 0.0, 0.0, 0.0, 63.73320000000001, 0.0, 0.0, 0.0]
solForward = solve(odeProblem, FBDF(), abstol=1e-6, reltol=1e-6)
du, dp = adjoint_sensitivities(solForward, FBDF(), dgdu_discrete=compute∂G∂u, t=[240.0],
sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP()),
abstol=1e-6, reltol=1e-6)
maximum(du)
maximum(dp)
Works just fine for us. But the interesting fact is at the bottom:
julia> maximum(du)
1.1952849995050876e9
julia> maximum(dp)
3161.4134643580383
julia> 1e9/1e-6
1.0e15
julia> eps(1.1952849995050876e9)
2.384185791015625e-7
Since there are values of 1e9, the absolute minimum you can resolve is 2e-7. That means it's effectively impossible to compute to an absolute tolerance of 1e-8 for this problem, you literally don't have enough digits inside of your floating point numbers to enforce that.
Note two things with this (https://sundials.readthedocs.io/en/latest/cvodes/Mathematics_link.html#selection-of-the-absolute-tolerances-for-sensitivity-variables). One is that Sundials does not automatically include this in their error test, it's an option to be set for what to do with the adaptivity of the forward and adjoint sensitivities. Our methods always enforce the gradient value accuracy by default, so these terms are error controlled. And secondly you're setting it directly here, while in Sundials it's usually scaled.
So my guess is that the tolerance is much lower for Julia. If you think you have a value of 1e9 calculated to an absolute tolerance of 1e-8 in AMICI, you don't since that's not possible, and if it's not exiting telling you that it's unable to hit that accuracy... then that means it's not controlling the accuracy here which is why there's a difference.
This plus https://github.com/SciML/OrdinaryDiffEq.jl/pull/2098 and I think this is solved. Feel free to ask any more questions, it's an interesting question with a rather deep answer became the interpretation of errors is not necessarily always the same.
Thanks a lot for looking into this!
We also saw in benchmarks that the SciMLSensitivity gradients were more accurate than those in AMICI, which align with SciMLSensitivity having a lower practical tolerance.
My follow up is then if you have any idea on how to deal with this in practice? The crux is; when we are close to an optima I do not get the problem above (as the extreme du
is typically a consequence of bad parameters), and for efficiency when close to an optima having accurate gradients would be good for convergence which is why we set low tolerances. This problem appears for "bad" parameters where I think not computing du
to an accuracy of 1e-8
is perfectly fine, as the optimization algorithm will probably anyhow find a relatively good descent direction.
Therefore, to not fail with the gradient evaluation (and thus entire parameter estimation for a start-guess), would it be possible to maybe use some form of scaled abstol, or in case the tolerance cannot be fulfilled when solving continue solving for du
with a small dt
to at least get a somewhat accurate gradient for "bad" parameter vectors? (does https://github.com/SciML/OrdinaryDiffEq.jl/pull/2098 solve this?)
Some additional info, on this problem AMICI takes many integration steps (magnitudes larger than in the forward solve), which means they probably control the error of du
? (this is a bit outside my expertise and I do not really know what else they could control), so I guess abstol
is scaled somehow.
My follow up is then if you have any idea on how to deal with this in practice? The crux is; when we are close to an optima I do not get the problem above (as the extreme du is typically a consequence of bad parameters), and for efficiency when close to an optima having accurate gradients would be good for convergence which is why we set low tolerances. This problem appears for "bad" parameters where I think not computing du to an accuracy of 1e-8 is perfectly fine, as the optimization algorithm will probably anyhow find a relatively good descent direction.
It sounds like what you generally want is a low relative tolerance but not a low absolute tolerance. It sounds like you've been setting both to the same value, but generally that's not a good idea. Relative tolerance will be more forgiving when the gradients are large and less forgiving when they are small.
Some additional info, on this problem AMICI takes many integration steps (magnitudes larger than in the forward solve), which means they probably control the error of du? (this is a bit outside my expertise and I do not really know what else they could control), so I guess abstol is scaled somehow.
SciMLSensitivity and DiffEq in general controls the error of du in the forward pass. If you want to check if that's what's doing it, you can just do ForwardDiff on your loss function and see how many steps the forward solution changes to.
From a discussion on discourse Chris asked me to file an issue here.
Recently we have done some benchmarks to test the performance of DifferentialEquations.jl for parameter estimation for ODE models in the field of system biology, and we have run into problems when trying to compute the gradient for stiff models via adjoint sensitivity analysis with a discrete cost function.
Below is a MVE for the Bachmann model (a stiff 21 state ODE model commonly used for benchmarks). For most parameter vectors the gradient is correctly computed when using adjoint_sensitivities, but for a subset of parameter vectors some entries in dgdu_discrete (below specifically dgdu_discrete[22] = -9e10) become very big which I think is the reason why we get the following error message from CVODE_BDF (see below), and following this error message the code crashes (top stacktrace below).
The main issue is that AMICI manages to compute the gradient via adjoint sensitivity analysis (albeit 67k integration steps are needed when solving the adjoint ODE for λ). AMICI uses the Sundials’ CVODES adjoint methods.
The code crashes at the same time point (t=120 where we have data) for both QuadratureAdjoint and InterpolatingAdjoint when using either ReverseDiffVJP or EnzymeVJP. Furthermore, both QNDF and FBDF run into integration problem with dt_min at t=120. I am confident the model is implemented correctly as we get both the same cost and gradient as AMICI (established software in system biology) for random parameter vectors.
Lastly, want to add that thus far we have found that the SciML family of packages achieves great performance for system biology models, and we really like to work with it. Thus it would be great if we could find a way to compute the gradient via adjoint sensitivity analysis for challenging cases like this (as typically we do parameter estimation via multi-start gradient-based optimization and having the gradient computations crash is quite detrimental).
Top of stacktrace: