SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
864 stars 153 forks source link

Neural SDE example has convergence and plotting issue #114

Closed amckay1 closed 4 years ago

amckay1 commented 4 years ago

Very excited to try out some of the demo code for UDEs, and I didn't have any trouble with the ODE example (starting at https://github.com/JuliaDiffEq/DiffEqFlux.jl#universal-differential-equations) but when I tried to use the SDE example (starting at https://github.com/JuliaDiffEq/DiffEqFlux.jl#neural-sde-example) I encountered a couple issues:

  1. Training did not appear to converge
  2. The final plot did not match the README.md version, looking instead like this: screenshot

    I'm guessing 1 and 2 are linked, and possibly a versioning issue: I'm updated on everything but not on master with DiffEqFlux v1.0.0, StochasticDiffEq v6.17.1, Plots v0.28.4, Flux v0.10.1, and DiffEqBase v6.12.2.

Will try to debug it more later if I can :)

Exact code ran in fresh Julia v1.3.1 session (OSX) below (copied from README.md):

#First let's build training data from the same example as the neural ODE:

using Flux, DiffEqFlux, StochasticDiffEq, Plots, DiffEqBase.EnsembleAnalysis

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0,1.0f0)

function trueSDEfunc(du,u,p,t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
mp = Float32[0.2,0.2]
function true_noise_func(du,u,p,t)
    du .= mp.*u
end
prob = SDEProblem(trueSDEfunc,true_noise_func,u0,tspan)
#For our dataset we will use DifferentialEquations.jl's parallel ensemble interface to generate data from the average of 100 runs of the SDE:

# Take a typical sample from the mean
ensemble_prob = EnsembleProblem(prob)
ensemble_sol = solve(ensemble_prob,SOSRI(),trajectories = 100)
ensemble_sum = EnsembleSummary(ensemble_sol)
sde_data = Array(timeseries_point_mean(ensemble_sol,t))
#Now we build a neural SDE. For simplicity we will use the neural_dmsde multiplicative noise neural SDE layer function:

drift_dudt = Chain(x -> x.^3,
             Dense(2,50,tanh),
             Dense(50,2))
n_sde = NeuralDMSDE(drift_dudt,mp,tspan,SOSRI(),saveat=t,reltol=1e-1,abstol=1e-1)
ps = Flux.params(n_sde)
#Let's see what that looks like:

pred = n_sde(u0) # Get the prediction using the correct initial condition

drift_(u,p,t) = drift_dudt(u)

# Note that if this line uses scalar indexing, you may need to
# `Tracker.collect()` the output in a separate dispatch i.e.
# g(u::Tracker.TrackedArray,p,t) = Tracker.collect(mp.*u)
g(u,p,t) = mp.*u
nprob = SDEProblem(drift_,g,u0,(0.0f0,1.2f0),nothing)

ensemble_nprob = EnsembleProblem(nprob)
ensemble_nsol = solve(ensemble_nprob,SOSRI(),trajectories = 100)
ensemble_nsum = EnsembleSummary(ensemble_nsol)
p1 = plot(ensemble_nsum, title = "Neural SDE: Before Training")
scatter!(p1,t,sde_data',lw=3)
scatter(t,sde_data[1,:],label="data")
scatter!(t,pred[1,:],label="prediction")
#Now just as with the neural ODE we define a loss function:

function predict_n_sde()
  n_sde(u0)
end
loss_n_sde1() = sum(abs2,sde_data .- predict_n_sde())
loss_n_sde10() = sum([sum(abs2,sde_data .- predict_n_sde()) for i in 1:10])

data = Iterators.repeated((), 10)
opt = ADAM(0.025)
cb = function () #callback function to observe training
  sample = predict_n_sde()
  # loss against current data
  display(sum(abs2,sde_data .- sample))
  # plot current prediction against data
  pl = scatter(t,sde_data[1,:],label="data")
  scatter!(pl,t,sample[1,:],label="prediction")
  display(plot(pl))
end

# Display the SDE with the initial parameter values.
cb()
#Here we made two loss functions: one which uses single runs of the SDE and another which uses multiple runs. This is beceause an SDE is stochastic, so trying to fit the mean to high precision may require a taking the mean of a few trajectories (the more trajectories the more precise the calculation is). Thus to fit this, we first get in the general area through single SDE trajectory backprops, and then hone in with the mean:

Flux.train!(loss_n_sde1 , ps, Iterators.repeated((), 100), opt, cb = cb)
Flux.train!(loss_n_sde10, ps, Iterators.repeated((), 20), opt, cb = cb)
#And now we plot the solution to an ensemble of the trained neural SDE:

ensemble_nprob = EnsembleProblem(nprob)
ensemble_nsol = solve(ensemble_nprob,SOSRI(),trajectories = 100)
ensemble_nsum = EnsembleSummary(ensemble_nsol)
p2 = plot(ensemble_nsum, title = "Neural SDE: After Training", xlabel="Time")
scatter!(p2,t,sde_data',lw=3,label=["x" "y" "z" "y"])

plot(p1,p2,layout=(2,1))
ChrisRackauckas commented 4 years ago

The plot of the data looks transposed.

As for not converging, we just updated the docs to 1.0 and this may have a bug. My guess is that I accidentally didn't pass the parameters somewhere

ChrisRackauckas commented 4 years ago

FWIW we know it works because NeuralNetDiffEq works, and it uses the SDE based method for solving high dimensional PDEs, so it's some silly doc issue

amckay1 commented 4 years ago

Definitely, not high priority, just wanted to let you know as the new paper drives more traffic to the docs. And I agree it looks transposed, but probably because the x-axis is not responding correctly to the timespan specified by ensemble_nsum, and I think if you zoom in on the 0.0:1.0 range you'll see that the sde_data is plotting correctly. Maybe something to do with the plots recipe for ensemble_nsum? Speculating, will keep exploring...

ChrisRackauckas commented 4 years ago

Fixed! It had a plot recipe issue in DiffEqBase mixed with a small thing I forgot to handle in the Zygote transition (the test had an @test_broken on it to remind me... oops). All is good on master now and I'll make a patch release soon.

amckay1 commented 4 years ago

Wonderful, thank you!