JuliaDiff / ReverseDiff.jl

Reverse Mode Automatic Differentiation for Julia
Other
348 stars 56 forks source link

Enhancement proposal: Modular tape caching #234

Open jacobusmmsmit opened 1 year ago

jacobusmmsmit commented 1 year ago

Problem

Compilation can't be used with run-time control flow. This stops some code from taking advantage of tape compilation.

Possible solution

Enable ReverseDiff's tape caching functionality to be used in cases with run-time control flow by introducing guarded/sub-tapes which are recompiled automatically if the instructions they contain are invalidated by a user-specified guard statement.

My implementation idea is that these guarded/sub-tapes live directly on normal compiled tapes as another type of AbstractInstruction (if I'm correct in assuming that it doesn't fit inside SpecialInstruction).

Here's a quick-and-dirty non-implementation showcasing the idea in action:

using ReverseDiff
import ReverseDiff: CompiledTape, GradientTape, compile, gradient!

mutable struct GuardedTape{F,T,G,V,C} # mutable because of `guard_value`
    f::F
    tape::CompiledTape{T}
    guard_f::G
    guard_value::V
    cache::C
end

function guarded_tape(func, guard_func, input) 
    tape = GradientTape(func, input)
    ctape = compile(tape)
    guard_value = guard_func(input)
    cache = Dict(guard_value => ctape)
    return GuardedTape(func, ctape, guard_func, guard_value, cache)
end

function gradient!(gt::GuardedTape, input)
    new_guard_value = gt.guard_f(input)
    if new_guard_value != gt.guard_value
        new_ctape = get!(gt.cache, new_guard_value) do
            println("Recompiling")
            tape = GradientTape(gt.f, input)
            compile(tape)
        end
        gt.guard_value = new_guard_value
        gt.tape = new_ctape
    end
    gradient!(gt.tape, input)    
end

f(x) = x[1] > 1 ? x[1]^ 3 : x[1]^2
input = [0.0]
gt = guarded_tape(f, x -> x[1]>1, input)

gradient!(gt, [0.1]) # No recompilation
gradient!(gt, [1.1]) # Recompilation triggered
gradient!(gt, [0.5]) # No recompilation

The soul of this is borrowed from JAX's static_argnums/static_argnames in jit, where users can specify an argument(s) that, if changed, triggers the lookup/recompilation step. This is essentially value dispatch. I'm not sure on its performance implications.

Impact

The original context this project is the Turing package. Gradient-based methods like HMC and NUTS are the state-of-the-art for MCMC sampling and, as stated on Turing's GSoC projects page, their performance is greatly improved by the caching features of ReverseDiff. However, this is not universally applicable and more complicated models using other packages will normally contain unavoidable control flow.

More generally, the ability to efficiently differentiate through control flow will allow ReverseDiff to be more universally recommended in packages that rely on ForwardDiff. AD backend selection is a great feature in the SciML ecosystem, and many of its packages, such as Optimization, could benefit from this contribution by making AD backend selection a potential performance footgun as opposed to a (admittedly blatant but not trivial) correctness one.

While next generation AD backends such as Diffractor and Enzyme are a hot topic in the ecosystem at the moment, ReverseDiff is a package which has stood the test of time for its reliability and performance. For workloads such as those found in Turing, "out of the box" it is almost always faster than Zygote, especially in compiled mode. Zygote may sometimes be faster, but requires far more hand-tuning to reach the necessary speeds, most of which is inaccessible to end-users.

ReverseDiff has a clear niche in the AD backend ecosystem: its target users are moderately performance sensitive with medium-to-high dimensional problems and it covers these very well with little to no hand-tuning. While Enzyme has incredible performance, which is a feature for the most performance-critical applications, it is neither trivial to use and tune, nor can it be applied in every situation due to some compatibility issues. In a similar vein, Zygote is a high performance solution that works great for applications heavy in linear algebra, but often requires significant hand-tuning.

mohamed82008 commented 1 year ago

Use the ReverseDiff.@grad macro to define an rrule for any function that has a branch. The rrule can use AbstractDifferentiation to call ReverseDiff again. This will essentially maintain a sub-tape for this particular function with dynamic control flow and will make it work even when the remaining functions' tape is compiled and cached. IIUC, this is roughly equivalent to what you are trying to do with very little engineering work.

