EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
461 stars 66 forks source link

Documentation about the reverse rule customization needs to be improved #2132

Open GiggleLiu opened 6 days ago

GiggleLiu commented 6 days ago

@vchuravy got the following code for differentiation the einsum! function in OMEinsum work. He also pointed out that the relevant documentation could be improved. Hope this code snippet helps.

using Enzyme, Enzyme.EnzymeRules, OMEinsum

function EnzymeRules.augmented_primal(
        config::EnzymeRules.RevConfigWidth{1},
        func::Const{typeof(einsum!)}, ::Type, 
        code::Const, xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const)
    @assert sx.val == 1 && sy.val == 0 "Only α = 1 and β = 0 is supported, got: $sx, $sy"
    # Compute primal
    if EnzymeRules.needs_primal(config)
        primal = func.val(code.val, xs.val, ys.val, sx.val, sy.val, size_dict.val)
    else
        primal = nothing
    end
    # Save x in tape if x will be overwritten
    if EnzymeRules.overwritten(config)[3]
        tape = copy(xs.val)
    else
        tape = nothing
    end
    shadow = ys.dval
    return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
               func::Const{typeof(einsum!)}, dret::Type{<:Annotation}, tape,
               code::Const,
               xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const)

   xval = EnzymeRules.overwritten(config)[3] ? tape : xs.val

   for i=1:length(xs.val)
       xs.dval[i] .+= OMEinsum.einsum_grad(OMEinsum.getixs(code.val),
             xval, OMEinsum.getiy(code.val), size_dict.val, conj(ys.dval), i)
   end
   return (nothing, nothing, nothing, nothing, nothing, nothing)
end

x = randn(3, 3);
y = randn(3);
gx = zero(x);
gy = zero(y);

function testf2(x)
    y = zeros(size(x, 1))
    einsum!(ein"ii->i", (x,), y, 1, 0, Dict('i'=>3))
    return sum(y)
end

autodiff(ReverseWithPrimal, testf2, Duplicated(x, gx))
gx

The function signature of einsum! is

einsum!(code::EinCode, xs::Tuple, y, sx, sy, size_dict::Dict=get_size_dict(getixs(code), xs))

The input y is directly changed, and the return value is the same as y.

wsmoses commented 6 days ago

perhaps make a PR on OMEinsum.jl with your rule?

GiggleLiu commented 5 days ago

perhaps make a PR on OMEinsum.jl with your rule?

Yeah, this is exactly what we were doing. @vchuravy mentioned that there are something needs to be documented, e.g.

  1. how to create shadow correctly.
  2. how to handle the output of an inplace function correctly, here the mutable array y is mutable, and also returned by the function.

Do you want to add more? @vchuravy

wsmoses commented 5 days ago

oh yeah for sure we definitely need more docs on custom rules.

Since you went through the first time process recently, would you be interested in giving it a go?

I think here would be the place to add text: https://github.com/EnzymeAD/Enzyme.jl/blob/main/examples/custom_rule.jl