SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
851 stars 154 forks source link

Multiple neural networks tutorial #459

Closed ChrisRackauckas closed 3 years ago

ChrisRackauckas commented 3 years ago

https://github.com/ChrisRackauckas/universal_differential_equations/issues/31 it's not hard but if someone asks a questions it means we're not clear enough.

yewalenikhil65 commented 3 years ago

Hi @ChrisRackauckas . Thanks for the hint at https://github.com/ChrisRackauckas/universal_differential_equations/issues/31 I think you meant p = params(p1,p2) earlier. But still no luck. Some sort of instability is detected while solving for sol_nn..

Please check my following approach(example taken as lorenz, other code is as per the UODE tutorial)

function lorenz!(du, u, p, t)
    x,y,z = u
    σ, ρ, β = p
    du[1] = dx = σ*(y-x)
    du[2] = dy = x*(ρ-z) - y
    du[3] = dz = x*y - β*z
end
# Define the experimental parameter
tspan = (0.0,20.0)   # short span for training
u0 = Float32[1.0; 0.0; 0.0]
p_ = Float32[10.0, 28.0, 8.0/3]
prob = ODEProblem(lorenz!, u0, tspan, p_)
sol = solve(prob,saveat = 0.5 )

# Ideal data
X = Array(sol)
Xₙ = X + Float32(1e-3)*randn(eltype(X), size(X))  #noisy data

# For xz term
NN_1 = FastChain(FastDense(2, 64, tanh),FastDense(64, 32, tanh), FastDense(32, 1))
p1 = initial_params(NN_1)

# for xy term
NN_2 = FastChain(FastDense(2, 64, tanh),FastDense(64, 32, tanh), FastDense(32, 1))
p2 = initial_params(NN_2)
p = params(p1,p2)
function dudt_(u, p,t)
    x, y, z = u
    z1 = NN_1([x,z], p[1])
    z2 = NN_2([x,y], p[2])
    [p_[1]*(y - x),                 # σ*(y-x)
    x*p_[2] + z1[1] - y,        # x*(ρ-z) - y
    z2[1] - p_[3]*z]              # x*y - β*z
end
prob_nn = ODEProblem(dudt_,u0, tspan, p)
sol_nn = solve(prob_nn, Rosenbrock23(),saveat = sol.t)   # problem seems here..instability detected

function predict(θ)
    Array(solve(prob_nn, Vern7(), u0, p=θ, saveat = sol.t,
                         abstol=1e-6, reltol=1e-6,
                         sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())))
end

# No regularisation right now
function loss(θ)
    pred = predict(θ)
    sum(abs2, Xₙ .- pred), pred
end
loss(p)
const losses = []
callback(θ,l,pred) = begin
    push!(losses, l)
    if length(losses)%50==0
        println(losses[end])
    end
    false
end

res1_uode = DiffEqFlux.sciml_train(loss, p, ADAM(0.01), cb=callback, maxiters = 500)
res2_uode = DiffEqFlux.sciml_train(loss, res1_uode.minimizer, BFGS(initial_stepnorm=0.01), cb=callback, maxiters = 10000)
yewalenikhil65 commented 3 years ago

Hi @ChrisRackauckas , I tried replacing FastDense / FastChain with Chain / Dense for using Flux.train in previously commented code as follows. However, I am still facing the same instability issue using Multiple Neural Networks in systems of ODEs. Please check if possible.

NN₁ = Chain(Dense(2, 16, tanh), Dense(16, 1))
p₁, re₁ = Flux.destructure(NN₁)
# for xy term
NN₂ = Chain(Dense(2, 16, tanh), Dense(16, 1))
p₂, re₂ = Flux.destructure(NN₂)

p = params(p₁, p₂)
function dudt_(u, p,t)
    x, y, z = u
    z₁ = re₁(p[1])([x,z])
    z₂ = re₂(p[2])([x,y])
    [p_[1]*(y - x),             # σ*(y-x)
    x*p_[2] + z₁[1] - y,        # x*(ρ-z) - y
    z₂[1] - p_[3]*z]            # x*y - β*z
end
prob_nn = ODEProblem(dudt_,u0, tspan, p)
sol_nn = solve(prob_nn, Tsit5(), saveat = sol.t)   # isn;t able to solve here.. tried using multiple solvers
ChrisRackauckas commented 3 years ago

Lorenz is a bad example because derivatives explode over long time spans due to chaos. I just finished grading and will make this example work, but on a shortened time span to avoid the diverging tangent space.

ChrisRackauckas commented 3 years ago

Check out the new tutorial once it's done building.

yewalenikhil65 commented 3 years ago

Thanks @ChrisRackauckas . The tutorial based on https://diffeqflux.sciml.ai/dev/examples/multiple_nn/ works fine.

I did not understand the scaling_factor part, and why is it needed while returning [z1[1],p[end]*z2[1]] from dudt_ Is the scaling_factor a parameter related to τinv ?

ChrisRackauckas commented 3 years ago

It's just some random choice to show you can do things like that.