Open tkf opened 5 years ago
I started wonder if problem like this is unavoidable by design. For example, the adjoint of A::AbstractMatrix * B::AbstractMatrix
is defined by
Wouldn't it be wasteful if (say) A
is a large constant matrix and only B
depends on the input (the argument of the outer-most forward
call)?
Maybe it is better to use struct-of-pullbacks rather than pullback-of-struct? In the above case, it would be something like
@adjoint function(A::AbstractMatrix * B::AbstractMatrix)
return A * B, (Δ -> Δ * B', Δ -> A' * Δ)
end
instead. It'd be a bit hard when it turns out that you need to compute derivatives w.r.t all arguments in the end and if some of the computations for them can be shared (see also https://github.com/JuliaDiff/ChainRulesCore.jl/issues/8). But I think that can be solved relatively easily by sharing states between the pullbacks.
Does my argument make sense? Or is there already a similar facility to not compute irrelevant intermediate derivatives?
(The function *
in the OP uses the method *(::Number, ::Number)
so it is probably not the directly cause. But I wondered if the derivative w.r.t. struct (hence closure) has a similar problem.)
I think the issue is to do with not representing the sparsity in the adjoint properly. We could use a 1-hot-like vector (or variant thereof where the non-zero element is allowed to have any value) for getindex
with scalar entries in the first instance.
The issue regarding struct-of-pullbacks
vs pullback-of-structs
is orthogonal to this I think. We're actually going to move towards Zygote's pullback-of-structs
route in ChainRules as it's much easier to share computations between adjoints if you go down this route.
Separately, we have thunks to avoid actually performing certain bits of computation that we don't want. e.g. in the matrix multiplication example, we'll have pullback that when invoked returns two thunks.
Yeah, so if you look at the adjoint for getindex, you'll see that we're literally returning with only one non-zero element in the case that the forwards pass pulls out a single element of the array. This is a prime target for optimisation, and it would be great to see a PR to sort this out, at least in the single-element-index case to start with.
@willtebbutt Thanks a lot for clarifying the current situation around these problems (esp. what is planned around ChainRules and thunk).
But it's still not clear to me if everything in the OP can be explained by getindex
. I was vaguely aware of the getindex
issue so that's why I mentioned the non-getindex
variant. Don't we need another separate fix for c'proj
example? (see below for benchmarks)
What I forgot to emphasize was that getindex
and c'proj
were just easy examples of "touching large array/object." I'd rather want to know if there is a general solution to it (migration to ChainRulesCore maybe?). There are other ways to touch "large objects" (e.g., getproperty
, globals, ...).
julia> function bench(n)
c = ones(n)
proj = zeros(n)
proj[1] = 1
_, back = forward(p -> (c'proj) * p, 1)
@benchmark $back(1)
end
bench (generic function with 1 method)
julia> bench(10 ^ 3) |> display
bench(10 ^ 4) |> display
bench(10 ^ 5) |> display
BenchmarkTools.Trial:
memory estimate: 15.88 KiB
allocs estimate: 2
--------------
minimum time: 1.045 μs (0.00% GC)
median time: 1.712 μs (0.00% GC)
mean time: 2.718 μs (17.00% GC)
maximum time: 286.861 μs (98.12% GC)
--------------
samples: 10000
evals/sample: 10
BenchmarkTools.Trial:
memory estimate: 156.41 KiB
allocs estimate: 4
--------------
minimum time: 10.380 μs (0.00% GC)
median time: 12.935 μs (0.00% GC)
mean time: 16.034 μs (12.74% GC)
maximum time: 1.029 ms (93.46% GC)
--------------
samples: 10000
evals/sample: 1
BenchmarkTools.Trial:
memory estimate: 1.53 MiB
allocs estimate: 4
--------------
minimum time: 110.699 μs (0.00% GC)
median time: 129.600 μs (0.00% GC)
mean time: 162.228 μs (14.29% GC)
maximum time: 1.541 ms (55.01% GC)
--------------
samples: 10000
evals/sample: 1
Ah I see what you're saying. My immediate response is: do we care about optimising for this kind of situation? This seems like kind of a hard problem to optimise away as, from the perspective of Zygote (which only really gets to reason about what goes on inside for forward
) you're computing the dot product between two dense vectors, so there's not a lot else that it could do. @MikeInnes might have a more insightful view on this though.
That said, were you to move the creation of c
and proj
inside the function in the forward
call, I would say you've got more of a point. Then again, I would ask whether you've designed the forwards pass sensibly -- if you designed proj
and c
to be 1-hot and a Fill
respectively, then you should reasonably be able to avoid the scaling issues because you should be able to perform adjoint-local optimisations to get good performance, as opposed to Zygote having to analyse the entire function.
I guess my opinion on the matter is that the asymptotic complexity of the reverse-pass should generally be the same as the forwards pass, and this follows immediately from each @adjoint
having the same forwards- and reverse- complexity. So if you've got a situation where there's a single operation who's forward pass is O(1) and reverse is O(N), then there's a problem. That's why getindex
is straightforward to fix, but the c
/ proj
example less so.
As regards getproperty
: we should be reasonably well optimised there anyway -- representing the adjoint as a NamedTuple
of nothing
s, with one non-nothing
element is fine for most things. I'm sure that you could construct a scenario in which there's more overhead than would be ideal, but is it one that we care about?
I think there are other important situations where the interactions with constant variables are not sparse. For example input/output data to the neural nets and the random variables for dropouts and GANs are all constants from the point of view of auto-differentiation. They interact densely with the variables that depend on the "input" with which derivatives are calculated. Furthermore, the size of those constants are comparable to the variables that depend on the "inputs", especially, I think, in the context of "scientific AI" where are the models are not as big as those huge deep neural nets. See also @antoine-levitt's real-application example where this matters https://github.com/JuliaNLSolvers/NLsolve.jl/issues/205#issuecomment-526822248. It looks to me that those cases need structure-of-pullbacks approach (or equivalently thunk-based approach?) to defer the computation until needed.
(Alternatively, I suppose the AD engine can somehow "mark" the values that depend on the "input" during the forward pass so that adjoint function can know with which argument it has to take the derivatives. But this sounds more complicated than structure-of-pullbacks to me.)
Of course, the examples I noted are very basic in machine learning setup so it is very possible that I am just ignorant about existing solutions to them. But I thought there is some non-zero chance that the design of Zygote is not fully reviewed because it is not yet the main auto-differentiation engine of Flux.
Now that I mention GAN, another example is taking derivative with respect to the generator. It is wasteful to take derivative w.r.t discriminator parameters in this case.
Hmm I see your point. In a tape-based system you get to know which variables are involved in AD, and which aren't, for free. In Zygote's world you don't though. One option is to use dropgrad
(is that still the name) -- this could be implemented at the Zygote-rule level and used to manually mark up code to say that you don't need to propagate stuff through a particular bit of code. This combined with thunks should get you 90% of the way. Not a great solution, but I don't really know how hard it would be to do the code analysis required to automatically be able to drop gradients. Again, @MikeInnes would know better.
Does Zygote know internally which arguments of a function need to be tracked back to gradient
's inputs? If so: Could @adjoint
could make available some __istracked__(x)
roughly equivalent to x isa TrackedArray
, to insert near expensive calculations? Could it recognise tuples in Δ -> (dx, dy)
and simply not evaluate dx
when this is not required? I suppose the second could break something, but surely not many things.
@willtebbutt In https://github.com/JuliaDiff/ChainRulesCore.jl/pull/30, @oxinabox explained to me that that's what thunk is for and how the rules are implemented using it. You've already mentioned thunks but I guess I didn't really get it enough while writing the last comment. Now I'm convinced that thunks solve the issue I brought up.
@mcabbott I was thinking a similar solution too. But I started to think that using thunks can solve most of the problems.
Closing, since #291 will take care of this issue.
I see, thanks for the link. Maybe I finally understand what a thunk is, simpler than I imagined.
Now I'm convinced that thunks solve the issue I brought up.
As per my previous comment, I'm not sure that they really do without use in conjunction with more general knowledge regarding whether or not parents (in the computational graph sense) need to compute their own adjoints i.e. via some __istracked__
function as @mcabbott suggests. Again, not sure how that would work in Zygote.
So my understanding is that code like this
A = rand(10, 10)
x = randn(10)
y, back = forward(x -> A * x.^2, x)
back(ones(10))
is equivalent to
x1 = x.^2
back11 = Δ -> 2 .* x .* Δ
back12 = Δ -> x1 .* log.(x) .* Δ
back = back11 # derivative is taken w.r.t x
y = A * x1
back21 = Δ -> Δ * x1'
back22 = Δ -> A' * Δ
back = back ∘ back22 # derivative is taken w.r.t x1
back(ones(10))
when the thunks are used (i.e., backXX
functions would be wrapped in thunks).
Since the final back
is back11 ∘ back22
which does not run back12
or back21
, this would solve my concern. That is to say, there is no penalty for "touching" large object like A
(back21
is not run).
Admittedly I've never checked where the closure back
is composed in Zygote so I am filling up details with my imagination. Maybe I am assuming too much here?
FYI, I need an immediate solution so I created https://github.com/tkf/ChainCutters.jl which implements cut
and uncut
functions to denote constants and variables, respectively. It's a method dispatch-based solution and so I guess it's a bit fragile (say compared to contextual dispatch) but it works for me and it seepds things up like 20x when customized broadcasting is involved.
A bit more explanation: There are two ingredients for making adjoint of broadcasting fast. The first ingredient is ForwardDiff-based adjoint of broadcasting based on the code in Zygote used for CuArrays. However, since this effectively makes differentiation "eager" for all arguments, it is not an ideal approach if the arity of broadcasted function is large compared to the number of variables differentiated. So, the second ingredient is to do forward differentiation only with respect to the non-constant variables. This is only implemented for BroadcastableCallable
type.
To check the efficiency of the approach I took, I ran a benchmark with three equivalent functions below with differently implementations. f_man
is the manual implementation. f_nocut
is an equivalent code but relying on that p
is a callable. f_cut
is using constant annotations from ChainCutters.jl.
f_man(p, x) = function(c)
@unpack c0, c1, c3, c4, c5, c6, c7, c8, c9 = p
c2 = c
y = @. c0 + c1 * x +
c2 * x^2 +
c3 * x^3 +
c4 * x^4 +
c5 * x^5 +
c6 * x^6 +
c7 * x^7 +
c8 * x^8 +
c9 * x^9
return sum(y)
end
f_nocut(p, x) = function(c)
q = @set p.c2 = c
q :: Poly9 # FYI
sum(q.(x))
end
f_cut(p, x) = function(c)
q = cut(@set p.c2 = uncut(c))
sum(q.(cut(x)))
end
xs = rand(1000)
p = Poly9(rand(10)...)
suite["f_cut"] = @benchmarkable Zygote.gradient($(f_cut(p, xs)), 1.0)
suite["f_nocut"] = @benchmarkable Zygote.gradient($(f_nocut(p, xs)), 1.0)
suite["f_man"] = @benchmarkable Zygote.gradient($(f_man(p, xs)), 1.0)
Here p
is a struct of type Poly9
representing 9th order polynomial. It inherits BroadcastableCallable
type which is supported by ChainCutters.jl. @set
is a macro from Setfield.jl that "mutates" the field of immutable objects. You can read the full benchmark script here https://github.com/tkf/ChainCutters.jl/blob/master/benchmark/bench_broadcast.jl
Here is the result:
3-element BenchmarkTools.BenchmarkGroup:
tags: []
"f_man" => BenchmarkTools.Trial:
memory estimate: 1.44 MiB
allocs estimate: 41179
--------------
minimum time: 1.466 ms (0.00% GC)
median time: 1.627 ms (0.00% GC)
mean time: 2.025 ms (10.43% GC)
maximum time: 456.652 ms (1.78% GC)
--------------
samples: 2466
evals/sample: 1
"f_nocut" => BenchmarkTools.Trial:
memory estimate: 64.36 MiB
allocs estimate: 1356032
--------------
minimum time: 343.052 ms (13.19% GC)
median time: 346.456 ms (13.60% GC)
mean time: 750.353 ms (10.50% GC)
maximum time: 3.167 s (8.84% GC)
--------------
samples: 7
evals/sample: 1
"f_cut" => BenchmarkTools.Trial:
memory estimate: 37.23 KiB
allocs estimate: 176
--------------
minimum time: 114.540 μs (0.00% GC)
median time: 121.496 μs (0.00% GC)
mean time: 152.005 μs (4.88% GC)
maximum time: 197.495 ms (4.08% GC)
--------------
samples: 10000
evals/sample: 1
As you can see, constant annotation can make differentiation 3000x faster than the code without annotation and even 12x faster than "manually expanded" code.
I'd imagine Zygote.jl can automatically insert some kind of constant annotations during the forward-pass (something equivalent to cut
). Is it a reasonable approach? Are there better ways to support adjoint of broadcasting?
Thunks
.@mzgubic is going to be working on that in the coming weeks.
Here is a MWE:
I see similar effect with an alternative implementation:
As you can see, computation time of
back
grows aslength(c)
grows even though majority ofc
does not participate in the computation. Is it possible to avoid this problem?