SciML / NeuralPDE.jl

Physics-Informed Neural Networks (PINN) Solvers of (Partial) Differential Equations for Scientific Machine Learning (SciML) accelerated simulation
https://docs.sciml.ai/NeuralPDE/stable/
Other
990 stars 199 forks source link

NNODE fails on simple logistic curve #634

Closed sdwfrost closed 1 year ago

sdwfrost commented 1 year ago

NeuralPDE.jl fails on this simple example of a logistic curve. I'm not sure whether this would be helped by changing the neural net (as the output is between 0 and 1), though I've tried lots of structures, as well as trying to refine the optimization. The ADAM+BFGS should work OK for this example.

using ModelingToolkit
using OrdinaryDiffEq
using NeuralPDE
using DomainSets
using Flux
using Optimization
using OptimizationOptimJL
using OptimizationOptimisers
using Plots

@parameters t
@variables x(..)
Dt = Differential(t)
eqs = [Dt(x(t)) ~ x(t)*(1-x(t))]

@named ode_sys = ODESystem(eqs)
ode_prob = ODEProblem(ode_sys, [0.01], (0.0,10.0), [])
ode_sol = solve(ode_prob, Tsit5(), saveat=0.1)
plot(ode_sol)

bcs = [x(0) ~ 0.01]
domains = [t ∈ Interval(0.0, 10.0)];
@named pde_sys = PDESystem(eqs, bcs, domains, [t], [x(t)])

numhid = 32
chain = [Flux.Chain(Flux.Dense(1, numhid, Flux.σ), Flux.Dense(numhid, 1))]
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization = NeuralPDE.PhysicsInformedNN(chain, grid_strategy)
pde_prob = NeuralPDE.discretize(pde_sys, discretization)

global i=1
callback = function (p,l)
    println("Epoch $i: Current loss is: $l")
    global i += 1
    return false
end

res = Optimization.solve(pde_prob, OptimizationOptimisers.Adam(0.1); callback=callback, maxiters=5000, abstol = 1e-10, reltol = 1e-10)
pde_prob = remake(pde_prob, u0 = res.minimizer)
res = Optimization.solve(pde_prob, OptimizationOptimJL.BFGS(); callback=callback, maxiters=5000, abstol = 1e-10, reltol = 1e-10)
phi = discretization.phi
ts = [infimum(d.domain):0.1:supremum(d.domain) for d in domains][1]
xpred  = hcat([phi[1]([t],res.u) for t in ts]...)'
plot(ode_sol)
plot!(ts,xpred)
lhoenig commented 1 year ago

I think the problem is not strictly speaking with NeuralPDE but that depending on the initialization, the NN sometimes gets stuck in local minima. With networks this small, I think this is a real concern. Your code sometimes actually trains successfully for me using Adam, BFGS, or both. You could try making the network wider or deeper to increase chances of training success.

sdwfrost commented 1 year ago

I've tried both wider and deeper (as well as both), and sometimes I get a sigmoidish curve, but still way off the true solution, and with many more parameters than a simple basis function approach would require. I was hoping that this example wouldn't require anything large or complex on the neural net side. The kind of problems I'm looking at will have the same problem as this simple logistic curve.

lhoenig commented 1 year ago

I'm more confused after playing around with it more.. Sometimes the loss is really low at the end (like 1e-10) but the solution is totally wrong. For successful runs the loss is more like 1e-7.

Sometimes I get NaN solutions / parameters without an error.

I also tried using Lux instead of Flux as this is recommended in the docs to use double precision. But using Lux I was never able to train successfully, the loss got lower more consistently but the solutions just don't seem to work.