jacobusmmsmit commented 1 year ago

Thanks for the reply. Forgive me for not understanding fully, do you think you could expand a little on how @grad could be used in combination with AD to make a "sub-tape"? In this case is the sub-tape also compiled and cached as in my toy implementation?

mohamed82008 commented 1 year ago

It would not be compiled by default but you can choose to compile 2 different tapes, one for each branch. I think you might also be able to do that lazily.

jacobusmmsmit commented 1 year ago

Could you give me a starting point that I could expand on? I'm not too familiar with AbstractDifferentiation but I'd love to build a usable MVP of this idea.

mohamed82008 commented 1 year ago

It's not easy. If you are already familiar with ReverseDiff, try reading https://github.com/JuliaDiff/AbstractDifferentiation.jl/blob/master/ext/AbstractDifferentiationReverseDiffExt.jl to understand the AD API. Then you will need to address https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/41. Then it should be easy to do a MWE. If you are interested to spend time on this, we can schedule a call to go through the work required to get it done.

jacobusmmsmit commented 1 year ago

I had a quick read of the above links as well as the AbstractDifferentiation PR about ReverseDiff. I see that it's a relatively difficult problem to solve at such a high-level (for all backends) due to type stability. I'd be interested in working on it.

ToucheSir commented 1 year ago

Saw the GSoC idea this proposal is referring to, very interesting stuff. One question from me: would this help with being able to represent dynamically-bounded loops on the tape without requiring recompilation? I can think of a few cases related to sequence/time series modelling where it would be nice to not eat tracing + tape compilation latency every time the input length changes. Some mechanism for caching sub-tapes seems like a necessary prerequisite for that, but I'm not sure if it falls under the scope of this proposal.

jacobusmmsmit commented 1 year ago

Base on my (limited) understanding of the problem I think the answer is no. That said, Mohamed may have a better idea to deal with it. Maybe Julia can do more than JAX in this regard?

mohamed82008 commented 1 year ago

One question from me: would this help with being able to represent dynamically-bounded loops on the tape without requiring recompilation? I can think of a few cases related to sequence/time series modelling where it would be nice to not eat tracing + tape compilation latency every time the input length changes.

If you have a specific example, we can think about it.

ToucheSir commented 1 year ago

The ultimate use case I have in mind is a RNN, but here is a simpler dependency-free example:

function f(xs)
    s = zero(eltype(xs))
    for (i, x) in enumerate(xs)
        s += i * x
    end
    return s
 end

julia> tape = ReverseDiff.GradientTape(f, ones(5))
typename(ReverseDiff.GradientTape)(f)

julia> ReverseDiff.gradient!(tape, ones(5))
5-element Vector{Float64}:
 1.0
 2.0
 3.0
 4.0
 5.0

julia> ReverseDiff.gradient!(tape, ones(3))
5-element Vector{Float64}:
 1.0
 2.0
 3.0
 4.0
 5.0

julia> ReverseDiff.gradient!(tape, ones(10))
ERROR: BoundsError: attempt to access 5-element Vector{Float64} at index [1:10]
Stacktrace:
  [1] throw_boundserror(A::Vector{Float64}, I::Tuple{UnitRange{Int64}})
    @ Base ./abstractarray.jl:744
  [2] checkbounds
    @ ./abstractarray.jl:709 [inlined]
  [3] _copyto_impl!(dest::Vector{Float64}, doffs::Int64, src::Vector{Float64}, soffs::Int64, n::Int64)
    @ Base ./array.jl:325
  [4] copyto!
    @ ./array.jl:319 [inlined]
  [5] copyto!
    @ ./array.jl:342 [inlined]
  [6] value!
    @ ~/.julia/packages/ReverseDiff/wIfrd/src/tracked.jl:156 [inlined]
  [7] seeded_forward_pass!
    @ ~/.julia/packages/ReverseDiff/wIfrd/src/api/tape.jl:41 [inlined]
  [8] gradient!
    @ ~/.julia/packages/ReverseDiff/wIfrd/src/api/gradients.jl:79 [inlined]
  [9] gradient!(tape::ReverseDiff.GradientTape{typeof(f), ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, input::Vector{Float64})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/wIfrd/src/api/gradients.jl:63
 [10] top-level scope
    @ REPL[18]:1

It would be nice to have a way to specify "don't unroll this loop" when tracing so that the same tape could be re-used for different input lengths.

mohamed82008 commented 1 year ago

For loops are not possible to intercept with ReverseDiff because they are not functions but if wrapped in a function, the function can be intercepted. In this case, you can define an rrule for this function which calls RD with no tape caching. This is possible now with AbstractDifferentiation.

for (i, x) in enumerate(xs)
  s += i * x
end
ToucheSir commented 1 year ago

Thanks Mohamed. I'm aware of the custom rule path, but the hope was to make use of tape caching (or I'd resort to using Zygote). Perhaps this example better describes my motivation:

