PumasAI / SimpleChains.jl

Simple chains
MIT License
235 stars 14 forks source link

Compatibility with Flux #59

Open staticfloat opened 2 years ago

staticfloat commented 2 years ago

I think it would be really useful to allow interoperability with things like Flux, Zygote, etc...

I think much of the benefit of this package would be applicable to other projects if a SimpleChain could be used without attaching a loss or optimizer or anything else, but simply used for its allocation-free forward and backward pass. So you should be able to do something like:

using SimpleChains, Flux

model = SimpleChain(8, TurboDense(SimpleChains.tanh, 8))
p = SimpleChains.init_params(model)
Flux.train!(
    # 'loss' function, returns the value to be minimized
    (x, y) -> Flux.Losses.mse(model(x, p), y),
    # Parameters to be optimized over
    p,
    # dataset
    [(randn(8,1), randn(8,1))],
    # optimizer
    Flux.Optimise.ADAM()
)

While we would still deal with the overhead of Zygote, Flux's optimizers, etc.... we would at least be able to eliminate our model's allocation burden, which may helpful for many users.

ToucheSir commented 2 years ago

If I'm not mistaken, the definitions in https://github.com/PumasAI/SimpleChains.jl/blob/main/src/chain_rules.jl should be more than enough for this purpose. For example, your code snippet already works if I change p -> Flux.params([p]). More complex model configurations are also possible:

using Flux: mse # only required for the loss function. If you're making your own layers in a library, just write:
using Functors # for interop with Flux's module system
using Optimisers, Zygote # training loop essentials sans Flux
using SimpleChains

# Example drop-in layer that will work wherever a Flux model is expected.
# Zygote and Optimisers will let you train wrt. the params vector directly,
# but bundling state and behaviour opens up the possiblility of working with higher-level libraries like FastAI.jl.
struct WrappedSimpleChain{M,P}
  model::M
  params::P
end
@functor WrappedSimpleChain

(m::WrappedSimpleChain)(x) = m.model(x, m.params)

loss(m, x, y) = mse(m(x), y)

let
  sc = SimpleChain(8, TurboDense(identity, 8))
  p = SimpleChains.init_params(sc)
  model = WrappedSimpleChain(sc, p)
  opt_state = Optimisers.setup(Optimisers.ADAM(), model)

  x = randn(Float32, 8, 1)
  y = 0.75x
  for i in 1:10
    grads, = gradient(model) do m
      loss(m, x, y)
    end
    opt_state, model = Optimisers.update!(opt_state, model, grads)
    @info i, loss(model, x, y)
  end
end

So not only does the compatibility appear to be there, but it looks pretty future-proof as well!

chriselrod commented 2 years ago

Note that SimpleChains memory management is not threadsafe, and thus requires manual management.

The train_*! methods manage memory manually so they can multithread safely.

But simple calls/gradient calculations are (a) not multithreaded and (b) manually batching and calling with Threads.@threads, Threads.@spawn, or Polyester.@batch will give corrupted/wrong results. You can work around this via having one SimpleChain per task.