Documentation about the reverse rule customization needs to be improved #2132

GiggleLiu commented 6 days ago

@vchuravy got the following code for differentiation the einsum! function in OMEinsum work. He also pointed out that the relevant documentation could be improved. Hope this code snippet helps.

using Enzyme, Enzyme.EnzymeRules, OMEinsum

function EnzymeRules.augmented_primal(
        func::Const{typeof(einsum!)}, ::Type, 
        code::Const, xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const)
    @assert sx.val == 1 && sy.val == 0 "Only α = 1 and β = 0 is supported, got: $sx, $sy"
    # Compute primal
    if EnzymeRules.needs_primal(config)
        primal = func.val(code.val, xs.val, ys.val, sx.val, sy.val, size_dict.val)
        primal = nothing
    # Save x in tape if x will be overwritten
    if EnzymeRules.overwritten(config)[3]
        tape = copy(xs.val)
        tape = nothing
    shadow = ys.dval
    return EnzymeRules.AugmentedReturn(primal, shadow, tape)

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
               func::Const{typeof(einsum!)}, dret::Type{<:Annotation}, tape,
               xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const)

   xval = EnzymeRules.overwritten(config)[3] ? tape : xs.val

   for i=1:length(xs.val)
       xs.dval[i] .+= OMEinsum.einsum_grad(OMEinsum.getixs(code.val),
             xval, OMEinsum.getiy(code.val), size_dict.val, conj(ys.dval), i)
   return (nothing, nothing, nothing, nothing, nothing, nothing)

x = randn(3, 3);
y = randn(3);
gx = zero(x);
gy = zero(y);

function testf2(x)
    y = zeros(size(x, 1))
    einsum!(ein"ii->i", (x,), y, 1, 0, Dict('i'=>3))
    return sum(y)

autodiff(ReverseWithPrimal, testf2, Duplicated(x, gx))

The function signature of einsum! is

einsum!(code::EinCode, xs::Tuple, y, sx, sy, size_dict::Dict=get_size_dict(getixs(code), xs))

The input y is directly changed, and the return value is the same as y.

wsmoses commented 6 days ago

perhaps make a PR on OMEinsum.jl with your rule?

GiggleLiu commented 5 days ago

Yeah, this is exactly what we were doing. @vchuravy mentioned that there are something needs to be documented, e.g.

  1. how to create shadow correctly.
  2. how to handle the output of an inplace function correctly, here the mutable array y is mutable, and also returned by the function.

Do you want to add more? @vchuravy

wsmoses commented 5 days ago

oh yeah for sure we definitely need more docs on custom rules.

Since you went through the first time process recently, would you be interested in giving it a go?

I think here would be the place to add text: https://github.com/EnzymeAD/Enzyme.jl/blob/main/examples/custom_rule.jl