function scan(f, xs, init)
  ys = empty(xs)
  h = init
  for x in xs
    h, y = f(h, x)
    push!(ys, y)
  end
  return h, ys
end

@jacobusmmsmit probably recognizes this as jax.lax.scan ;) Per the suggestion, I would have to define a rrule for scan which calls ReverseDiff on f. The problem is then that I need to give up on tape caching altogether. Is there a way to create a tape for f once, compile it and then reuse it on every iteration of the loop (while keeping nice features like mutation tracking)? I could find very little in the way of resources on how to manipulate tapes, so I assumed it would require changes to ReverseDiff itself.

jacobusmmsmit commented 1 year ago

After a call with Mohamed, I think the implementation we decided to try out will address scan quite nicely. I'll write up what we discussed here so it's public. For simplicity, I'm assuming that xs is a shaped array, so its shape is included in its type.

scan is a function that has a core computation, f, and some machinery around it to "catch" its outputs. In this case, scan's output shapes and concrete types are determined once f, xs, and init are known. We want to avoid recompiling the core computation f every time xs changes shape.

[API for this part a work in progress, I mean so is this whole thing but this part especially] To address this, we wrap f in a callable struct CachedReverseDiffBackend that contains f and a compiled tape:

struct CachedReverseDiffBackend{F, T} # Could also be parametric in backend type
    func::F
    compiled_tape::T
    # Constructor to compile the tape given inputs
    function CachedReverseDiffBackend(f::F, x) where {F}
        compiled_tape = compile(construct_tape(f, x)) # pseudo RD code
        T = typeof(compiled_tape)
        return new{F, T}(f, compiled_tape)
    end
end

compiled_f = CachedReverseDiffBackend(f, x) # where typeof(x) == eltype(xs)

then we make the cached backend callable (with the caveat that @grad only accepts functions so we define call_func too):

const CRDB = CachedReverseDiffBackend # alias for brevity

(b::CRDB)(y) = call_func(b, y)
call_func(b::CRDB, y) = b.func(y)

and define a custom rule for our CRDB structs:

function call_func(b::CRDB, y::TrackedArray)
    return ReverseDiff.track(call_func, b, y)
end

import AbstractDifferentiation as AD
ReverseDiff.@grad function call_func(b::CRDB, y)
    return AD.value_and_pullback_function(b, y) # to be implemented
end

Now we can pass compiled_f to scan instead: scan(compiled_f, xs, init), and when we try to differentiate through it with ReverseDiff.gradient, it will reach compiled_f inside the loop and see that there's a custom rule for it. The custom rule we defined (making use of AD) calls ReverseDiff.gradient on compiled_f and uses the compiled tape that we created when defining compiled_f = ....

So in the end we have an outer uncompiled tape which contains calls to inner compiled tapes.

jacobusmmsmit commented 1 year ago

My previous comment was discussing the compiled tape in an uncompiled tape case, but the uncompiled tape in a compiled tape is easier to address. I'm leaving this comment as some documentation of how this is already possible but could use some development to make it easier to use.

At the end I do have a question of how grad works.

Example showing it's already possible

Some setup

using ReverseDiff
using ReverseDiff: TrackedArray, track, @grad, value, GradientTape, compile, gradient!, gradient

