Open sdobber opened 3 years ago
Flux's stack
just calls cat.(xs...; dims)
with some reshaping:
https://github.com/FluxML/Flux.jl/blob/master/src/utils.jl#L321
I think replacing that line with x -> reshape(reduce(vcat, x), :,1,1)
helps a bit, Zygote has rules for these more efficient functions.
You can also move the Flux.unstack(x, 2)
out of the Chain, as the process of making all these slices has a relatively expensive gradient which I think you don't need. Still only a little faster than your model2
for me:
model4 = Chain(x -> lstm.(x),
x -> reshape(reduce(vcat, x), :,1,1) )
loss4(x, y) = Flux.mse(model4(x), y)
@btime Flux.train!(loss4, Flux.params(lstm), Iterators.repeated((eachcol(x), y), 5), ADAM())
Thank you for the faster version. Do you know if a similar faster variant exists for Flux.unstack
? In my actual use case, I cannot move it out of the chain because the function operates on the output of another part of the NN.
We should really teach Zygote about eachcol
and friends. But for now there are some ideas here:
https://github.com/mcabbott/SliceMap.jl/blob/master/src/SliceMap.jl#L185
using JuliennedArrays, SliceMap
will give you lazy slice/glue functions which ought to work with Zygote. Or you can copy collecteachcol
. Or you can use TensorCast.jl, which contains similar code.
The issue (btw) is that making slices involves view / getindex, whose gradient (at present) allocates something like zero(x)
before writing into it, for every slice. That's wasteful, if you are going to iterate over the whole x
, you really only need one array to write into.
(This issue is based on the discussion https://github.com/sdobber/FluxArchitectures/issues/15): I've been advised to avoid
Zygote.Buffer
due to it being slow, however one alternative based onFlux.stack
andFlux.unstack
turns out to be even slower:This example is of course rather academic; in more complex network architectures the performance penalty turns out to be even worse in my experience.