Closed QiyaoWei closed 3 years ago
using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
using Statistics
u0 = Float32[0.0; 0.0]
tspan = (0.0f0, 10.0f0)
ann = FastChain(
FastDense(4, 4, relu),
FastDense(4, 4, tanh))
p1 = initial_params(ann)
n = length(p1)
ps = Float32[p1;u0]
dudt_(u, p, t) = ann(u, p[1:n]) - u
# Or
#function dudt_(du, u, p, t)
# Solving the equation f(u) - u = du = 0
# du .= ann(u, p[1:n]) - u
#end
ode = ODEProblem(dudt_, u0, tspan, ps)
ss = SteadyStateProblem(ode)
function predict(x)
Array(solve(ss, DynamicSS(Rodas5()), u0 = [u0;x], p = ps, sensealg=SteadyStateAdjoint()))
end
# https://medium.com/coffee-in-a-klein-bottle/deep-learning-with-julia-e7f15ad5080b
#Auxiliary functions for generating our data
function generate_real_data(n)
x1 = rand(n) .- 0.5
x2 = (x1 .* x1)*3 .+ randn(n)*0.1
return vcat(x1,x2)
end
function generate_fake_data(n)
θ = 2*π*rand(n)
r = rand(n)/3
x1 = @. r*cos(θ)
x2 = @. r*sin(θ)+0.5
return vcat(x1,x2)
end
# Creating our data
train_size = 1
real = generate_real_data(train_size)
fake = generate_fake_data(train_size)
# Organizing the data in batches
X = hcat(real,fake)
temp = vcat(ones(train_size),zeros(train_size))
Y = vcat(temp, temp, temp, temp)
data = Flux.Data.DataLoader((X, reshape(Y, 4, size(X)[2])), batchsize=1,shuffle=true)
opt = ADAM(0.05)
function loss(x, y)
ŷ = predict(x)
@show sum((y .- ŷ).^2)
end
epochs = 1000
for i in 1:epochs
Flux.train!(loss, Flux.Params([ps]), data, opt)
#println(mean(ann([u0;real],ps[1:n])),mean(ann([u0;fake],ps[1:n]))) # Print model prediction
end
Works. The main issue, that stack overflow, was a recently introduced issue with a quick fix: https://github.com/SciML/SciMLBase.jl/pull/56 which is now tagged. Other issues you had in here:
du .= ann(u, p[1:n]) - u
. Remember that f(du,u,p,t)
is a mutating function, so it needs to mutate the output. In this case it makes more sense to use the non-mutating form f(u,p,t)
, i.e. dudt_(u, p, t) = ann(u, p[1:n]) - u
.println(mean(ann([u0;real],ps[1:n])),mean(ann([u0;fake],ps[1:n])))
whatever you were printing wasn't taking in the values.u0
part is doing here?
Hi all,
I am trying to use DEQ models to fit 1 toy data point by wrapping the neural net in a SteadyStateProblem. However, when I run a piece of code like this
I get the following error. Any help would be appreciated!