First we define a function with branches. Compiling a tape with branches on it is currently a very dangerous operation as it will compile without complaining but silently return the wrong answer.

branching_f(x) = sum(x) > 1 ? sum(x)^ 3 : sum(x)^2
_branching_f(x) = sum(x) > 1 ? sum(x)^ 3 : sum(x)^2 # function used as a reference

Then we define a custom gradient with some logging to show that the right thing is happening each time.

branching_f(x::TrackedArray) = track(branching_f, x)
@grad function branching_f(x)
    xv = value(x)
    function grad_branching(Δ)
        @show sum(xv)
        if sum(xv) > 1
            println("High branch")
            return (3*sum(xv)^2*Δ, )
        else
            println("Low branch")
            return (2*sum(xv)*Δ, )
        end
    end
    return branching_f(xv), grad_branching
end

Now we construct the tapes and test that everything is running as expected:

# Construct and compile the tape
input = [0.0, 1.1, 1.0]
branching_tape = compile(GradientTape(branching_f, input))
_branching_tape = compile(GradientTape(_branching_f, input)) # This tape should ignore the branch

# One input for each branch in the function
input_low = [0.1, 0.2, 0.3]
input_high = [1.1, 1.2, 1.3]

# Test for correctness of implementation
grad_low = gradient(_branching_f, input_low)
grad_high = gradient(_branching_f, input_high)

grad_low == gradient(branching_f, input_low)
grad_high == gradient(branching_f, input_high)

# An example of the method working
grad_low == gradient!(branching_tape, input_low) # true
grad_low == gradient!(_branching_tape, input_low) # false
grad_high == gradient!(branching_tape, input_high) # true
grad_high == gradient!(_branching_tape, input_high) # true (but for the wrong reason)

Where to go from here

So, in a way, there we go. We can do modular tape caching already! But this is all very manual. It would be very nice we could have this done automatically such as:

Automatic detection of branches and a warning

julia> compile(GradientTape(_branching_tape, input))
Warning: woah buddy, you've got a branch in that function of yours, I don't think you want to compile it!

or automatic detection of branches and not compiling the branch sources (not ideal)

julia> compile(GradientTape(outer_function_with_inner_branch, my_input)) # Automatic modularisation
Warning: The tape of `outer_function_with_inner_branch` has branches because of `inner_function`,
this function was not compiled

or allowing users to define static arguments à la JAX

inner_function(x, y) = x > 0 : 2y : 3y^2
sa_inner_function = @static_arguments(inner_function, x)

outer_function_with_inner_branch(z) = sum(z) * sa_inner_function(z[1], z[2])

or ultimately automatic detection of branches and not compiling the branch sources with respect to those arguments

inner_function(x, y) = x > 0 : 2y : 3y^2
outer_function_with_inner_branch(z) = sum(z) * sa_inner_function(z[1], z[2])
compile(GradientTape(outer_function_with_inner_branch, my_input)) # All good, works as if it were uncompiled but with compiled performance where possible.

A question

What I'd like to ask is about how @grad works: Which parts of the @grad function are "frozen" when the tape is compiled? In my testing, everything defined outside of the grad_branching function would be frozen, but I couldn't find any documentation on this in ReverseDiff.

@grad function branching_f(x)
    xv = value(x)
    sum_xv = sum(xv) # This part is constant when compiled
    function grad_branching(Δ)
        (sum_xv > 1 ? 3*sum_xv^2*Δ : 2*sum_xv*Δ,) # Doesn't work at all
    end
    return branching_f(xv), grad_branching
end
jacobusmmsmit commented 1 year ago

Ok, I've got a draft implementation for defining cached sub-tapes:

import AbstractDifferentiation as AD
using ReverseDiff

using ReverseDiff: @grad, compile, GradientTape
import AbstractDifferentiation: primal_value, pullback_function, value_and_pullback_function

struct CachedReverseDiffBackend{F,T} <: AD.AbstractBackend# Could also be parametric in backend type
    func::F
    compiled_tape::T
    # Constructor to compile the tape given inputs
    function CachedReverseDiffBackend(f::F, x) where {F}
        compiled_tape = compile(GradientTape(f, x)) # pseudo RD code
        T = typeof(compiled_tape)
        return new{F,T}(f, compiled_tape)
    end
