dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

Primitivize rrules #103

Closed dfdx closed 2 years ago

dfdx commented 2 years ago

Let's take rrule for matrix multiplication as an example. At the moment we differentiate it by rewriting:

y = A * B

with

rr = rrule(*, A, B)
y = getfield(rr, 1)
pb = getfield(rr, 2)
...
drr = pb(dy)
dA = getfield(drr, 2)
dB = getfield(drr, 3)

There are several issues with this approach:

  1. The pullback pb is a closure and thus cannot be serialized e.g. to ONNX.
  2. Since rrule is a single call, we cannot
  3. The code becomes much harder to read and find inconsistencies or mistakes.

If we take a look at this rrule's code:

function rrule(
    ::typeof(*),
    A::AbstractVecOrMat{<:CommutativeMulNumber},
    B::AbstractVecOrMat{<:CommutativeMulNumber},
)
    project_A = ProjectTo(A)
    project_B = ProjectTo(B)
    function times_pullback(ȳ)
        Ȳ = unthunk(ȳ)
        dA = @thunk(project_A(Ȳ * B'))
        dB = @thunk(project_B(A' * Ȳ))
        return NoTangent(), dA, dB
    end
    return A * B, times_pullback
end

we can see that for ordinary dense matrices it can be replaced with this:

y = A * B
...
dA = dy * B'
dB = A' * dy

which is much easier to work with.

I'm not sure if it will work well in general case, but one way to implement it is to tweak record_primitive!() to trace rrule() and split its primal and pullback code into 2 separate lists of operations. Something like:

function record_primitive!(tape::Tape{GradCtx}, v_fargs...)
    v_f, v_args... = v_fargs
    f, args... = [v isa V ? tape[v].val : v for v in v_fargs]
    if isprimitive(ChainRulesCtx(), f, args...)
        t = tape.c.tracer   # a bit weird backref, but let it be for this example
        res = trace!(t, get_code_info(f, args...), v_fargs...)
        v_val, v_pb = tape[res].args    # destructure tuple constructed as the return value from rrule
        tape.c.pullbacks[v_val] = v_pb        
        return v_val
    else
        return push!(tape, mkcall(v_fargs...))
    end
end

Then, during the reverse pass, we can trace the saved pullback and re-map captured values to variables from the primal subtape.

This is pretty sophisticated approach, but so far it looks doable.

(Todo: check out how JAX implements it)

dfdx commented 2 years ago

With the current vision, this ideas is unlikely to land in the foreseeable future.