JuliaMath / FFTW.jl

Julia bindings to the FFTW library for fast Fourier transforms
https://juliamath.github.io/FFTW.jl/stable
MIT License
269 stars 54 forks source link

add chainrules for `r2r`, `dct` #272

Open vpuri3 opened 1 year ago

vpuri3 commented 1 year ago

how is the gradient computed for plan_dct, if there's not rrule dor dct??

using LinearAlgebra, FFTW, Zygote

x = rand(4)
C = plan_dct(x)

f(x) = C \ (C * x) |> norm
g(x) = x |> dct |> idct |> norm
h(x) = plan_dct(x) \ (plan_dct(x) * x) |> norm

@show Zygote.gradient(f, x) # ([0.7499995699183157, 0.5170775887690442, 0.3522881598130941, 0.2145331321046639],)
@show Zygote.gradient(g, x) # errors
@show Zygote.gradient(h, x) # errors

error message:

julia> Zygote.gradient(f, x)                                                          
ERROR: Compiling Tuple{Type{FFTW.r2rFFTWPlan{Float64, Any, false, 1}}, Vector{Float64}, FFTW.FakeArray{Float64, 1}, UnitRange{Int64}, Int64, UInt32, Float64}: try/catch is n
ot supported.                                                                         
Refer to the Zygote documentation for fixes.                                          
https://fluxml.ai/Zygote.jl/latest/limitations                          
danielwe commented 4 weeks ago

Cosigning this, having AdjointStyle and adjoint_mul for r2r plans would be great. I'm working out the one-dimensional case, but haven't quite wrapped my head around multidimensional FFTW yet

danielwe commented 4 weeks ago

Here's an extremely rudimentary implementation of adjoint_mul for the 1d REDFT10, in case anyone finds it helpful as a starting point

using AbstractFFTs
using FFTW

struct R2RFFTAdjointStyle <: AbstractFFTs.AdjointStyle end

AbstractFFTs.AdjointStyle(::FFTW.r2rFFTWPlan) = R2RFFTAdjointStyle()

function AbstractFFTs.adjoint_mul(
    p::FFTW.r2rFFTWPlan{T}, x::AbstractVector{T}, ::R2RFFTAdjointStyle
) where {T}
    (length(p.kinds) == 1) || throw(ArgumentError("Multidimensional r2r transforms not yet supported"))
    (only(p.kinds) == 5) || throw(ArgumentError("r2r kinds other than REDFT10 not yet supported"))
    pinv = inv(p)
    unscaled_pinv = (pinv isa AbstractFFTs.ScaledPlan) ? pinv.p : pinv
    y = unscaled_pinv * x
    # REDFT10 is unitary except for the first row, so the unscaled inverse is its adjoint
    # except for the first column. To obtain the true adjoint, add more DC.
    y .+= first(x)
    return y
end