DENG-MIT / CRNN.jl_old

Chemical Reaction Neural Network in Julia
1 stars 0 forks source link

A semi-working demo for case 1 single initial condition #4

Closed jiweiqi closed 3 years ago

jiweiqi commented 4 years ago
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots

function trueODEfunc(dydt, y, k, t)
    dydt[1] = -2 * k[1] * y[1]^2 - k[2] * y[1]
    dydt[2] = k[1] * y[1]^2 - k[4] * y[2] * y[4]
    dydt[3] = k[2] * y[1] - k[3] * y[3]
    dydt[4] = k[3] * y[3] - k[4] * y[2] * y[4]
    dydt[5] = k[4] * y[2] * y[4]
end

u0 = Float64[1.0;1.0;0.0;0.0;0.0]
datasize = 20
tspan = Float64[0.0, 20.0]
tsteps = range(tspan[1], tspan[2], length = datasize)

k = Float64[0.1, 0.2, 0.13, 0.3]
alg = Rosenbrock23(autodiff = false)

prob_trueode = ODEProblem(trueODEfunc, u0, tspan, k)
ode_data = Array(solve(prob_trueode, alg, saveat = tsteps))

lb = 1e-5
ub = 10.0

dudt2 = FastChain((x, p)->log.(clamp.(x, lb, ub)),
                  FastDense(5, 4, exp),
                  FastDense(4, 5))

prob_neuralode = NeuralODE(dudt2, tspan, alg, saveat = tsteps)
p = prob_neuralode.p
# layer1: 5x4+4 = 24
# layer2: 4x5+5 = 25
p[1:20] .= clamp.(p[1:20], 0, 2.5)
p[45:49] .= 0.0 * p[45:49]

function predict_neuralode(p)
    return clamp.(Array(prob_neuralode(u0, p)), -ub, ub)
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end

function display_p(p)
    println("r1")
    println(p[1:5])
    println("r2")
    println(p[6:10])
    println("r3")
    println(p[11:15])
    println("r4")
    println(p[16:20])
end

# Callback function to observe training
list_plots = []
list_loss = []
iter = 0
p_now = p
cb = function (p, l, pred; doplot = true)
    global list_plots, iter, p_now

    p[1:20] .= clamp.(p[1:20], 0, 2.5)
    p[45:49] .= 0.0 * p[45:49]

    p_now = p

    if iter == 0
        list_plots = []
    end
    iter += 1

    push!(list_loss, l)

    # plot current prediction against data

    list_plt = []

    for i in 1:5
        plt = scatter(tsteps, ode_data[i,:], label = string("data",i))
        plot!(plt, tsteps, pred[i,:], label = string("pred",i))
        push!(list_plt, plt)
    end

    if iter < 2000
        plt_loss = plot(list_loss, yscale = :log10, label="loss")
    else
        plt_loss = plot(list_loss, xscale = :log10, yscale = :log10, label="loss")
    end

    push!(list_plt, plt_loss)

    plt_all = plot(list_plt..., layout = (2,3), legend = true)

    #push!(list_plots, plt)
    if doplot & (iter % 100 == 0)
        display(plt_all)
        display(iter)
        display(l)
        display_p(p)
    end

    return false
end

pstart = DiffEqFlux.sciml_train(loss_neuralode, p, ADAM(0.001), cb = cb, maxiters = 10000).minimizer

pstart = DiffEqFlux.sciml_train(loss_neuralode, pstart, ADAM(0.001), cb = cb, maxiters = 200000).minimizer

pmin = DiffEqFlux.sciml_train(loss_neuralode, p, cb = cb, Optim.KrylovTrustRegion(), maxiters = 1000)
jiweiqi commented 4 years ago

Loss and species profiles

image

Input weights

image