dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

RFC: don't automatically unthunk grad outputs #134

Closed ToucheSir closed 1 year ago

ToucheSir commented 1 year ago

This is the other change I made while optimizing https://github.com/jeremiedb/ADTests.jl/blob/main/experiments/yota/dense.jl. By not unthunking all outputs by default, we avoid materializing the gradient of the input x. This saves both a significant amount of compute and memory. It's also highly breaking, hence the draft PR. Would there be any interest in having this as an option in grad or lower level API?

mcabbott commented 1 year ago

FWIW, https://github.com/JuliaDiff/Diffractor.jl/pull/79 went the other way. There's something to be said for the high-level functions never giving you a thunk.

You can of course write val, g = grad(m -> loss(m, x), m) to discard.

ToucheSir commented 1 year ago

grad(m -> loss(m, x), m) doesn't work because the getproperty pullback will unthunk <closure type>.x. In my mind making that not unthunk is a much bigger change, but perhaps it would bring other advantages too.

Either way, don't think any of this interferes with https://github.com/JuliaDiff/ChainRulesCore.jl/pull/568? The mention of Diffractor using escape analysis instead initially excited me, but given how far that's (not) developed and where Diffractor sits in the AD performance ranking, seeing thunks bring tangible performance benefits is way more appealing right now.

dfdx commented 1 year ago

Is it correct to say that thunks only save compute and memory during differentiating, but eventually they need to be materialized and at this point they take the same amount of resources as eager execution would take?

ToucheSir commented 1 year ago

I think InplaceableThunk can save resources when materializing compared to eager execution, but not 100% sure.

mcabbott commented 1 year ago

Yes. In practice using + to accumulate gradients (when multiple functions take the same x as input) will never use InplaceThunks at all right now, see e.g. these allocs. Their only function at present is not to perform some work / use some memory for gradients which aren't needed at all.

My Diffractor link above discusses many things, sorry. But also shows Yota (at the time) doing this right, here with grad(x -> f(x, b), a). So I don't follow what this comment is saying:

doesn't work because the getproperty pullback will unthunk <closure type>.x

Is this some quirk of the * rule which doesn't apply to that f? Or has getproperty been changed to be more eager about unthunking internally somewhere?

ToucheSir commented 1 year ago

I remember seeing that, which is why I was surprised when it didn't work on the ADTests example! The extra amount allocated (8MB on top of 65MB) is exactly the size of x. Here's a modified version of that file which shows this behaviour:

using Yota
using BenchmarkTools
using Random: seed!

seed!(123)
const bs = 4096
const f = 256
const h1 = 512

struct Linear{A,B}
    w::A
    b::B
end

(m::Linear)(x) = exp.(m.w * x .+ m.b)

let
    w1 = randn(h1, f) .* 0.01;
    b1 = randn(h1);
    x = randn(f, bs) .* 0.01;

    m = Linear(w1, b1)
    loss(m) = sum(m(x))

    @btime grad($loss, $m; seed=1.0); # 41.837 ms (64 allocations: 73.01 MiB)

    tape_func = Yota.GRAD_CACHE |> values |> first
    @bprofile $tape_func($loss, $m);
    # VSCodeServer.view_profile()
end

Running Cthulhu on tape_func, I see the following line:

%75 = invoke Yota.var"#getfield_pullback#34"{var"#loss#1"{Matrix{Float64}}, Symbol, typeof(identity)}(:x, identity)(%71::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1485#1490"{Matrix{Float64}, Matrix{Float64}}}, ChainRules.var"#1484#1489"{Matrix{Float64}, Matrix{Float64}}})::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.Tangent{var"#loss#1"{Matrix{Float64}}}, ChainRulesCore.NoTangent}

Which is clearly unthunking if we look at https://github.com/dfdx/Yota.jl/blob/v0.8.2/src/rulesets.jl#L63. Git blame attributes this to https://github.com/dfdx/Yota.jl/commit/928dbce57cd852c38ffdd5dd325664725dee990f, which seems to have happened after your comment on the Diffractor PR.

dfdx commented 1 year ago

Here's my take on thunks. I believe, in short term we may use thunks to make micro-optimisations like this, but in longer term it's more advantageous to get rid of them.

I expect Julia AD to eventually converge to 2-3 large engines with different tradeoffs. Perhaps, one similar to Diffractor with heavy compiler-level optimizations and targeting SciML, another more traditional based on computational graphs targeting conventional deep learning and maybe one more experimental. But all these engines will be pretty sophisticated, bringing a lot of optimisations by themselves.

These optimisations may include, for example, gradient checkpointing or training with mixed precision. These are huge optimisations that can reduce training time and memory by several times - much more than thunks. And these optimisations are easier to implement without thunks involved.

So in this bright future I imagine ChainRules to be a minimalistic repository of rules without any added machinery, and AD engines to take care of all the optimisations.

At the same time, we are not there yet, and 11% memory reduction (8Mb / 65+8Mb) is quite a lot. If it's reproducible on other tests too (maybe not 11%, but 2-3% is still good), let's incorporate the change.

dfdx commented 1 year ago

In fact, we can indeed add an option to ignore gradients of some inputs and don't unthunk or even don't calculate only their derivatives. This should a non-breaking change.

mcabbott commented 1 year ago

Thanks for the digging above. It sounds like the problem is getproperty, then, not the overall thunking story. Can it be fixed more narrowly?

I think that never returning thunks from user-facing functions is probably a good policy. You ask for a gradient, you get a gradient, not some weird internal delayed object.

Re benchmarks, for typical neural network things, materialising the gradient with respect to the data is likely to always be a tiny effect. It only matters for the first layer, all subsequent layers cannot avoid this work.

The places where something like thunks can have a large effect are (1) calculations with many static inputs along the way (e.g. suppose you wanted only the gradient with respect to x, not the weights of a neural network, then this would be a factor of 2 effect) and (2) calculations with repeated use of the same inputs (where InplaceThunks, or some other way of writing rules to use mul! not *, can materialise once instead of N times).

ToucheSir commented 1 year ago

Agreed with all the above. There are certainly cases where in-place accumulation could save a decent amount of memory (Parallel comes to mind), but this is not one of them. Also I'm not aware of any CR-enabled AD which uses add!!. Either way, between https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592 and activity annotations, there should be plenty to explore when it comes to reducing unnecessary computation and allocaitons.