FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 210 forks source link

Slow backward pass when the forward pass touches a large array #323

Open tkf opened 5 years ago

tkf commented 5 years ago

Here is a MWE:

julia> function bench(n)
           c = ones(n)
           _, back = forward(p -> c[1] * p, 1)
           @benchmark $back(1)
       end
bench (generic function with 1 method)

julia> bench(10 ^ 3) |> display
       bench(10 ^ 4) |> display
       bench(10 ^ 5) |> display
BenchmarkTools.Trial:
  memory estimate:  7.94 KiB
  allocs estimate:  1
  --------------
  minimum time:     708.962 ns (0.00% GC)
  median time:      772.438 ns (0.00% GC)
  mean time:        1.396 μs (20.10% GC)
  maximum time:     20.507 μs (85.36% GC)
  --------------
  samples:          10000
  evals/sample:     130
BenchmarkTools.Trial:
  memory estimate:  78.20 KiB
  allocs estimate:  2
  --------------
  minimum time:     2.706 μs (0.00% GC)
  median time:      5.779 μs (0.00% GC)
  mean time:        7.328 μs (14.44% GC)
  maximum time:     195.151 μs (95.96% GC)
  --------------
  samples:          10000
  evals/sample:     6
BenchmarkTools.Trial:
  memory estimate:  781.33 KiB
  allocs estimate:  2
  --------------
  minimum time:     32.468 μs (0.00% GC)
  median time:      54.223 μs (0.00% GC)
  mean time:        63.464 μs (12.27% GC)
  maximum time:     1.255 ms (73.20% GC)
  --------------
  samples:          10000
  evals/sample:     1

I see similar effect with an alternative implementation:

