Closed isaac-rstor closed 5 years ago
We could definitely do with some more utilities here, but let me mention some things that'll make it easier to implement. The kind of mapping-over-models that you're doing is something that we try to abstract over with mapleaves
.
typeswitch(T, x) = x
typeswitch(T, x::Number) = T(x)
typeswitch(T, x::AbstractArray) = T.(x)
mapleaves(x -> typeswitch(Float32, x), m)
This has the benefit that it'll work with any treelike
layer, without defining a method for each one.
Also, you'll probably want to use Adapt.jl for the actual conversion; then we can easily implement the right behaviour for booleans, tracked arrays etc (for which your current code is technically incorrect in an unfortunately subtle way).
x-ref #225
I am trying to now switch back to Float64 (from the current default Float32), because I need the extra accuracy. It does not work. When training I get:
ERROR: LoadError:
back!was already used Stacktrace: [1] error(::String) at ./error.jl:33 [2] back_(::Flux.Tracker.Call{Missing,Tuple{}}, ::Array{Float64,2}, ::Bool) at /Users/rs1909/.julia/packages/Flux/U8AZD/src/tracker/back.jl:30
The code is
function rhs(y)
r1 = 0.99
phi1 = 2*pi/sqrt(2)
r2 = 0.98
phi2 = 2*pi/sqrt(3)
x = copy(y)
x[1] = y[2]
x[2] = -r1^2*y[1] + 2*r1*cos(phi1)*y[2]
x[3] = x[4]
x[4] = -r2^2*y[3] + 2*r2*cos(phi2)*y[4]
return x
end
using Flux
using LinearAlgebra
typeswitch(T, x) = x
typeswitch(T, x::Number) = T(x)
typeswitch(T, x::AbstractArray) = T.(x)
function construct()
mzero = Chain( Dense(4, 64), Dense(64, 64, NNlib.relu ), Dense(64, 4))
m = mapleaves(x -> typeswitch(Float64, x), mzero)
loss(x, y) = Flux.mse(m(x), y)
val = rand(4)
xs = [val]
ys = [rhs(val)]
for k=1:30
val = rand(4)
push!(xs,val)
push!(ys, rhs(val))
for k=1:200
# follow trajectories from initial conditions
val = last(ys)
push!(xs, val)
push!(ys, rhs(val) )
end
end
data = zip(xs, ys)
opt = Flux.Descent(0.1)
println(loss(m(xs[1]), ys[1]))
for s=1:12
Flux.train!(loss, params(m), data, opt)
sum = 0.0
for k=1:length(xs)
err = m(xs[k]) - ys[k]
erro = rhs(xs[k]) - ys[k]
sum += max(dot(err,err))
end
println("Average training error ", sum/length(xs))
end
return m
end
m = construct()
m([0.0, 0.0, 0.0, 0.0])
j1 = Array(Flux.jacobian(m,[0.0, 0.0, 0.0, 0.0]))
I'm experimenting with different AbstractFloat data types and have some code that lets you switch up the data type on the model. I'll give a bit of a snippet here:
Happy to do a PR on what I've got... But might think about changing the name first. Any thoughts before I send the PR.