Open staticfloat opened 2 years ago
If I'm not mistaken, the definitions in https://github.com/PumasAI/SimpleChains.jl/blob/main/src/chain_rules.jl should be more than enough for this purpose. For example, your code snippet already works if I change p -> Flux.params([p])
. More complex model configurations are also possible:
using Flux: mse # only required for the loss function. If you're making your own layers in a library, just write:
using Functors # for interop with Flux's module system
using Optimisers, Zygote # training loop essentials sans Flux
using SimpleChains
# Example drop-in layer that will work wherever a Flux model is expected.
# Zygote and Optimisers will let you train wrt. the params vector directly,
# but bundling state and behaviour opens up the possiblility of working with higher-level libraries like FastAI.jl.
struct WrappedSimpleChain{M,P}
model::M
params::P
end
@functor WrappedSimpleChain
(m::WrappedSimpleChain)(x) = m.model(x, m.params)
loss(m, x, y) = mse(m(x), y)
let
sc = SimpleChain(8, TurboDense(identity, 8))
p = SimpleChains.init_params(sc)
model = WrappedSimpleChain(sc, p)
opt_state = Optimisers.setup(Optimisers.ADAM(), model)
x = randn(Float32, 8, 1)
y = 0.75x
for i in 1:10
grads, = gradient(model) do m
loss(m, x, y)
end
opt_state, model = Optimisers.update!(opt_state, model, grads)
@info i, loss(model, x, y)
end
end
So not only does the compatibility appear to be there, but it looks pretty future-proof as well!
Note that SimpleChains
memory management is not threadsafe, and thus requires manual management.
The train_*!
methods manage memory manually so they can multithread safely.
But simple calls/gradient calculations are (a) not multithreaded and (b) manually batching and calling with Threads.@threads
, Threads.@spawn
, or Polyester.@batch
will give corrupted/wrong results.
You can work around this via having one SimpleChain
per task.
I think it would be really useful to allow interoperability with things like Flux, Zygote, etc...
I think much of the benefit of this package would be applicable to other projects if a
SimpleChain
could be used without attaching a loss or optimizer or anything else, but simply used for its allocation-free forward and backward pass. So you should be able to do something like:While we would still deal with the overhead of Zygote, Flux's optimizers, etc.... we would at least be able to eliminate our model's allocation burden, which may helpful for many users.