julia> function bench(n)
           c = ones(n)
           proj = zeros(n)
           proj[1] = 1
           _, back = forward(p -> (c'proj) * p, 1)
           @benchmark $back(1)
       end

As you can see, computation time of back grows as length(c) grows even though majority of c does not participate in the computation. Is it possible to avoid this problem?

tkf commented 5 years ago

I started wonder if problem like this is unavoidable by design. For example, the adjoint of A::AbstractMatrix * B::AbstractMatrix is defined by

https://github.com/FluxML/Zygote.jl/blob/d74f3cf5ed3c185969ff7787ee36d15c105e6653/src/lib/array.jl#L207-L209

Wouldn't it be wasteful if (say) A is a large constant matrix and only B depends on the input (the argument of the outer-most forward call)?

Maybe it is better to use struct-of-pullbacks rather than pullback-of-struct? In the above case, it would be something like

@adjoint function(A::AbstractMatrix * B::AbstractMatrix)
  return A * B, (Δ -> Δ * B', Δ -> A' * Δ)
end

instead. It'd be a bit hard when it turns out that you need to compute derivatives w.r.t all arguments in the end and if some of the computations for them can be shared (see also https://github.com/JuliaDiff/ChainRulesCore.jl/issues/8). But I think that can be solved relatively easily by sharing states between the pullbacks.

Does my argument make sense? Or is there already a similar facility to not compute irrelevant intermediate derivatives?

(The function * in the OP uses the method *(::Number, ::Number) so it is probably not the directly cause. But I wondered if the derivative w.r.t. struct (hence closure) has a similar problem.)

willtebbutt commented 5 years ago

I think the issue is to do with not representing the sparsity in the adjoint properly. We could use a 1-hot-like vector (or variant thereof where the non-zero element is allowed to have any value) for getindex with scalar entries in the first instance.

The issue regarding struct-of-pullbacks vs pullback-of-structs is orthogonal to this I think. We're actually going to move towards Zygote's pullback-of-structs route in ChainRules as it's much easier to share computations between adjoints if you go down this route.

Separately, we have thunks to avoid actually performing certain bits of computation that we don't want. e.g. in the matrix multiplication example, we'll have pullback that when invoked returns two thunks.

willtebbutt commented 5 years ago

Yeah, so if you look at the adjoint for getindex, you'll see that we're literally returning with only one non-zero element in the case that the forwards pass pulls out a single element of the array. This is a prime target for optimisation, and it would be great to see a PR to sort this out, at least in the single-element-index case to start with.

tkf commented 5 years ago

@willtebbutt Thanks a lot for clarifying the current situation around these problems (esp. what is planned around ChainRules and thunk).

But it's still not clear to me if everything in the OP can be explained by getindex. I was vaguely aware of the getindex issue so that's why I mentioned the non-getindex variant. Don't we need another separate fix for c'proj example? (see below for benchmarks)

What I forgot to emphasize was that getindex and c'proj were just easy examples of "touching large array/object." I'd rather want to know if there is a general solution to it (migration to ChainRulesCore maybe?). There are other ways to touch "large objects" (e.g., getproperty, globals, ...).

julia> function bench(n)
           c = ones(n)
           proj = zeros(n)
           proj[1] = 1
           _, back = forward(p -> (c'proj) * p, 1)
           @benchmark $back(1)
       end
bench (generic function with 1 method)

julia> bench(10 ^ 3) |> display
       bench(10 ^ 4) |> display
       bench(10 ^ 5) |> display
BenchmarkTools.Trial:
  memory estimate:  15.88 KiB
  allocs estimate:  2
  --------------
  minimum time:     1.045 μs (0.00% GC)
  median time:      1.712 μs (0.00% GC)
  mean time:        2.718 μs (17.00% GC)
  maximum time:     286.861 μs (98.12% GC)
  --------------
  samples:          10000
  evals/sample:     10
BenchmarkTools.Trial:
  memory estimate:  156.41 KiB
  allocs estimate:  4
  --------------
  minimum time:     10.380 μs (0.00% GC)
  median time:      12.935 μs (0.00% GC)
  mean time:        16.034 μs (12.74% GC)
  maximum time:     1.029 ms (93.46% GC)
  --------------
  samples:          10000
  evals/sample:     1
BenchmarkTools.Trial:
  memory estimate:  1.53 MiB
  allocs estimate:  4
  --------------
  minimum time:     110.699 μs (0.00% GC)
  median time:      129.600 μs (0.00% GC)
  mean time:        162.228 μs (14.29% GC)
  maximum time:     1.541 ms (55.01% GC)
  --------------
  samples:          10000
  evals/sample:     1
willtebbutt commented 5 years ago

Ah I see what you're saying. My immediate response is: do we care about optimising for this kind of situation? This seems like kind of a hard problem to optimise away as, from the perspective of Zygote (which only really gets to reason about what goes on inside for forward) you're computing the dot product between two dense vectors, so there's not a lot else that it could do. @MikeInnes might have a more insightful view on this though.

That said, were you to move the creation of c and proj inside the function in the forward call, I would say you've got more of a point. Then again, I would ask whether you've designed the forwards pass sensibly -- if you designed proj and c to be 1-hot and a Fill respectively, then you should reasonably be able to avoid the scaling issues because you should be able to perform adjoint-local optimisations to get good performance, as opposed to Zygote having to analyse the entire function.

I guess my opinion on the matter is that the asymptotic complexity of the reverse-pass should generally be the same as the forwards pass, and this follows immediately from each @adjoint having the same forwards- and reverse- complexity. So if you've got a situation where there's a single operation who's forward pass is O(1) and reverse is O(N), then there's a problem. That's why getindex is straightforward to fix, but the c / proj example less so.

As regards getproperty: we should be reasonably well optimised there anyway -- representing the adjoint as a NamedTuple of nothings, with one non-nothing element is fine for most things. I'm sure that you could construct a scenario in which there's more overhead than would be ideal, but is it one that we care about?

tkf commented 5 years ago

I think there are other important situations where the interactions with constant variables are not sparse. For example input/output data to the neural nets and the random variables for dropouts and GANs are all constants from the point of view of auto-differentiation. They interact densely with the variables that depend on the "input" with which derivatives are calculated. Furthermore, the size of those constants are comparable to the variables that depend on the "inputs", especially, I think, in the context of "scientific AI" where are the models are not as big as those huge deep neural nets. See also @antoine-levitt's real-application example where this matters https://github.com/JuliaNLSolvers/NLsolve.jl/issues/205#issuecomment-526822248. It looks to me that those cases need structure-of-pullbacks approach (or equivalently thunk-based approach?) to defer the computation until needed.

(Alternatively, I suppose the AD engine can somehow "mark" the values that depend on the "input" during the forward pass so that adjoint function can know with which argument it has to take the derivatives. But this sounds more complicated than structure-of-pullbacks to me.)

Of course, the examples I noted are very basic in machine learning setup so it is very possible that I am just ignorant about existing solutions to them. But I thought there is some non-zero chance that the design of Zygote is not fully reviewed because it is not yet the main auto-differentiation engine of Flux.

tkf commented 5 years ago

Now that I mention GAN, another example is taking derivative with respect to the generator. It is wasteful to take derivative w.r.t discriminator parameters in this case.

willtebbutt commented 5 years ago

Hmm I see your point. In a tape-based system you get to know which variables are involved in AD, and which aren't, for free. In Zygote's world you don't though. One option is to use dropgrad (is that still the name) -- this could be implemented at the Zygote-rule level and used to manually mark up code to say that you don't need to propagate stuff through a particular bit of code. This combined with thunks should get you 90% of the way. Not a great solution, but I don't really know how hard it would be to do the code analysis required to automatically be able to drop gradients. Again, @MikeInnes would know better.

mcabbott commented 5 years ago

Does Zygote know internally which arguments of a function need to be tracked back to gradient's inputs? If so: Could @adjoint could make available some __istracked__(x) roughly equivalent to x isa TrackedArray, to insert near expensive calculations? Could it recognise tuples in Δ -> (dx, dy) and simply not evaluate dx when this is not required? I suppose the second could break something, but surely not many things.

tkf commented 5 years ago

@willtebbutt In https://github.com/JuliaDiff/ChainRulesCore.jl/pull/30, @oxinabox explained to me that that's what thunk is for and how the rules are implemented using it. You've already mentioned thunks but I guess I didn't really get it enough while writing the last comment. Now I'm convinced that thunks solve the issue I brought up.

@mcabbott I was thinking a similar solution too. But I started to think that using thunks can solve most of the problems.

Closing, since #291 will take care of this issue.

mcabbott commented 5 years ago

I see, thanks for the link. Maybe I finally understand what a thunk is, simpler than I imagined.

willtebbutt commented 5 years ago

Now I'm convinced that thunks solve the issue I brought up.

As per my previous comment, I'm not sure that they really do without use in conjunction with more general knowledge regarding whether or not parents (in the computational graph sense) need to compute their own adjoints i.e. via some __istracked__ function as @mcabbott suggests. Again, not sure how that would work in Zygote.

tkf commented 5 years ago

So my understanding is that code like this

A = rand(10, 10)
x = randn(10)
y, back = forward(x -> A * x.^2, x)
back(ones(10))

is equivalent to

x1 = x.^2
back11 = Δ -> 2 .* x .* Δ
back12 = Δ -> x1 .* log.(x) .* Δ

back = back11  # derivative is taken w.r.t x

y = A * x1
back21 = Δ -> Δ * x1'
back22 = Δ -> A' * Δ

back = back ∘ back22  # derivative is taken w.r.t x1

back(ones(10))

when the thunks are used (i.e., backXX functions would be wrapped in thunks).

Since the final back is back11 ∘ back22 which does not run back12 or back21, this would solve my concern. That is to say, there is no penalty for "touching" large object like A (back21 is not run).

Admittedly I've never checked where the closure back is composed in Zygote so I am filling up details with my imagination. Maybe I am assuming too much here?

tkf commented 5 years ago

FYI, I need an immediate solution so I created https://github.com/tkf/ChainCutters.jl which implements cut and uncut functions to denote constants and variables, respectively. It's a method dispatch-based solution and so I guess it's a bit fragile (say compared to contextual dispatch) but it works for me and it seepds things up like 20x when customized broadcasting is involved.

tkf commented 5 years ago

A bit more explanation: There are two ingredients for making adjoint of broadcasting fast. The first ingredient is ForwardDiff-based adjoint of broadcasting based on the code in Zygote used for CuArrays. However, since this effectively makes differentiation "eager" for all arguments, it is not an ideal approach if the arity of broadcasted function is large compared to the number of variables differentiated. So, the second ingredient is to do forward differentiation only with respect to the non-constant variables. This is only implemented for BroadcastableCallable type.

To check the efficiency of the approach I took, I ran a benchmark with three equivalent functions below with differently implementations. f_man is the manual implementation. f_nocut is an equivalent code but relying on that p is a callable. f_cut is using constant annotations from ChainCutters.jl.

f_man(p, x) = function(c)
    @unpack c0, c1, c3, c4, c5, c6, c7, c8, c9 = p
    c2 = c
    y = @. c0 + c1 * x +
        c2 * x^2 +
        c3 * x^3 +
        c4 * x^4 +
        c5 * x^5 +
        c6 * x^6 +
        c7 * x^7 +
        c8 * x^8 +
        c9 * x^9
    return sum(y)
end

f_nocut(p, x) = function(c)
    q = @set p.c2 = c
    q :: Poly9  # FYI
    sum(q.(x))
end

f_cut(p, x) = function(c)
    q = cut(@set p.c2 = uncut(c))
    sum(q.(cut(x)))
end

xs = rand(1000)
p = Poly9(rand(10)...)
suite["f_cut"] = @benchmarkable Zygote.gradient($(f_cut(p, xs)), 1.0)
suite["f_nocut"] = @benchmarkable Zygote.gradient($(f_nocut(p, xs)), 1.0)
suite["f_man"] = @benchmarkable Zygote.gradient($(f_man(p, xs)), 1.0)

Here p is a struct of type Poly9 representing 9th order polynomial. It inherits BroadcastableCallable type which is supported by ChainCutters.jl. @set is a macro from Setfield.jl that "mutates" the field of immutable objects. You can read the full benchmark script here https://github.com/tkf/ChainCutters.jl/blob/master/benchmark/bench_broadcast.jl

Here is the result:

3-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "f_man" => BenchmarkTools.Trial:
          memory estimate:  1.44 MiB
          allocs estimate:  41179
          --------------
          minimum time:     1.466 ms (0.00% GC)
          median time:      1.627 ms (0.00% GC)
          mean time:        2.025 ms (10.43% GC)
          maximum time:     456.652 ms (1.78% GC)
          --------------
          samples:          2466
          evals/sample:     1
  "f_nocut" => BenchmarkTools.Trial:
          memory estimate:  64.36 MiB
          allocs estimate:  1356032
          --------------
          minimum time:     343.052 ms (13.19% GC)
          median time:      346.456 ms (13.60% GC)
          mean time:        750.353 ms (10.50% GC)
          maximum time:     3.167 s (8.84% GC)
          --------------
          samples:          7
          evals/sample:     1
  "f_cut" => BenchmarkTools.Trial:
          memory estimate:  37.23 KiB
          allocs estimate:  176
          --------------
          minimum time:     114.540 μs (0.00% GC)
          median time:      121.496 μs (0.00% GC)
          mean time:        152.005 μs (4.88% GC)
          maximum time:     197.495 ms (4.08% GC)
          --------------
          samples:          10000
          evals/sample:     1

As you can see, constant annotation can make differentiation 3000x faster than the code without annotation and even 12x faster than "manually expanded" code.

I'd imagine Zygote.jl can automatically insert some kind of constant annotations during the forward-pass (something equivalent to cut). Is it a reasonable approach? Are there better ways to support adjoint of broadcasting?

oxinabox commented 4 years ago

603 is required to fix this, by enabling support for Thunks.

@mzgubic is going to be working on that in the coming weeks.