EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
439 stars 62 forks source link

How to use ChainRules rrules with autodiff? #583

Closed maxfreu closed 4 months ago

maxfreu commented 1 year ago

I couldn't find a better title, so let me explain: I have a function interpolated via BSplineKit, based on some lookup table. BSplineKit provides rrules for differentiation of the interpolated function. Now I want to use this function in some code, that I want to differentiate with Enzyme, because Zygote is too slow. Is that already possible? My first lazy attempts resulted in segfaults :(

vchuravy commented 1 year ago

No that is currently not possible. #172 is the issue to follow, but we are not targeting ChainRules support from the get-go, but rather allowing the user to provide rules in an Enzyme compatible fashion.

maxfreu commented 1 year ago

Thanks for the hint & the development work!

CarloLucibello commented 7 months ago

I would suggest reopening this, unless there is some big technical blocker. A lot of work has gone into defining a big set of rules into https://github.com/JuliaDiff/ChainRules.jl, it may take some time to EnzymeRules to catch up.

Moreover, if we want to encourage a smooth transition to Enzyme of Flux users and generally of the ML ecosystem (see #805), it would be nice to support the custom rules that people have been writing with ChainRules for years.

wsmoses commented 7 months ago

I'm not opposed to this, but there are several challenges/limitations. 1) Most ChainRules implicitly assumes that mutation does not occur. For example, consider the A * B rule. The pullback would store A and B (by reference). However, if A is overwritten from forward to reverse, the data an A will have changed and the ChainRule will silently get the answer wrong as a result. 2) Most rules implemented with ChainRules should not have rules in Enzyme. Because Enzyme doesn't usually need rules for all Julia functions (e.g. working from the lower level up, rules for most code can ben generated from the definition). As a consequence the question is now not whether a function needs a rule, but whether a function should have a rule. If this is the case, it is often (though not always) for performance reasons, which means you probably want the EnzymeRule to be fast.

I'd earlier written ChainRules to EnzymeRules importers for Forward and Reverse Mode in this PR: https://github.com/EnzymeAD/Enzyme.jl/pull/996/files (see import_frule and import_rrule). These may be helpful for a quick conversion, but they have certain limitations, and performance limitations.

Specifically

    import_frule(::fn, tys...)
Automatically import a ChainRules.frule as a custom forward mode EnzymeRule. When called in batch mode, this
will end up calling the primal multiple times, which may result in incorrect behavior if the function mutates,
and slow code, always. Importing the rule from ChainRules is also likely to be slower than writing your own rule,
and may also be slower than not having a rule at all.

Use with caution.

Enzyme.@import_frule(typeof(Base.sort), Any);
x=[1.0, 2.0, 0.0]; dx=[0.1, 0.2, 0.3]; ddx = [0.01, 0.02, 0.03];
Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,ddx)))
Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,ddx)))
Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,)))
Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,)))
# output
(var"1" = [0.0, 1.0, 2.0], var"2" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02]))
(var"1" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02]),)
(var"1" = [0.3, 0.1, 0.2],)
(var"1" = [0.0, 1.0, 2.0], var"2" = [0.3, 0.1, 0.2])
    import_rrule(::fn, tys...)
Automatically import a ChainRules.rrule as a custom reverse mode EnzymeRule. When called in batch mode, this
will end up calling the primal multiple times which results in slower code. This macro assumes that the underlying
function to be imported is read-only, and returns a Duplicated or Const object. This macro also assumes that the
inputs permit a .+= operation and that the output has a valid Enzyme.Compiler.make_zero function defined.
Finally, this macro falls back to almost always caching all of the inputs, even if it may not be needed for the
derivative computation.
As a result, this auto importer is also likely to be slower than writing your own rule, and may also be slower
than not having a rule at all.

Use with caution.

Enzyme.@import_rrule(typeof(Base.sort), Any);

I'm potentially okay with it being made into an extension package and marked as deprecated / rules have warnings when run.

cc @vchuravy for your thoughts.

In any case the PR needs a small amount of rebase work before it could be merged, if you'd be interested in helping

CarloLucibello commented 7 months ago

An opt-in approach like the one in #996 would be valuable already. I understand the performance limitations but slow is better than not working at all. Related to this, it seems to me that writing an enzyme rule is way more difficult than writing a chainrules' one, so I would really appreciate the possibility of a quick translation, and probably in many cases typical of DL the performance hit would be negligible.

wsmoses commented 7 months ago

@CarloLucibello go for it, reviving the PR would be a welcome contribution.

I'm more skeptical of the negligible performance hit because having to copy and cache a bunch of unnecessary memory would scale with the size of the tensors.

CarloLucibello commented 7 months ago

@CarloLucibello go for it, reviving the PR would be a welcome contribution.

I'm not familiar with Enzyme's internals (honestly I have little familiarity with Enzyme in general yet), I don't think I can help much. Do you want help with testing it?

wsmoses commented 4 months ago

With https://github.com/EnzymeAD/Enzyme.jl/pull/996 now in we have macros to import both chainrules frules and rrules, am now closing.