SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
321 stars 69 forks source link

Sensitivity Analysis for TwoPointBVProblem #353

Open jholland1 opened 3 years ago

jholland1 commented 3 years ago

Are there any sensitivity algorithms available for a TwoPointBVProblem? My solve call in this case is:

augmented = TwoPointBVProblem(aug_model,bc,[p[1],0.0],tspan,θ)

solve(augmented,MIRK4(),p=p,dt = 0.01)

I then define a loss function, but if I call:

Zygote.gradient(loss,θ)

I get the error: "dt must be positive", I get the same response if calling DiffEqFlux.sciml_train()

I suspect that this is being caused by the differences in calling solve for and ODEProblem vs TwoPointBVProblem, and changing the sensealg doesn't resolve this issue. Is this a feature that hasn't been implemented? I can get the gradients and optimization to work using GalacticOptim.AutoFiniteDiff() but I would much prefer to use reverse mode autodiff if it is possible.

Thanks! I appreciate the help, I'm new to this package but it has a lot of great capability!

ChrisRackauckas commented 3 years ago

GalacticOptim.AutoForwardDiff() should work too, but indeed we haven't added adjoints to BVPs yet. I'm transferring this over to DiffEqSensitivity.jl so we can track and add the adjoint there. This is just an omission and it would be a good idea to get this added.

mschauer commented 3 years ago

If you tell a bit more about you specific problem, ideally with MWE we can have it in mind.

ChrisRackauckas commented 3 years ago

I think a sensible MWE would be to just take https://diffeq.sciml.ai/stable/tutorials/bvp_example/#BVProblem and stick it into Zygote. Right now there's no adjoint defined on it so Zygote will try to work directly on the solver and my guess is that it fails in the mutation where it builds the banded matrix for the MIRK tableau.

jholland1 commented 3 years ago

Sorry for the (very) slow response on this. I meant to polish up my test code and and post but I got side-tracked. Recently revisited this so thought I would share the problem I'm referencing with this issue. IJulia notebook attached and code below. I'm trying to reproduce results from my thesis that I did in Python. I think it will be far more easy to work with all the capability in this package (really amazing stuff!!). I can get results using finite differences but I would also like to train one NN using multiple TwoPointBVProblem solutions and the runtime would get a little gnarly using finite differences. Still may give it a shot. If interested, details on what I'm trying to do are in the 1D heat equation sections here: https://www.researchgate.net/publication/333808531_Field_Inversion_and_Machine_Learning_With_Embedded_Neural_Networks_Physics-Consistent_Neural_Network_Training

And my Python implementation is here: https://github.com/jholland1/py_1D_heat

Again, really amazing capability in this package, would like to be able to demo the same results with much less pain. Very cool.

IJulia Notebook with output and notes: 1DHeat.pdf

Or the same with just the code is below. This will throw an error at the Zygote.gradient() call:

versioninfo()

using Plots, Optim, Flux, DiffEqFlux, DifferentialEquations, LaTeXStrings, DiffEqSensitivity, Zygote, GalacticOptim

function heat!(du,u,p,t)
    Tinf = p[1]
    T = u[1]
    dT = u[2]
    du[1] = dT
    du[2] = -(1.0+5.0*sin(3.0*pi*T/200.0)+exp(0.02*T))*10.0^(-4.0)*(Tinf^4.0-T^4.0)+0.5*(Tinf-T)
end

p = [50.0]
u0 = [p[1], 0.0]
tspan = (0.0,1.0)
dt = 0.1

function bc!(residual, u, p, t)
    residual[1] = u[1][1]
    residual[2] = u[end][1]
end
prob = TwoPointBVProblem(heat!,bc!,[p[1],0.0],tspan,p)

sol = solve(prob,MIRK4(),dt=dt)
scatter(sol.t,sol[1,:],label=L"\mathrm{Truth}, \ T_\infty",xlabel=L"z",ylabel=L"T")

z = sol.t
truth = sol[1,:]
function predict_truth()
    solve(prob,MIRK4(),p=p,dt = dt)[1,:]
end

function model!(du,u,p,t)
    Tinf = p[1]
    T = u[1]
    dT = u[2]
    du[1] = dT
    du[2] = -5.0*10.0^(-4.0)*(Tinf^4.0-T^4.0)
end
pred = TwoPointBVProblem(model!,bc!,[p[1],0.0],tspan,p)
function predict_model()
    solve(pred,MIRK4(),p=p,dt = dt)[1,:]
end
model_prediction = predict_model()

scatter!(sol.t,model_prediction,label="Model")

    nodes = 20
    dudt2 = Chain(x->[x[1],x[2]],
    Dense(2,nodes,tanh,initW = zeros, initb = zeros),
    #Dense(nodes,nodes,tanh,initW = zeros, initb = zeros), #Uncomment to add an additional hidden layer
    Dense(nodes,1, initW = zeros, initb = zeros))
  g,re = Flux.destructure(dudt2)
  re(g)([50.0,0.0])

function aug_model(du,u,p,t)
    global re
  Tinf = p[3]
  g = p[4:end]
    T = u[1]
    dT = u[2]
    du[1] = dT
    du[2] = (1.0+re(g)([(T-Tinf)/Tinf,Tinf/50.0])[1])*-5.0*10.0^(-4.0)*(Tinf^4.0-T^4.0)
end

θ=[u0;p;g]
augmented = TwoPointBVProblem(aug_model,bc!,[p[1],0.0],tspan,θ)# sensealg=ForwardDiffSensitivity())

function predict_n_ode(θ)
  #solve(pred,MIRK4(),p=p,dt = 0.01)[1,:]
  solve(augmented,MIRK4(),p=θ,dt=dt)[1,:]
end
predict_n_ode(θ)

function loss_n_ode(θ)
    pred = predict_n_ode(θ)
    loss = sum(abs2,truth .- pred)
    loss,pred
end
l, pred = loss_n_ode(θ)
loss(θ)

display(Zygote.gradient(loss,θ))

show_result = function (θ, l, pred)
    scatter(z,truth,label="Truth")
    scatter!(z,model_prediction,label="Model")
    scatter!(z,predict_n_ode(θ),label="Augmented")
end
show_result(θ, l, pred)

cb = function (θ,l,pred;doplot=false)
    display(l)
    #pl = plot(sol)
    return false
end
cb(θ,l,pred)

#res = DiffEqFlux.sciml_train(loss_n_ode, θ, LBFGS(), cb = cb)
#res = DiffEqFlux.sciml_train(loss_n_ode, θ, ADAM(0.1), cb = cb, maxiters=100)

f = OptimizationFunction(loss, GalacticOptim.AutoFiniteDiff())
prob = OptimizationProblem(f,θ)
sol = solve(prob,BFGS())

#show_result(res.minimizer,loss_n_ode(res.minimizer)...)
show_result(sol.u,loss_n_ode(sol.u)...)

scatter(z,truth-predict_n_ode(θ),label="Model Error")
scatter!(z,truth-predict_n_ode(sol.u),label="Augmented Error")
ChrisRackauckas commented 3 years ago

Yeah we just don't have this adjoint yet. This is a good project for one of my students to pick up though.