Another thing I just realized is that the issue title references NNODE (https://docs.sciml.ai/NeuralPDE/stable/manual/ode/#NeuralPDE.NNODE) but the code doesn't actually use NNODE (or is it used internally?). Maybe the NNODE specialization would work better as this is an ODE here, but I could not get NNODE to work based on the example code there. I always get "Optimization algorithm not found. Either the chosen algorithm is not a valid solver choice for the OptimizationProblem, or the Optimization solver library is not loaded." even though everything is indeed loaded..

lhoenig commented 1 year ago

sometimes I get a sigmoidish curve, but still way off the true solution

Fwiw, the successful runs do come very close to the sigmoid function for me. I was just able to get a successful solution using Lux, and it looks like this:

lux But I am not really sure about the loss values. Sometimes it works, sometimes it doesn't, sometimes that is reflected in the loss, other times it doesn't seem to be correlated. Maybe it becomes clearer when looking at the individual loss terms (data and physics loss).

sdwfrost commented 1 year ago

Sorry, the subject comes from a discussion of this issue on the Julia Discourse.

lhoenig commented 1 year ago

I got NNODE to run, I was missing the Optimisers package. But it doesn't really fit the function more reliably than using a PDESystem + PhysicsInformedNN on this problem, for me actually less so. I really don't understand these highly variable loss values of the PINNs and how they relate to the solution quality. With the NNODE I achieve the same results as described by @sdwfrost, i.e. never really an exact fit as I sometimes got with the PDESystem + PhysicsInformedNN code.

Here's my code for using the NNODE:

using Flux
using NeuralPDE
using OrdinaryDiffEq, Optimisers
using Plots
import Lux, OptimizationOptimisers, OptimizationOptimJL

eq = (u, p, t) -> u .* (1 .- u)
tspan = (0.0, 10.0)
u0 = 0.01
prob = ODEProblem(eq, u0, tspan)
ode_sol = solve(prob, Tsit5(), saveat=0.1)

function run()
    chain = Flux.Chain(Flux.Dense(1, 20, σ), Flux.Dense(20, 1))
    luxchain = Lux.Chain(Lux.Dense(1, 20, Lux.σ), Lux.Dense(20, 1))
    opt = OptimizationOptimJL.BFGS()
    # opt = OptimizationOptimisers.Adam(0.1)
    nnode_sol = solve(prob, NeuralPDE.NNODE(luxchain, opt), dt=1 / 10.0, verbose=true, abstol=1.0e-10, maxiters=5000)
    ts = tspan[1]:0.01:tspan[2]
    xpred = [nnode_sol(t) for t in ts]
    plot(ode_sol, label="ODE solution")
    plot!(ts, xpred, label="NNODE solution")
end

run()
sdesai1287 commented 1 year ago

I've been playing with this and applying a new training strategy to it, and I am getting the mean absolute difference between the nnode solution and ode solution is 0.02. With some other strategies, this difference is sometimes very large which confuses me for such a simple problem. Here is the code, I am going to try to make this better

using OrdinaryDiffEq, Optimisers
using Plots, Statistics
import Lux, OptimizationOptimisers, OptimizationOptimJL

eq = (u, p, t) -> u * (1 .- u)
tspan = (0.0, 10.0)
u0 = 0.01
prob = ODEProblem(eq, u0, tspan)
ode_sol = solve(prob, Tsit5(), saveat=0.01)

N = 4
func = Lux.σ
luxchain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, 1,func))
# opt = OptimizationOptimJL.BFGS()
opt = OptimizationOptimisers.Adam(0.01)
weights = [1/14, 4/14, 4/14, 4/14, 1/14]
samples = 200
alg = NeuralPDE.NNODE(luxchain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, samples))
nnode_sol = solve(prob, alg, verbose=true, maxiters=50000, saveat=0.01)
ts = tspan[1]:0.01:tspan[2]
xpred = [nnode_sol(t) for t in ts]
print(abs(mean(ode_sol .- nnode_sol)))

using Plots
plot(ode_sol, label="ODE solution")
plot!(ts, xpred, label="NNODE solution")

Here is the plot plot_2

sdesai1287 commented 1 year ago

Seems like Quadrature Training works pretty well for this, I am getting mean difference in solutions to be 0.01 or less fairly consistently. Here is the code and a graph

using OrdinaryDiffEq, Optimisers, NeuralPDE
using Plots, Statistics
import Lux, OptimizationOptimisers, OptimizationOptimJL

eq = (u, p, t) -> u .* (1 .- u)
tspan = (0.0, 10.0)
u0 = 0.01
prob = ODEProblem(eq, u0, tspan)
ode_sol = solve(prob, Tsit5(), saveat=0.01)

N = 4
func = Lux.σ
luxchain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N -1, func), Lux.Dense(N -1, N-2, func), Lux.Dense(N - 2, 1, func))
opt = OptimizationOptimisers.Adam(0.001)
alg = NeuralPDE.NNODE(luxchain, opt, autodiff = false, strategy = NeuralPDE.QuadratureTraining())
nnode_sol = solve(prob, alg, verbose=true, abstol = 1e-10, maxiters=100000, saveat=0.01)

ts = tspan[1]:0.01:tspan[2]
xpred = [nnode_sol(t) for t in ts]
xreal = [ode_sol(t) for t in ts]
println(abs(mean(xpred .- xreal)))
println(abs(mean(nnode_sol .- ode_sol)))

using Plots
plot(ode_sol, label="ODE solution")
plot!(ts, xpred, label="NNODE solution")

plot_4

ChrisRackauckas commented 1 year ago

👍 . I think we should change the docs to use QuadratureTraining in the tutorials and then close this. GridTraining shouldn't be used (as the docs say in other places).

sathvikbhagavan commented 1 year ago

@ChrisRackauckas we can close this? https://github.com/SciML/NeuralPDE.jl/pull/729 addresses using QuadratureTraining in the docs