EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
450 stars 64 forks source link

Custom pullbacks for PP AD example #26

Closed femtomc closed 2 years ago

femtomc commented 3 years ago

Just to clarify my question: the way my AD objective is setup, there's a layer of indirection between the targets I'm taking the gradient with respect to, and the computation of the objective function. E.g. say I'm taking the gradient with respect to a random choice (and let's ignore hierarchical calls, and just focus on a single call)

@inline function (ctx::ChoiceBackpropagateContext)(call::typeof(trace), 
                                                   addr::T, 
                                                   d::Distribution{K}) where {T <: Address, K}
    haskey(ctx.target, addr) || return get_value(get_sub(ctx.call, addr))
    s = read_choice(ctx, addr)
    increment!(ctx, logpdf(d, s))
    return s
end

The user provides a selection of "what choices do you want grads for" and then this context will re-trace the program. read_choice has a custom adjoint, and the objective is kept in the ctx - here increment! accumulates the logpdf of the choice onto the objective. So, AD is sort of messy in this case.

Then, everything is tied together by calling pullback here:

function accumulate_choice_gradients!(fillables::S, initial_params::P, choice_grads, choice_target::K, cl::DynamicCallSite, ret_grad...) where {S <: AddressMap, P <: AddressMap, K <: Target}
    fn = (args, choices) -> begin
        ctx = ChoiceBackpropagate(cl, fillables, initial_params, choices, choice_grads, choice_target)
        ret = ctx(cl.fn, args...)
        (ctx.weight, ret)
    end
    blank = Store()
    _, back = Zygote.pullback(fn, cl.args, blank)
    arg_grads, grad_ref = back((1.0, ret_grad...))
    choice_vals = filter_acc!(choice_grads, cl, grad_ref, choice_target)
    return choice_vals, arg_grads
end

So here, the function which I pullback takes care of instantiating the context, then accumulate the objective (here, ctx.weight) etc. And the "glue" pullbacks for read_choice make sure that only terms which the user targets are accumulated.

I think this is an "advanced" Julia AD example which might be interesting to think about.

wsmoses commented 3 years ago

Taking a deeper dive on this than our earlier call, I have the following thoughts:

1) As discussed, we need to add support for Julia custom gradients. 2) I'm not as well versed in how Julia implements indirect function calls, but differentiating the call to cl.fn may require special care to ensure the bitcode is available @vchuravy 3) From there all you should need to do is use Enzyme on fn or even better cl.fn with appropriate creation of active/inactive arguments as necessary for the choice of gradients required (thereby only computing the required adjoints).

vchuravy commented 3 years ago

I'm not as well versed in how Julia implements indirect function calls, but differentiating the call to cl.fn may require special care to ensure the bitcode is available

There are two forms here, the first on is that DynamicCallSite is parameterized on fn::Fn and thus at compile time it is clear that cl.fn::Fn then Julia will potentially inline the call and if not Enzyme.jl will collect it into the bitcode file given to Enzyme. If at the callsite cl.fn::Function we have a dynamic call, and we need a custom adjoint for jl_apply_generic and need to reason much more directly about jl_value_t*.

Assuming the former, the question becomes how do I define a custom adjoint to read_choice and communicate it from Julia all the way to down to Enzyme. This is challenging in it's own right, since Julia can do inlining before Enzyme.jl sees the code, and maintaining function/method identity between Julia and LLVM is non-trivial.

I wonder if we might be able to side-step this by using function pointers...

struct EnzymeFn{rt, tt}
  primal::Ptr{Cvoid}
  adjoint::Ptr{Cvoid}
end

function (ef::EnzymeFn{rt, tt})(args...)
  attach_adjoint(ef.primal, ef.adjoint) # Enzyme intrinsic
  # complicated ABI handling
  ccall(ef.primal, rt, (tt...), args...)
end

ef = EnzymeFn{Int, Tuple{Int, Int}}(
  @cfunction(+, Int, (Int, Int)),
  @cfunction(-, #= Adjoint calling convention =#)
)
wsmoses commented 3 years ago

So the first thing that I'm going to do is make c-extern versions of the relevant gradient functions available.

@vchuravy If I have an enum is this accessible to you, or should just have them be ints.

Also by default is it reasonable to have the functions assume default argument type deduction assumptions (double arg is really a double) or is it preferred to have them be null sets.

vchuravy commented 3 years ago

If I have an enum is this accessible to you, or should just have them be ints.

Enum is fine, we just have to keep the defs in sync.

Also by default is it reasonable to have the functions assume default argument type deduction assumptions (double arg is really a double) or is it preferred to have them be null sets.

Not sure... I would be conservative and expect the frontend to fill in all of the info and if it doesn't -> null set

wsmoses commented 3 years ago

So if we have it as an option to fill it in from the frontend there'd need to be an Enzyme TypeTree passed in (C++ object). I suppose we could make a C builder for it thought

vchuravy commented 3 years ago

Yeah which would also allow us to get rid of the custom handling of Julia's tbaa.

wsmoses commented 3 years ago

I don’t think this would be quite sufficient. This alone would enable the type information at the boundaries whereas to eliminate custom tbaa handling wed also need to enable a custom Type Analysis rule — which would be very easy if one passed in an extended type analysis.

vchuravy commented 3 years ago

custom Type Analysis rule

Not sure what form that would take, but keep in mind that you can call Julia from C and as such Enzyme.jl could provide arbitrary callbacks. There would need to be a fairly robust C-API, but as an example we have custom LLVM passes written in pure Julia ;)

wsmoses commented 3 years ago

Ok the primitive C API I just exported should be sufficient for an initial version of this (once we rewrite Enzyme.jl to use the Enzyme c functions to create gradients rather than running the optimization pass).

This also probably should be moved to a separate issue but it's also worth discussing what custom Type propagation rules we'd want to enable. I presume just a custom rule for a function call is sufficient, but perhaps there's utility in doing more (like the load store for eliminating TBAA).

wsmoses commented 2 years ago

Now duplicate of https://github.com/EnzymeAD/Enzyme.jl/issues/172