Open ptiede opened 1 month ago
FFTW.jl support would be hugely appreciated. As it's a wrapper around a C library, I suppose it will require custom rules similar to BLAS. Is that on the horizon at all, @wsmoses?
x-ref #369
As a stopgap solution one can import ChainRules rrules defined in AbstractFFTs
and make the example work as follows:
using Enzyme
using FFTW
using AbstractFFTs
using ChainRulesCore
Enzyme.@import_rrule(typeof(*), AbstractFFTs.Plan, AbstractArray)
Enzyme.@import_rrule(typeof(*), AbstractFFTs.ScaledPlan, AbstractArray)
function test(p, x)
y = p * x
return sum(abs2, y)
end
x = rand(ComplexF64, 256)
p = plan_fft(x)
@show test(p, x)
x = rand(ComplexF64, 256)
dx = make_zero(x)
autodiff(Reverse, test, Active, Const(p), Duplicated(x, dx))
display(dx)
These rules don't support in-place plans, so I changed your MWE to use an out-of-place plan.
x-ref stalled AbstractFFTs.jl PR to add Enzyme rules over there: https://github.com/JuliaMath/AbstractFFTs.jl/pull/103. This seems like the right place for any Julia native rules.
However, the base FFTW library is such a cornerstone of computational code across languages that it might justify having rules in base Enzyme, similar to BLAS, if and when anyone has time to write them.
Thanks for pointing this out! Having some rules for FFTW seems like a good idea. Also, the number of rules I'd need to write for FFTW is potentially much smaller than those for AbstractFFTs.
Just a note that you can skip importing the rule for ScaledPlan
, Enzyme can differentiate through that wrapper just fine. To get full reverse-mode support for everything that's supported elsewhere in Julia, all you should need is
Enzyme.@import_rrule(typeof(*), AbstractFFTs.Plan, AbstractArray)
EnzymeRules.inactive(::typeof(plan_fft), args...) = nothing
# ... similar `inactive` methods for the other plan functions, like plan_dct
Some parts of FFTW are not yet covered by this, notably the r2r
transforms. What's missing are corresponding implementations of AbstractFFTs.AdjointStyle
, see https://github.com/JuliaMath/FFTW.jl/issues/272.
Here is a MWE
The inactive rule is to prevent another
ijl_lazy_load_and_lookup
error, but that's easy to beat since it shouldn't be differentiated.Unless I misunderstand something, this is another missing rule. If so, I'll take a crack at writing some rules in the next couple of days.