FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.48k stars 606 forks source link

create a function that lets you change the typing on the internal arrays in a model #399

Closed isaac-rstor closed 5 years ago

isaac-rstor commented 6 years ago

I'm experimenting with different AbstractFloat data types and have some code that lets you switch up the data type on the model. I'll give a bit of a snippet here:

typeswitch(dense::Dense, T::Type) = Dense(T.(dense.W), T.(dense.b), dense.σ)
typeswitch(f::Function, T::Type) = f
typeswitch(model::Chain, T::Type) = Chain([typeswitch(l, T) for l in model.layers]...)

Happy to do a PR on what I've got... But might think about changing the name first. Any thoughts before I send the PR.

MikeInnes commented 6 years ago

We could definitely do with some more utilities here, but let me mention some things that'll make it easier to implement. The kind of mapping-over-models that you're doing is something that we try to abstract over with mapleaves.

typeswitch(T, x) = x
typeswitch(T, x::Number) = T(x)
typeswitch(T, x::AbstractArray) = T.(x)

mapleaves(x -> typeswitch(Float32, x), m)

This has the benefit that it'll work with any treelike layer, without defining a method for each one.

Also, you'll probably want to use Adapt.jl for the actual conversion; then we can easily implement the right behaviour for booleans, tracked arrays etc (for which your current code is technically incorrect in an unfortunately subtle way).

x-ref #225

rs1909 commented 5 years ago

I am trying to now switch back to Float64 (from the current default Float32), because I need the extra accuracy. It does not work. When training I get:

ERROR: LoadError:back!was already used Stacktrace: [1] error(::String) at ./error.jl:33 [2] back_(::Flux.Tracker.Call{Missing,Tuple{}}, ::Array{Float64,2}, ::Bool) at /Users/rs1909/.julia/packages/Flux/U8AZD/src/tracker/back.jl:30

The code is

function rhs(y)
  r1 = 0.99
  phi1 = 2*pi/sqrt(2)
  r2 = 0.98
  phi2 = 2*pi/sqrt(3)

  x = copy(y)
  x[1] = y[2]
  x[2] = -r1^2*y[1] + 2*r1*cos(phi1)*y[2]
  x[3] = x[4]
  x[4] = -r2^2*y[3] + 2*r2*cos(phi2)*y[4]
  return x
end

using Flux
using LinearAlgebra

  typeswitch(T, x) = x
  typeswitch(T, x::Number) = T(x)
  typeswitch(T, x::AbstractArray) = T.(x)

function construct()
  mzero = Chain( Dense(4, 64), Dense(64, 64, NNlib.relu ), Dense(64, 4))

  m = mapleaves(x -> typeswitch(Float64, x), mzero)

  loss(x, y) = Flux.mse(m(x), y)

  val = rand(4)
  xs = [val]
  ys = [rhs(val)]
  for k=1:30
      val = rand(4)
      push!(xs,val)
      push!(ys, rhs(val))
      for k=1:200
          # follow trajectories from initial conditions
          val = last(ys)
          push!(xs, val)
          push!(ys, rhs(val) )
      end
  end

  data = zip(xs, ys)

  opt = Flux.Descent(0.1)

  println(loss(m(xs[1]), ys[1]))

  for s=1:12
      Flux.train!(loss, params(m), data, opt)
      sum = 0.0
      for k=1:length(xs)
          err = m(xs[k]) - ys[k]
          erro = rhs(xs[k]) - ys[k]
          sum += max(dot(err,err))
      end
      println("Average training error ", sum/length(xs))
  end
  return m
end

m = construct()

m([0.0, 0.0, 0.0, 0.0])
j1 = Array(Flux.jacobian(m,[0.0, 0.0, 0.0, 0.0]))