Let's take rrule for matrix multiplication as an example. At the moment we differentiate it by rewriting:
y = A * B
with
rr = rrule(*, A, B)
y = getfield(rr, 1)
pb = getfield(rr, 2)
...
drr = pb(dy)
dA = getfield(drr, 2)
dB = getfield(drr, 3)
There are several issues with this approach:
The pullback pb is a closure and thus cannot be serialized e.g. to ONNX.
Since rrule is a single call, we cannot
The code becomes much harder to read and find inconsistencies or mistakes.
If we take a look at this rrule's code:
function rrule(
::typeof(*),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Ȳ = unthunk(ȳ)
dA = @thunk(project_A(Ȳ * B'))
dB = @thunk(project_B(A' * Ȳ))
return NoTangent(), dA, dB
end
return A * B, times_pullback
end
we can see that for ordinary dense matrices it can be replaced with this:
y = A * B
...
dA = dy * B'
dB = A' * dy
which is much easier to work with.
I'm not sure if it will work well in general case, but one way to implement it is to tweak record_primitive!() to trace rrule() and split its primal and pullback code into 2 separate lists of operations. Something like:
function record_primitive!(tape::Tape{GradCtx}, v_fargs...)
v_f, v_args... = v_fargs
f, args... = [v isa V ? tape[v].val : v for v in v_fargs]
if isprimitive(ChainRulesCtx(), f, args...)
t = tape.c.tracer # a bit weird backref, but let it be for this example
res = trace!(t, get_code_info(f, args...), v_fargs...)
v_val, v_pb = tape[res].args # destructure tuple constructed as the return value from rrule
tape.c.pullbacks[v_val] = v_pb
return v_val
else
return push!(tape, mkcall(v_fargs...))
end
end
Then, during the reverse pass, we can trace the saved pullback and re-map captured values to variables from the primal subtape.
This is pretty sophisticated approach, but so far it looks doable.
Let's take
rrule
for matrix multiplication as an example. At the moment we differentiate it by rewriting:with
There are several issues with this approach:
pb
is a closure and thus cannot be serialized e.g. to ONNX.rrule
is a single call, we cannotIf we take a look at this
rrule
's code:we can see that for ordinary dense matrices it can be replaced with this:
which is much easier to work with.
I'm not sure if it will work well in general case, but one way to implement it is to tweak
record_primitive!()
to tracerrule()
and split its primal and pullback code into 2 separate lists of operations. Something like:Then, during the reverse pass, we can trace the saved pullback and re-map captured values to variables from the primal subtape.
This is pretty sophisticated approach, but so far it looks doable.
(Todo: check out how JAX implements it)