FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.43k stars 598 forks source link

Flux.stack / Flux.unstack slower than Zygote.Buffer #1475

Open sdobber opened 3 years ago

sdobber commented 3 years ago

(This issue is based on the discussion https://github.com/sdobber/FluxArchitectures/issues/15): I've been advised to avoid Zygote.Buffer due to it being slow, however one alternative based on Flux.stack and Flux.unstack turns out to be even slower:

using Flux
using BenchmarkTools

# Create some data and model
x = rand(Float32, 10, 200)
y = rand(Float32, 200, 1)
lstm = LSTM(10, 1)

# Model 1 using Flux.stack and Flux.unstack
model1 = Chain(x -> Flux.unstack(x, 2),
            x -> lstm.(x),
            x -> Flux.stack(x, 1) )
loss(x, y) = Flux.mse(model(x), y)

@btime Flux.train!(loss, Flux.params(lstm), Iterators.repeated((x, y), 5), ADAM())
# 138.182 ms (647903 allocations: 41.70 MiB)

# Model 2 using Zygote.Buffer
function model2(x)
    out = Flux.Zygote.Buffer(x, size(x, 2), 1)
    for i = 1:size(x, 2)
        out[i] = lstm(x[:,i])[1]
    end
    return copy(out)
end
loss2(x, y) = Flux.mse(model2(x), y)

@btime Flux.train!(loss2, Flux.params(lstm), Iterators.repeated((x, y), 5), ADAM())
# 106.545 ms (376463 allocations: 30.53 MiB)

This example is of course rather academic; in more complex network architectures the performance penalty turns out to be even worse in my experience.

mcabbott commented 3 years ago

Flux's stack just calls cat.(xs...; dims) with some reshaping: https://github.com/FluxML/Flux.jl/blob/master/src/utils.jl#L321 I think replacing that line with x -> reshape(reduce(vcat, x), :,1,1) helps a bit, Zygote has rules for these more efficient functions.

You can also move the Flux.unstack(x, 2) out of the Chain, as the process of making all these slices has a relatively expensive gradient which I think you don't need. Still only a little faster than your model2 for me:

model4 = Chain(x -> lstm.(x),
                  x -> reshape(reduce(vcat, x), :,1,1) )
loss4(x, y) = Flux.mse(model4(x), y)
@btime Flux.train!(loss4, Flux.params(lstm), Iterators.repeated((eachcol(x), y), 5), ADAM())
sdobber commented 3 years ago

Thank you for the faster version. Do you know if a similar faster variant exists for Flux.unstack? In my actual use case, I cannot move it out of the chain because the function operates on the output of another part of the NN.

mcabbott commented 3 years ago

We should really teach Zygote about eachcol and friends. But for now there are some ideas here:

https://github.com/mcabbott/SliceMap.jl/blob/master/src/SliceMap.jl#L185

using JuliennedArrays, SliceMap will give you lazy slice/glue functions which ought to work with Zygote. Or you can copy collecteachcol. Or you can use TensorCast.jl, which contains similar code.

The issue (btw) is that making slices involves view / getindex, whose gradient (at present) allocates something like zero(x) before writing into it, for every slice. That's wasteful, if you are going to iterate over the whole x, you really only need one array to write into.