Closed ChrisRackauckas closed 4 years ago
using DifferentialEquations
using Distributions
using Flux, DiffEqFlux, ForwardDiff
using Flux.Tracker
# Neural Network
function f(z, p)
α, β = p
tanh.(α.*z .+ β)
end
tspan = (0.0, 10.0)
function cnf(du,u,p,t)
z = @view u[1:end-1]
du[1:end-1] = f(z, p)
du[end] = -sum(ForwardDiff.jacobian((z)->f(z, p), z))
end
prob = ODEProblem(cnf,nothing,tspan,nothing)
p = param([0.0, 0.0]) # Initial Parameter Vector
params = Params([p])
function predict_adjoint(x)
diffeq_adjoint(p,prob,Tsit5(),u0=[x,false],
saveat=0.0:0.1:10.0,
sensealg=DiffEqFlux.SensitivityAlg(quad=false,
backsolve=true,autojacvec=false))
end
function loss_adjoint(xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
opt = ADAM(0.1)
raw_data = [[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);
Flux.train!(loss_adjoint, params, data, opt)
# check whether it looks standard normal
using Plots
preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];
histogram([p[1].data for p in preds])
That's a version that should be suitable for library use. Needs to be made for batching and get some testing with Flux models though.
I am getting this error
julia> Flux.train!(loss_adjoint, params, data, opt)
ERROR: UndefVarError: uf not defined
Stacktrace:
[1] #ODEAdjointProblem#18(::Array{Float64,1}, ::CallbackSet{Tuple{},Tuple{}}, ::LinearAlgebra.UniformScaling{Bool}, ::Function, ::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(cnf),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(cnf),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}}}, ::getfield(DiffEqFlux, Symbol("#df#27")){Bool}, ::Array{Float64,1}, ::Nothing, ::DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}}) at /Users/tpevny/.julia/packages/DiffEqSensitivity/DI6VG/src/adjoint_sensitivity.jl:177
[2] #ODEAdjointProblem at ./none:0 [inlined]
My status of DiffFluxEq is
(v1.1) pkg> st DiffEqFlux
Status `~/.julia/environments/v1.1/Project.toml`
[79e6a3ab] Adapt v0.4.2
[aae7a2af] DiffEqFlux v0.4.0+ #master (https://github.com/JuliaDiffEq/DiffEqFlux.jl.git)
[587475ba] Flux v0.8.2+ [`~/.julia/dev/Flux`]
[f6369f11] ForwardDiff v0.10.3
[10745b16] Statistics
You need OrdinaryDiffEq, DiffEqFlux, and DiffEqSensitivity master. If anyone could help generate the Project.toml files I will register
If I replace that "neural network" with a Flux model the above code fails.
# Neural Network
nn = Dense(1,1,tanh)
pp = destructure(nn)
function f(z, p)
m = restructure(nn,p)
return m(z)
end
using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, ForwardDiff, Tracker
# Neural Network
nn = Dense(1,1,tanh)
p = DiffEqFlux.destructure(nn)
function f(z, p)
m = DiffEqFlux.restructure(nn,p)
return m(z)
end
tspan = Float32.((0.0, 10.0))
function cnf(du,u,p,t)
z = @view u[1:end-1]
du[1:end-1] = f(z, p)
du[end] = -sum(Tracker.jacobian((z)->f(z, p), z))
end
prob = ODEProblem(cnf,nothing,tspan,nothing)
p = param(Float32[0.0, 0.0]) # Initial Parameter Vector
params = Params([p])
function predict_adjoint(x)
diffeq_adjoint(p,prob,Tsit5(),u0=[x;false],
saveat=0.0:0.1:10.0,
sensealg=DiffEqFlux.SensitivityAlg(quad=false,
backsolve=true,autojacvec=true))
end
function loss_adjoint(xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
opt = ADAM(0.1)
raw_data = [Float32[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);
Flux.train!(loss_adjoint, params, data, opt)
# check whether it looks standard normal
using Plots
preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];
is probably the closest we've gotten, and just needs to fix the nesting of Tracker: finding out why the gradient is undefined.
This is a nice simplifying example. It works until the backpass, where in that case DiffEqSensitivity's autojacvec uses Flux to do the vjps. In that case, u
is a TrackedArray and then it fails.
using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, ForwardDiff, Tracker
# Neural Network
nn = Dense(1,1,tanh)
p = Tracker.data(DiffEqFlux.destructure(nn))
DiffEqFlux.restructure(nn,p)([1.0])
tspan = Float32.((0.0, 10.0))
function cnf(du,u,p,t)
z = @view u[1:end-1]
m = DiffEqFlux.restructure(nn,p)
du[1:end-1] = m(z)
@show z
du[end] = -sum(Tracker.jacobian((z)->log.(z), z))
end
prob = ODEProblem(cnf,nothing,tspan,nothing)
p = param(Float32[0.0, 0.0]) # Initial Parameter Vector
params = Params([p])
function predict_adjoint(x)
diffeq_adjoint(p,prob,Tsit5(),u0=[x;false],
saveat=0.0:0.1:10.0,
sensealg=DiffEqFlux.SensitivityAlg(quad=false,
backsolve=true,autojacvec=true))
end
function loss_adjoint(xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
Tracker.jacobian((z)->log.(3 .* z.+z.^2), [10.0])
opt = ADAM(0.1)
raw_data = [Float32[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);
Flux.train!(loss_adjoint, params, data, opt)
The problem is that Tracker.jacobian can't nest. How to work around this @MikeInnes ?
using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, ForwardDiff, Tracker
# Neural Network
nn = Chain(Dense(1,1,tanh))
p = DiffEqFlux.destructure(nn)
tspan = Float32.((0.0, 10.0))
function cnf(u,p,t)
z = @view u[1:end-1]
m = DiffEqFlux.restructure(nn,p)
jac = -sum(Tracker.jacobian((z)->log.(z), z))
if u isa TrackedArray
res = Tracker.collect([m(z);jac])
else
res = Tracker.data([m(z);jac])
end
res
end
prob = ODEProblem(cnf,nothing,tspan,nothing)
params = Params([p])
function predict_adjoint(x)
diffeq_adjoint(p,prob,Tsit5(),u0=[x;false],
saveat=0.0:0.1:10.0,
sensealg=DiffEqFlux.SensitivityAlg(quad=false,
backsolve=true,autojacvec=true))
end
function loss_adjoint(xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
opt = ADAM(0.1)
raw_data = [Float32[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 1);
loss_adjoint(raw_data[1])
Flux.train!(loss_adjoint, params, data, opt)
iszero(Tracker.grad(nn[1].W))
works with https://github.com/FluxML/Tracker.jl/pull/24
using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, Tracker
using LinearAlgebra: tr
using Plots
using Flux: @epochs, throttle
using Tracker: forward
# Neural Network
nn = Chain(Dense(1,1,swish), Dense(1,1,identity))
# We track the parameters.
p = Flux.data(DiffEqFlux.destructure(nn))
params = param(p)
ps = Flux.params(params)
tspan = Float32.((0.0, 10.0))
# We define tr(J) to support batching.
# But it's possible to use tr(Tracker.jacobian(m, z)), it works perfectly.
function divergence(f, x::AbstractArray)
y::AbstractArray, back = forward(f, x)
D, N = size(x)
ȳ(i) = [i == j for j = 1:D]
reduce(+, transpose([back(ȳ(i))[1][i, :] for i = 1:D]))
end
# Dynamics of the CNF
function cnf_dudt_(u::TrackedArray,p,t)
z = @view u[1:end-1, :]
m = DiffEqFlux.restructure(nn, p)
jac = -divergence(m, z)
Tracker.collect([m(z);jac])
end
function cnf_dudt_(u::AbstractArray,p,t)
z = @view u[1:end-1, :]
m = DiffEqFlux.restructure(nn, p)
jac = -divergence(m, z)
Tracker.data([m(z);jac])
end
function predict_adjoint(x)
diffeq_adjoint(params,prob,Tsit5(),u0=vcat(x, zeros(Float32, (1, size(x, 2)))),
saveat=0.0:0.1:10.0,
sensealg=DiffEqFlux.SensitivityAlg(quad=false,
backsolve=false,autojacvec=true))
end
# We want to be able to sample according to x = f^(-1)(z)
# We don't need the dynamics of log P here.
function f(u, p, t)
m = DiffEqFlux.restructure(nn,p)
return Tracker.data(m(u))
end
prob = ODEProblem(cnf_dudt_,nothing,tspan,nothing)
function invsample(x::AbstractArray)
# Remember that to train with respect to the NLL we actually went in the x -> u direction
# We want to go in the u -> x direction so we just solve the ODE backward in time.
invprob = ODEProblem(f,x,(10.0, 0.), params)
solve(invprob, Tsit5(),save_everystep=false)[2]
end
function invsample(x::Real)
invsample([x])[1]
end
opt = ADAM(0.1)
model = Normal(5., 0.1)
raw_data = [Float32.(rand(model, (1, 100)))]
data = Iterators.repeated(raw_data, 10)
function loss_adjoint(xs)
pz = Normal(0.0, 1.0)
preds = predict_adjoint(xs)[:, :,end]
z = preds[1, :]
delta_logp = preds[2, :]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
cb = function()
# You can schedule the learning rate if you want.
opt.eta *= 0.95
pz = Normal(0.0, 1.0)
preds = Tracker.data(predict_adjoint(raw_data[1])[:, :,end])
zs = preds[1, :]
delta_logp = preds[2, :]
logpz = logpdf.(pz, zs)
logpx = logpz - delta_logp
loss = -mean(logpx)
perm = sortperm(raw_data[1][1, :])
pl = plot(raw_data[1][1, perm], Tracker.data.(exp.(logpx[perm])), xlims=(4, 6), ylims=(0, 5), title="Loss = $(loss)", label="Learned density")
plot!(t -> pdf(model, t), label="Real density")
samples = invsample(Float32.(rand(pz, (1, 500))))[1, :]
histogram!(samples, normalize=:pdf, alpha=.3, fillalpha=.3, label="Model samples")
display(pl)
end
# The very first invocation of predict_adjoint is very slow because of the JIT overhead.
# I don't think this is normal.
# Just be patient.
cb()
@epochs 100 Flux.train!(loss_adjoint, ps, data, opt; cb = throttle(cb, 100))
This one works for me.
@jessebett so how should we "libraryitize" CNF? Clearly the layer should be given a nice function, but what about the loss function? Is that just specific to normal distributions, should we generalize it? Is this common or does it change depending on application?
I tried to make it GPU-compatible:
using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, Tracker
using LinearAlgebra: tr
using Plots
using Flux: @epochs, throttle
using Tracker: forward
using Adapt
#using CuArrays
using LinearAlgebra
# Neural Network
nn = Chain(Dense(1,1,swish), Dense(1,1,identity)) #|> gpu
# We track the parameters.
p = Flux.data(DiffEqFlux.destructure(nn))
params = param(p)
ps = Flux.params(params)
tspan = Float32.((0.0, 10.0))
function divergence(f, x::AbstractArray)
y::AbstractArray, back = forward(f, x)
D, N = size(x)
T = DiffEqFlux.gpu_or_cpu(x)
ȳ(i) = adapt(T,[i == j for j = 1:D])
tmp = [back(ȳ(i))[1][i, :] for i = 1:D]
adapt(T,reduce(+, transpose(tmp)))
end
Tracker.@grad function divergence(f, x::TrackedArray)
y::AbstractArray, back = forward(f, x)
D, N = size(x)
T = DiffEqFlux.gpu_or_cpu(x)
ȳ(i) = T([i == j for j = 1:D])
out = reduce(+, transpose([back(ȳ(i))[1][i, :] for i = 1:D]))
out, Δ -> begin
nothing, back(Δ)
end
end
# Dynamics of the CNF
function cnf_dudt_(u::TrackedArray,p,t)
z = u[1:end-1, :]
m = DiffEqFlux.restructure(nn, p)
jac = -divergence(m, z)
Tracker.collect([m(z);jac])
end
function cnf_dudt_(u::AbstractArray,p,t)
z = u[1:end-1, :]
m = DiffEqFlux.restructure(nn, p)
jac = -divergence(m, z)
Tracker.data([m(z);jac])
end
function predict_adjoint(x)
diffeq_adjoint(params,prob,Tsit5(),u0=vcat(x, zeros(Float32, (1, size(x, 2)))),
saveat=0.0:0.1:10.0,
sensealg=DiffEqFlux.SensitivityAlg(quad=false,
backsolve=false,autojacvec=true))
end
# We want to be able to sample according to x = f^(-1)(z)
# We don't need the dynamics of log P here.
function f(u, p, t)
m = DiffEqFlux.restructure(nn,p)
return Tracker.data(m(u))
end
prob = ODEProblem(cnf_dudt_,nothing,tspan,nothing)
function invsample(x::AbstractArray)
# Remember that to train with respect to the NLL we actually went in the x -> u direction
# We want to go in the u -> x direction so we just solve the ODE backward in time.
invprob = ODEProblem(f,x,(10.0, 0.), params)
solve(invprob, Tsit5(),save_everystep=false)[2]
end
function invsample(x::Real)
invsample([x])[1]
end
opt = ADAM(0.1)
model = Normal(5., 0.1)
raw_data = [Float32.(rand(model, (1, 100)))] #.|> gpu
data = Iterators.repeated(raw_data, 10)
function loss_adjoint(xs)
preds = predict_adjoint(xs)[:, :,end]
z = preds[1, :]
delta_logp = preds[2, :]
μ = 0.0
σ = 1.0
pz = Normal(μ, σ)
#logpz = logpdf.(pz, z)
logpz = -((((z .- μ) ./ σ ).^2 .+ log(2π))./2 .- log(σ))
logpx = logpz - delta_logp
loss = -mean(logpx)
end
cb = function()
# You can schedule the learning rate if you want.
opt.eta *= 0.95
preds = Tracker.data(predict_adjoint(raw_data[1])[:, :,end])
zs = preds[1, :]
delta_logp = preds[2, :]
μ = 0.0
σ = 1.0
pz = Normal(μ, σ)
#logpz = logpdf.(pz, zs)
logpz = -((((zs .- μ) ./ σ ).^2 .+ log(2π))./2 .- log(σ))
logpx = Array(logpz - delta_logp)
loss = -mean(logpx)
_raw_data = Array.(raw_data)
perm = sortperm(_raw_data[1][1, :])
pl = plot(_raw_data[1][1, perm], Tracker.data.(exp.(logpx[perm])), xlims=(4, 6), ylims=(0, 5), title="Loss = $(loss)", label="Learned density")
plot!(t -> pdf(model, t), label="Real density")
gendata = Float32.(rand(pz, (1, 500))) #|> gpu
samples = Array(invsample(gendata)[1, :])
histogram!(pl,samples, normalize=:pdf, alpha=.3, fillalpha=.3, label="Model samples")
display(pl)
end
CuArrays.allowscalar(false)
# The very first invocation of predict_adjoint is very slow because of the JIT overhead.
# I don't think this is normal.
# Just be patient.
cb()
@epochs 100 Flux.train!(loss_adjoint, ps, data, opt; cb = throttle(cb, 100))
The issue is that
function divergence(f, x::AbstractArray)
y::AbstractArray, back = forward(f, x)
D, N = size(x)
T = DiffEqFlux.gpu_or_cpu(x)
ȳ(i) = adapt(T,[i == j for j = 1:D])
tmp = [back(ȳ(i))[1][i, :] for i = 1:D]
adapt(T,reduce(+, transpose(tmp)))
end
the adapt
calls for some reason break the gradient, and then
Tracker.@grad function divergence(f, x::TrackedArray)
y::AbstractArray, back = forward(f, x)
D, N = size(x)
T = DiffEqFlux.gpu_or_cpu(x)
ȳ(i) = T([i == j for j = 1:D])
out = reduce(+, transpose([back(ȳ(i))[1][i, :] for i = 1:D]))
out, Δ -> begin
nothing, back(Δ)
end
end
is the wrong adjoint. The adjoint of the divergence is the gradient, so I just took a stab at it but @MikeInnes might know how to fix this.
@jessebett so how should we "libraryitize" CNF? Clearly the layer should be given a nice function, but what about the loss function? Is that just specific to normal distributions, should we generalize it? Is this common or does it change depending on application?
I think the loss should be left to the user as it's really problem dependent, on the toy problem I'm playing with the loglikelihood isn't even computed that way. The base distribution doesn't have to be a normal distribution even if I'm pretty sure 99% people use a normal distribution for normalizing flows or VAEs, but if we think of it as a prior then I'm sure other more sensible and problem-specific distributions are useful.
I started writing a NormalizingFlow library that would implement the API of Distributions.jl ( because we are just fitting distributions by MLE after all, we even have access to the PDF, CDF and can sample. We can implement all the methods. ) but I stopped when you said the API would change with Zygote. But it would be useful to play with NF + Turing :)
Otherwise the only thing I would change in the API would be the ability to pass hyper-parameters to the ODE / predict_adjoint and all related functions. Right now we can only pass an AbstractArray p that will then be tracked. We would need to be able to pass a tracked p1 and untracked p2. Here p2 would just be the structure of the neural net. Having the net in the global scope is just dirty. It breaks for all kinds of reasons when playing in the repl too.
So I'd love if ODEs took parameters of the type dudt(u, p::Tuple(AbstractArray, Any), t) or dudt(u,p::AbstractArray,t; hyperparams=Nothing)
Implementing CNFs that way, as Distributions.jl objects would enable the CNFs to be mixed in other things like in Turing for example where they have Bijectors.jl
@ChrisRackauckas @aussetg is right re loss function and requiring base distribution to be normal.
However, to libraryize I think a convenience implementation could definitely be made in the normal distribution case, and then generalized for other distributions. A CNF (and FFJORD?) library implementation only needs the dynamics function m
to have jacobian(m)
working and composable (with higher order AD). In the case of FFJORD, m
only needs gradient
to work. Then a library function would take the usual dynamics given by dudt(u) = m(u)
and instead give something like
function dudt(u)
z,deltapz = u #unpack state
z = m(z) #original dynamics
deltapz = trace(jacobian(m,z)) #or hutchinson estimate in ffjord
return [z;deltapz]
So as @aussetg says, the loss should definitely be left to the user. The u0
for the above dynamics takes an initial sample from the base distribution and its log-likelihood under that base distribution, and then transforms both. Ideally, a library version of all this should not have to worry about where these samples and log-probabilities come from, so they could be Distributions.jl
. There's work to be done in both directions to make Distributions.jl
more composable with our stuff and the other way around. I like the idea of supplying a Distributions-API that allows the user to sample from/evaluate likelihoods under a flow-defined distribution just like any other. However, it would also be nice if Distributions worked for us internally.
e.g. logpz = -((((z .- μ) ./ σ ).^2 .+ log(2π))./2 .- log(σ))
is hand-coded to compute logpdf(Normal(μ,σ),z)
because the calling that from Distributions breaks autodiff in a few ways. @willtebbutt has ideas/work on improving this.
using OrdinaryDiffEq
using Distributions, FiniteDiff
using Flux, DiffEqFlux, DistributionsAD
# Hack to fix Zygote AD of Normal
Base.Irrational{:log2π}(x::Int64) = Base.Irrational{:log2π}()
function f(z, p)
α, β = p
tanh.(α.*z .+ β)
end
u0 = Float32[0.0, 0.0]
tspan = (0f0, 10f0)
function cnf(u,p,t)
z, logpz = u
α, β = p
[f(z, p),-sum(FiniteDiff.finite_difference_jacobian((z)->f(z, p), [z]))]
end
prob = ODEProblem{false}(cnf,u0,tspan,nothing)
θinit = Float32[0.0, 0.0] # Initial Parameter Vector
function predict_adjoint(x,θ)
concrete_solve(prob,Tsit5(),[x,0f0],θ,
saveat=0f0:0.1f0:10f0)
end
function loss_adjoint(θ,xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x,θ)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
function cb(θ,l)
@show l
false
end
opt = ADAM(0.1)
raw_data = Iterators.cycle(([Float32[rand(Normal(2.0, 0.1)) for i in 1:100]],))
DiffEqFlux.sciml_train(loss_adjoint, θinit, opt, raw_data, cb=cb, maxiters=100)
# check whether it looks standard normal
using Plots
preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];
histogram([p[1].data for p in preds])
is a working finite difference version, and the following uses Tracker over Zygote to nest the reverse mode:
using OrdinaryDiffEq, Zygote
using Distributions, FiniteDiff, DiffEqSensitivity
using Flux, DiffEqFlux, DistributionsAD
function jacobian(f, x::AbstractVector)
y::AbstractVector, back = Zygote.pullback(f, x)
ȳ(i) = [i == j for j = 1:length(y)]
vcat([transpose(back(ȳ(i))[1]) for i = 1:length(y)]...)
end
# Hack to fix Zygote AD of Normal
Base.Irrational{:log2π}(x::Int64) = Base.Irrational{:log2π}()
function f(z, p)
α, β = p
tanh.(α.*z .+ β)
end
u0 = Float32[0.0, 0.0]
tspan = (0f0, 10f0)
function cnf(u,p,t)
z, logpz = u
α, β = p
[f(z, p),-sum(jacobian((z)->f(z, p), [z]))]
end
prob = ODEProblem{false}(cnf,u0,tspan,nothing)
θinit = Float32[0.0, 0.0] # Initial Parameter Vector
function predict_adjoint(x,θ)
concrete_solve(prob,Tsit5(),[x,0f0],θ,
saveat=0f0:0.1f0:10f0,sensealg=InterpolatingAdjoint(
autojacvec=DiffEqSensitivity.TrackerVJP()))
end
function loss_adjoint(θ,xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x,θ)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
function cb(θ,l)
@show l
false
end
opt = ADAM(0.01)
raw_data = Iterators.cycle(([Float32[rand(Normal(2.0, 0.1)) for i in 1:100]],))
DiffEqFlux.sciml_train(loss_adjoint, θinit, opt, raw_data, cb=cb, maxiters=100)
# check whether it looks standard normal
using Plots
preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];
histogram([p[1].data for p in preds])
Now it's time to library-itize it.
@abhigupta768 here is a summary:
So @chrisrackauckas wrote this after a few iterations but I’ll just quickly give an overview as far as I can tell.
f
is a very simple Dense neural network layer. Should be replaced by something like what Chris is calling FastDense.
function cnf(u,p,t)
z, logpz = u
α, β = p
[f(z, p),-sum(FiniteDiff.finite_difference_jacobian((z)->f(z, p), [z]))]
end
This CNF is the dynamics we describe in the Neural ODE paper. I can get the equation number for you, but I highly recommend you read both the variational inference normalizing flows and the CNF section of Neural ODE paper to understand what this is doing.
The TL;DR: the complete state u
represents both the original state z
which evolves according to dz/dt = f(z,p)
AND the log-probability of that state z
at time t. Which evolves according to the -trace(J)
.
The idea is we start with some x
that we can't evaluate the loglikelihood logpx
of easily. So instead, we will put it through these cnf
dynamics, and the result hopefully be a distribution we can easily evaluate the likelihood under, e.g. a standard gaussian.
So we put x in, integrate it, and get z
out. We also get delta_logp
which is the change in log likelihood during the integration, from that second term in the dynamics.
Now that we have a z
we can evaluate its likelihood under that standard gaussian, giving logpz
.
To get logpx
we just undo the change in log-likelihood from integration: logpx = logpz - delta_logp
.
The loss we are trying to optimize is to maximize the likelihood of x
, or minimize the negative log likelihood, loss = -mean(logpx)
function loss_adjoint(θ,xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x,θ)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end