end

const CRDB = CachedReverseDiffBackend # alias for brevity

(b::CRDB)(x) = call_func(b, x)
call_func(b::CRDB, x) = b.func(x)

function call_func(b::CRDB, x::ReverseDiff.TrackedArray)
    return ReverseDiff.track(call_func, b, x)
end

@grad function call_func(b::CRDB, x)
    return value_and_pullback_function(b, x)
end

primal_value(::CRDB, xs, _) = primal_value(xs) # is this ok?

function value_and_pullback_function(cb::CRDB, x)
    xv = ReverseDiff.value(x)
    yv = cb.func(xv)

    function pullback_f(Δ)
        (Δ*ReverseDiff.gradient!(cb.compiled_tape, xv), ) # no space to cache output :/
    end
    return yv, pullback_f
end

Should this backend be a real backend i.e. should it define a @primitive?

Here's an example of how it would be used:

using BenchmarkTools
g(xs) = sum(abs2, xs)
xs = [1.0, 2.0, 3.0]
const crdb = CRDB(g, xs) # must be declared const otherwise type unstable when called
gt = compile(GradientTape(g, xs)) # RD code

# Check gradients work as intended :)
ReverseDiff.gradient(g, xs .+ 1)
ReverseDiff.gradient!(gt, xs .+ 1)
ReverseDiff.gradient!(crdb.compiled_tape, xs .+ 1)
# All return the same thing

# Define an outer function
f_nocompile(xs) = 2g(xs) # use the original `g`
f_compile(xs) = 2crdb(xs) # use the `g` with a compiled gradient

# Primal timings
@btime f_nocompile($xs) #  4.000 ns (0 allocations: 0 bytes)
@btime f_compile($xs) # 4.000 ns (0 allocations: 0 bytes)

# Gradient timings
@btime ReverseDiff.gradient(f_nocompile, $xs) # 961.750 ns (32 allocations: 1.34 KiB)
@btime ReverseDiff.gradient(f_compile, $xs) # 1.092 μs (17 allocations: 1008 bytes)

# Double-compile also works
fnc_tape = compile(GradientTape(f_nocompile, xs))
fc_tape = compile(GradientTape(f_compile, xs))

@btime ReverseDiff.gradient!(fnc_tape, $xs) # 521.266 ns (1 allocation: 80 bytes)
@btime ReverseDiff.gradient!(fc_tape, $xs) # 847.889 ns (3 allocations: 240 bytes)

As talked about in this issue, caching interfaces should be addressed as this is, I think, where the performance difference comes from.

mohamed82008 commented 1 year ago

First, sorry for the really late response.

In my testing, everything defined outside of the grad_branching function would be frozen, but I couldn't find any documentation on this in ReverseDiff.

Correct. Documentation PRs are welcome :)

The yv = cb.func(xv) in the code below will break higher order AD. I would use yv = cb(xv) instead letting dispatch do its thing.

function value_and_pullback_function(cb::CRDB, x)
    xv = ReverseDiff.value(x)
    yv = cb.func(xv)

    function pullback_f(Δ)
        (Δ*ReverseDiff.gradient!(cb.compiled_tape, xv), ) # no space to cache output :/
    end
    return yv, pullback_f
end

Should this backend be a real backend i.e. should it define a @primitive?

Your implementation right now seems to only work for scalar-valued functions. You might want to generalise it and then yes making it a primitive will give you all the other methods for free. Check the ReverseDiff backend implementation in AbstractDifferentiation for reference.

As talked about in https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/41, caching interfaces should be addressed as this is, I think, where the performance difference comes from.

Try profiling to see where the performance difference comes from. Also try a function with more inputs which might be more representative of when people use ReverseDiff. Most people would not use ReverseDiff for a function of 3 variables. If allocations are the bottleneck in your function, then we need to consider reducing those but let's check first that: 1) that's the case with profiling, and 2) that's a real problem you will run into when using the package for real sized problems.