Closed femtomc closed 2 years ago
Taking a deeper dive on this than our earlier call, I have the following thoughts:
1) As discussed, we need to add support for Julia custom gradients. 2) I'm not as well versed in how Julia implements indirect function calls, but differentiating the call to cl.fn may require special care to ensure the bitcode is available @vchuravy 3) From there all you should need to do is use Enzyme on fn or even better cl.fn with appropriate creation of active/inactive arguments as necessary for the choice of gradients required (thereby only computing the required adjoints).
I'm not as well versed in how Julia implements indirect function calls, but differentiating the call to cl.fn may require special care to ensure the bitcode is available
There are two forms here, the first on is that DynamicCallSite
is parameterized on fn::Fn
and thus at compile time it is clear that cl.fn::Fn
then Julia will potentially inline the call and if not Enzyme.jl will collect it into the bitcode file given to Enzyme.
If at the callsite cl.fn::Function
we have a dynamic call, and we need a custom adjoint for jl_apply_generic
and need to reason much more directly about jl_value_t*
.
Assuming the former, the question becomes how do I define a custom adjoint to read_choice
and communicate it from Julia all the way to down to Enzyme. This is challenging in it's own right, since Julia can do inlining before Enzyme.jl sees the code, and maintaining function/method identity between Julia and LLVM is non-trivial.
I wonder if we might be able to side-step this by using function pointers...
struct EnzymeFn{rt, tt}
primal::Ptr{Cvoid}
adjoint::Ptr{Cvoid}
end
function (ef::EnzymeFn{rt, tt})(args...)
attach_adjoint(ef.primal, ef.adjoint) # Enzyme intrinsic
# complicated ABI handling
ccall(ef.primal, rt, (tt...), args...)
end
ef = EnzymeFn{Int, Tuple{Int, Int}}(
@cfunction(+, Int, (Int, Int)),
@cfunction(-, #= Adjoint calling convention =#)
)
So the first thing that I'm going to do is make c-extern versions of the relevant gradient functions available.
@vchuravy If I have an enum is this accessible to you, or should just have them be ints.
Also by default is it reasonable to have the functions assume default argument type deduction assumptions (double arg is really a double) or is it preferred to have them be null sets.
If I have an enum is this accessible to you, or should just have them be ints.
Enum is fine, we just have to keep the defs in sync.
Also by default is it reasonable to have the functions assume default argument type deduction assumptions (double arg is really a double) or is it preferred to have them be null sets.
Not sure... I would be conservative and expect the frontend to fill in all of the info and if it doesn't -> null set
So if we have it as an option to fill it in from the frontend there'd need to be an Enzyme TypeTree passed in (C++ object). I suppose we could make a C builder for it thought
Yeah which would also allow us to get rid of the custom handling of Julia's tbaa.
I don’t think this would be quite sufficient. This alone would enable the type information at the boundaries whereas to eliminate custom tbaa handling wed also need to enable a custom Type Analysis rule — which would be very easy if one passed in an extended type analysis.
custom Type Analysis rule
Not sure what form that would take, but keep in mind that you can call Julia from C and as such Enzyme.jl could provide arbitrary callbacks. There would need to be a fairly robust C-API, but as an example we have custom LLVM passes written in pure Julia ;)
Ok the primitive C API I just exported should be sufficient for an initial version of this (once we rewrite Enzyme.jl to use the Enzyme c functions to create gradients rather than running the optimization pass).
This also probably should be moved to a separate issue but it's also worth discussing what custom Type propagation rules we'd want to enable. I presume just a custom rule for a function call is sufficient, but perhaps there's utility in doing more (like the load store for eliminating TBAA).
Now duplicate of https://github.com/EnzymeAD/Enzyme.jl/issues/172
Just to clarify my question: the way my AD objective is setup, there's a layer of indirection between the targets I'm taking the gradient with respect to, and the computation of the objective function. E.g. say I'm taking the gradient with respect to a random choice (and let's ignore hierarchical calls, and just focus on a single call)
The user provides a selection of "what choices do you want grads for" and then this context will re-trace the program.
read_choice
has a custom adjoint, and the objective is kept in thectx
- hereincrement!
accumulates the logpdf of the choice onto the objective. So, AD is sort of messy in this case.Then, everything is tied together by calling pullback here:
So here, the function which I pullback takes care of instantiating the context, then accumulate the objective (here, ctx.weight) etc. And the "glue" pullbacks for read_choice make sure that only terms which the user targets are accumulated.
I think this is an "advanced" Julia AD example which might be interesting to think about.