Missing rules for FFTW `unsafe_execute` #1717

Open ptiede opened 1 month ago

ptiede commented 1 month ago

Here is a MWE

using Enzyme
using FFTW

Enzyme.EnzymeRules.inactive(::typeof(FFTW.assert_applicable), args...) = nothing

function test(p, x)
    return sum(abs2, x)

x = rand(ComplexF64, 256)
p = plan_fft!(x)

test(p, x)

dx = similar(x)
autodiff(Enzyme.Reverse, test, Active, Const(p), Duplicated(x, dx))

Enzyme compilation failed.

No augmented forward pass found for ijl_lazy_load_and_lookup
 at context:   %fftw_execute_dft.found = call void ()* @ijl_lazy_load_and_lookup({} addrspace(10)* nonnull %11, i8* noundef getelementptr inbounds ([17 x i8], [17 x i8]* @_j_str1, i32 0, i32 0)) #21, !dbg !41

 [1] unsafe_execute!
   @ ~/.julia/packages/FFTW/6nZei/src/fft.jl:518
 [2] *
   @ ~/.julia/packages/FFTW/6nZei/src/fft.jl:835
 [3] *
   @ ~/.julia/packages/AbstractFFTs/4iQz5/src/definitions.jl:224
 [4] test
   @ ~/Research/EnzymeTest/fft.jl:8

  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/aEyGD/src/compiler.jl:1696
  [2] unsafe_execute!
    @ ~/.julia/packages/FFTW/6nZei/src/fft.jl:518 [inlined]
  [3] *
    @ ~/.julia/packages/FFTW/6nZei/src/fft.jl:835 [inlined]
  [4] *
    @ ~/.julia/packages/AbstractFFTs/4iQz5/src/definitions.jl:224 [inlined]
  [5] test
    @ ~/Research/EnzymeTest/fft.jl:8 [inlined]
  [6] diffejulia_test_5412wrap
    @ ~/Research/EnzymeTest/fft.jl:0
  [7] macro expansion
    @ ~/.julia/packages/Enzyme/aEyGD/src/compiler.jl:6673 [inlined]
  [8] enzyme_call
    @ ~/.julia/packages/Enzyme/aEyGD/src/compiler.jl:6273 [inlined]
  [9] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/aEyGD/src/compiler.jl:6150 [inlined]
 [10] autodiff
    @ ~/.julia/packages/Enzyme/aEyGD/src/Enzyme.jl:314 [inlined]
 [11] autodiff(::ReverseMode{…}, ::typeof(test), ::Type{…}, ::Const{…}, ::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/aEyGD/src/Enzyme.jl:326
 [12] top-level scope
    @ ~/Research/EnzymeTest/fft.jl:20
Some type information was truncated. Use `show(err)` to see complete types.

The inactive rule is to prevent another ijl_lazy_load_and_lookup error, but that's easy to beat since it shouldn't be differentiated.

Unless I misunderstand something, this is another missing rule. If so, I'll take a crack at writing some rules in the next couple of days.

danielwe commented 1 month ago

FFTW.jl support would be hugely appreciated. As it's a wrapper around a C library, I suppose it will require custom rules similar to BLAS. Is that on the horizon at all, @wsmoses?

x-ref #369

danielwe commented 1 month ago

As a stopgap solution one can import ChainRules rrules defined in AbstractFFTs and make the example work as follows:

using Enzyme
using FFTW
using AbstractFFTs
using ChainRulesCore

Enzyme.@import_rrule(typeof(*), AbstractFFTs.Plan, AbstractArray)
Enzyme.@import_rrule(typeof(*), AbstractFFTs.ScaledPlan, AbstractArray)

function test(p, x)
    y = p * x
    return sum(abs2, y)

x = rand(ComplexF64, 256)
p = plan_fft(x)

@show test(p, x)

x = rand(ComplexF64, 256)
dx = make_zero(x)

autodiff(Reverse, test, Active, Const(p), Duplicated(x, dx))

These rules don't support in-place plans, so I changed your MWE to use an out-of-place plan.

danielwe commented 1 month ago

x-ref stalled AbstractFFTs.jl PR to add Enzyme rules over there: https://github.com/JuliaMath/AbstractFFTs.jl/pull/103. This seems like the right place for any Julia native rules.

However, the base FFTW library is such a cornerstone of computational code across languages that it might justify having rules in base Enzyme, similar to BLAS, if and when anyone has time to write them.

ptiede commented 4 weeks ago

Thanks for pointing this out! Having some rules for FFTW seems like a good idea. Also, the number of rules I'd need to write for FFTW is potentially much smaller than those for AbstractFFTs.

danielwe commented 4 weeks ago

Just a note that you can skip importing the rule for ScaledPlan, Enzyme can differentiate through that wrapper just fine. To get full reverse-mode support for everything that's supported elsewhere in Julia, all you should need is

Enzyme.@import_rrule(typeof(*), AbstractFFTs.Plan, AbstractArray)
EnzymeRules.inactive(::typeof(plan_fft), args...) = nothing
# ... similar `inactive` methods for the other plan functions, like plan_dct

Some parts of FFTW are not yet covered by this, notably the r2r transforms. What's missing are corresponding implementations of AbstractFFTs.AdjointStyle, see https://github.com/JuliaMath/FFTW.jl/issues/272.