aicenter / GenerativeModels.jl

Generative Models with trainable conditional distributions in Julia!
MIT License
31 stars 3 forks source link

Zygote and DiffEqFlux #38

Closed nmheim closed 4 years ago

nmheim commented 4 years ago

DiffEqFlux is not updated to Zygote yet. There is a pull request but it is has not been worked on for a few months. Including the changes from that pull request in the current DiffEqFlux and running:

using Flux, DiffEqFlux, DifferentialEquations, Plots

## Setup ODE to optimize
function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
u0 = [1.0,1.0]
tspan = (0.0,10.0)
p = [1.5,1.0,3.0,1.0]
prob = ODEProblem(lotka_volterra,u0,tspan,p)

# Verify ODE solution
sol = solve(prob,Tsit5())

# Generate data from the ODE
sol = solve(prob,Tsit5(),saveat=0.1)
A = sol[1,:] # length 101 vector
t = 0:0.1:10.0

# Build a neural network that sets the cost as the difference from the
# generated data and 1

p = [2.2, 1.0, 2.0, 0.4] # Initial Parameter Vector
function predict_rd() # Our 1-layer neural network
  solve(prob, Tsit5(), p=p, saveat=0.1)[1,:]
loss_rd() = sum(abs2,x-1 for x in predict_rd()) # loss function

# Optimize the parameters so the ODE's solution stays near 1

data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function () #callback function to observe training
  # using `remake` to re-create our `prob` with current parameters `p`
# Display the ODE with the initial parameter values.
Flux.train!(loss_rd, [p], data, opt, cb = cb)

results in the following:

ERROR: LoadError: Need an adjoint for constructor ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1}
,Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}. Gradient is of type Array{Float64,2}
nmheim commented 4 years ago

we will do this with forwarddiff for now. implemented in #36