JuliaMath / AbstractFFTs.jl

A Julia framework for implementing FFTs
MIT License
126 stars 33 forks source link

Adding EnzymeRules #99

Open sethaxen opened 1 year ago

sethaxen commented 1 year ago

Similar to the ChainRulesCore support, we could use EnzymeCore.EnzymeRules to define Forward and Reverse mode rules for Enzyme in an extension.

EnzymeCore requires at least Julia v1.6. Making it a dependency for Julia versions older than v1.6 (as is done with ChainRulesCore) would then only be possible if AbstractFFTs add a Julia v1.6 version bound. But since EnzymeCore's sole dependency (Adapt) depends on Requires, it may make more sense to conditionally load on pre-v1.9 using Requires.jl: https://pkgdocs.julialang.org/dev/creating-packages/#Requires.jl

Unlike ChainRulesCore support, it probably only makes sense to only define rules for StridedArray inputs to avoid doing the wrong thing for sparse or structured arrays.

ChrisRackauckas commented 1 year ago

I don't think anyone would be against it. Make it an extension package using v1.9 extensions and there's no dep added here. It's more about getting it done.

wsmoses commented 1 year ago

Enzyme rules require v1.7 or above. cc @vchuravy

EnzymeCore is designed to be the light flexible package like ChainRules core for importing and adding rules.

sethaxen commented 1 year ago

Okay, I'm happy to tackle this once I understand some confusing behavior of complex array rules (see https://github.com/EnzymeAD/Enzyme.jl/issues/744 and https://github.com/EnzymeAD/Enzyme.jl/pull/739#discussion_r1173527477).

Unlike the ChainRules support, I would restrict Enzyme rules to StridedArray inputs, which is what FFTW promotes everything to and avoids doing the wrong thing for structurally sparse array inputs.

sethaxen commented 1 year ago

A few observations after starting work on this:

wsmoses commented 1 year ago

Why would defining rules for fft/etc cause more work to be done?

sethaxen commented 1 year ago

unlike ChainRules, we should not define rules for fft, ifft, bfft etc, since this will cover up the primal functions or any custom methods and do more work than is necessary

Let me clarify. Let f be fft, fft!, etc. Assume no-one pirates f for an array type defined in the standard library in some other package. Then it is safe for us to define forward- and reverse-mode rules for f for all array types defined in the standard lib (or maybe even just StridedArray). To avoid doing extra work, when f is not in-place, we can use the same plan for the primal and tangent. Because FFTW works by promoting any array to a StridedArray, this would cover all cases covered by ChainRules when FFTW is the backend.

For AbstractArray inputs, we cannot define reverse-mode rules for f, because the input might be structured; these rules need to be defined in the backend package.

Why would defining rules for fft/etc cause more work to be done?

We can also define forward-mode rules for f for AbstractArray inputs, but since a package might overload f for a custom array type to use a method that does not use a plan, we just call the primal function and don't construct a plan. If the primal does use a plan, then this would construct at least one unnecessary plan and do more work than is necessary. In the batched case, this could construct many more plans than are needed.