SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
874 stars 157 forks source link

Neural SDE does not train diffusion parameters #189

Closed ric-cioffi closed 4 years ago

ric-cioffi commented 4 years ago

By running the neural SDE example in the docs, the result of sciml_train for the diffusion parameters are the same as the starting point (see below). I don't know why that's the case but the reason seems to be that Zygote gives a zero gradient for those parameters.

using Flux, DiffEqFlux, StochasticDiffEq, Plots, DiffEqBase.EnsembleAnalysis, Statistics

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0,1.0f0)

function trueSDEfunc(du,u,p,t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
mp = Float32[0.2,0.2]
function true_noise_func(du,u,p,t)
    du .= mp.*u
end
prob = SDEProblem(trueSDEfunc,true_noise_func,u0,tspan)

ensemble_prob = EnsembleProblem(prob)
ensemble_sol = solve(ensemble_prob,SOSRI(),trajectories = 1000)
ensemble_sum = EnsembleSummary(ensemble_sol)
sde_data,sde_data_vars = Array.(timeseries_point_meanvar(ensemble_sol,t))

drift_dudt = FastChain((x,p) -> x.^3,
             FastDense(2,50,tanh),
             FastDense(50,2))
diffusion_dudt = FastChain(FastDense(2,2))
n_sde = NeuralDSDE(drift_dudt,diffusion_dudt,tspan,SOSRI(),saveat=t,reltol=1e-1,abstol=1e-1)
start_p = copy(n_sde.p)

function predict_n_sde(p)
  Array(n_sde(u0,p))
end
function loss_n_sde(p;n=100)
  samples = [predict_n_sde(p) for i in 1:n]
  means = reshape(mean.([[samples[i][j] for i in 1:length(samples)] for j in 1:length(samples[1])]),size(samples[1])...)
  vars = reshape(var.([[samples[i][j] for i in 1:length(samples)] for j in 1:length(samples[1])]),size(samples[1])...)
  loss = sum(abs2,sde_data - means) + sum(abs2,sde_data_vars - vars)
  loss,means,vars
end

cb = function (p,loss,means,vars) #callback function to observe training
  # loss against current data
  display(loss)
  return false
end

opt = ADAM(0.025)
res1 = DiffEqFlux.sciml_train((p)->loss_n_sde(p,n=10),  n_sde.p, opt, cb = cb, maxiters = 100)
final_p = copy(res1.minimizer)

start_p[(n_sde.len + 1):end] == final_p[(n_sde.len + 1):end]
ChrisRackauckas commented 4 years ago

It looks like it's just FastChain here. MWE:

using OrdinaryDiffEq, StochasticDiffEq, DelayDiffEq, Flux, DiffEqFlux,
      Zygote, Test, DiffEqSensitivity

mp = Float32[0.1,0.1]
x = Float32[2.; 0.]
xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.]))
tspan = (0.0f0,1.0f0)
dudt = Chain(Dense(2,50,tanh),Dense(50,2))
dudt2 = Chain(Dense(2,50,tanh),Dense(50,2))
NeuralDSDE(dudt,dudt2,(0.0f0,.1f0),SOSRI(),saveat=0.1)(x)
sode = NeuralDSDE(dudt,dudt2,(0.0f0,.1f0),SOSRI(),saveat=0.0:0.01:0.1)

grads = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode))
@test ! iszero(grads[x])
@test ! iszero(grads[sode.p])
@test ! iszero(grads[sode.p][end])

dudt = FastChain(FastDense(2,50,tanh),FastDense(50,2))
dudt2 = FastChain(FastDense(2,50,tanh),FastDense(50,2))
NeuralDSDE(dudt,dudt2,(0.0f0,.1f0),SOSRI(),saveat=0.1)(x)
sode = NeuralDSDE(dudt,dudt2,(0.0f0,.1f0),SOSRI(),saveat=0.0:0.01:0.1)

grads = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode))
@test ! iszero(grads[x])
@test ! iszero(grads[sode.p][end])
ChrisRackauckas commented 4 years ago

Using this as a time to refresh the neural SDE image to include error bars in the data.

neural_sde

ChrisRackauckas commented 4 years ago

neural_sde