Open sethaxen opened 1 year ago
I don't think anyone would be against it. Make it an extension package using v1.9 extensions and there's no dep added here. It's more about getting it done.
Enzyme rules require v1.7 or above. cc @vchuravy
EnzymeCore is designed to be the light flexible package like ChainRules core for importing and adding rules.
Okay, I'm happy to tackle this once I understand some confusing behavior of complex array rules (see https://github.com/EnzymeAD/Enzyme.jl/issues/744 and https://github.com/EnzymeAD/Enzyme.jl/pull/739#discussion_r1173527477).
Unlike the ChainRules support, I would restrict Enzyme rules to StridedArray
inputs, which is what FFTW promotes everything to and avoids doing the wrong thing for structurally sparse array inputs.
A few observations after starting work on this:
fft
, ifft
, bfft
etc, since this will cover up the primal functions or any custom methods and do more work than is necessary*(p::Plan, x::StridedVector)
since we have no way of knowing whether p
was constructed as an in-place plan or not, and the rule changes depending on whether this is the case.mul!(y::AbstractArray, plan::Plan, x::AbstractArray)
, since we don't know what the plan is and therefore don't know how to normalize.mul!(y::AbstractArray, plan::Plan, x::AbstractArray)
. As long as every possible FFT is linear, then the pushforward is the same as the primal, and we know that plan
must be out-of-placeWhy would defining rules for fft/etc cause more work to be done?
unlike ChainRules, we should not define rules for
fft
,ifft
,bfft
etc, since this will cover up the primal functions or any custom methods and do more work than is necessary
Let me clarify. Let f
be fft
, fft!
, etc. Assume no-one pirates f
for an array type defined in the standard library in some other package. Then it is safe for us to define forward- and reverse-mode rules for f
for all array types defined in the standard lib (or maybe even just StridedArray
). To avoid doing extra work, when f
is not in-place, we can use the same plan for the primal and tangent. Because FFTW works by promoting any array to a StridedArray
, this would cover all cases covered by ChainRules when FFTW is the backend.
For AbstractArray
inputs, we cannot define reverse-mode rules for f
, because the input might be structured; these rules need to be defined in the backend package.
Why would defining rules for fft/etc cause more work to be done?
We can also define forward-mode rules for f
for AbstractArray
inputs, but since a package might overload f
for a custom array type to use a method that does not use a plan, we just call the primal function and don't construct a plan. If the primal does use a plan, then this would construct at least one unnecessary plan and do more work than is necessary. In the batched case, this could construct many more plans than are needed.
Similar to the ChainRulesCore support, we could use
EnzymeCore.EnzymeRules
to defineForward
andReverse
mode rules for Enzyme in an extension.EnzymeCore requires at least Julia v1.6. Making it a dependency for Julia versions older than v1.6 (as is done with ChainRulesCore) would then only be possible if AbstractFFTs add a Julia v1.6 version bound. But since EnzymeCore's sole dependency (Adapt) depends on Requires, it may make more sense to conditionally load on pre-v1.9 using Requires.jl: https://pkgdocs.julialang.org/dev/creating-packages/#Requires.jl
Unlike ChainRulesCore support, it probably only makes sense to only define rules for
StridedArray
inputs to avoid doing the wrong thing for sparse or structured arrays.