FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

efficient gradient for fft #204

Closed sipposip closed 4 years ago

sipposip commented 5 years ago

I am trying to implement custom adjoints for Fast Fourier Transform (FFT) functions. As the derivative of a DFT (discrete fourier transform) output is the inverse DFT of the derivative of the inputs, and FFT is just a way to compute a DFT, this is in principle not very complicated.

The following implementation (based on https://github.com/FluxML/Flux.jl/issues/410) works, but is inefficient for large dimensions because the DFT for the gradient is computed via a matrix instead of using FFT.

using Zygote
using FFTW

Zygote.@adjoint function FFTW.fft(xs)
  return FFTW.fft(xs), function (Δ)
    ns = size(xs,1)
    # the gradient of fft is ifft (the inverse fft), which is the same as the fft
    # but without the minus-sign in the exponent
    ω = [exp(2π*im*j*k/ns) for j=0:(ns-1), k=0:(ns-1)]
    return (ω*Δ,)
  end
end

Zygote.@adjoint function FFTW.ifft(xs)
  return FFTW.ifft(xs), function (Δ)
    ns = size(xs,1)
    # the gradient of ifft is fft
    ω = [-exp(2π*im*j*k/ns) for j=0:(ns-1), k=0:(ns-1)]
    return (ω*Δ,)
  end
end

x = randn(10)
Zygote.gradient((x) -> abs(sum(FFTW.fft(x))), x)
Zygote.gradient((x) -> abs(sum(FFTW.ifft(x))), x)

Ideally, one would use FFT to compute the DFT of the gradient

Zygote.@adjoint function FFTW.fft(xs)
  return FFTW.fft(xs), function (Δ)
    return FFTW.ifft(Δ)  # does not work (because the fft function does not work with Zygote arrays)
  end
end
Zygote.gradient((x) -> abs(sum(FFTW.fft(x))), x) # fails

This throws the following error because FFTW.ifft is not implemented for Zygote.FillArray MethodError: no method matching plan_bfft(::Zygote.FillArray{Complex{Float64},1}, ::UnitRange{Int64})

Is there maybe another way to do this? E.g. can one extract the data from the FillArray? This boils down to a more fundamental question, namely is it possible to use a function that Zygote cannot differentiate inside the definition of the adjoint of this function, using the input-gradient as input to this function?

MikeInnes commented 5 years ago

The error is nothing fundamental but just an artefact of sum producing a FillArray (and FFTW apparently does not work on arbitrary abstract arrays). A different case works fine:

julia> Zygote.@adjoint function FFTW.fft(xs)
         return FFTW.fft(xs), function (Δ)
           return (FFTW.ifft(Δ),)
         end
       end

julia> Zygote.gradient((x) -> abs(FFTW.fft(x)[1]), [1,2,3])
(Complex{Float64}[0.3333333333333333 + 0.0im, 0.3333333333333333 + 0.0im, 0.3333333333333333 + 0.0im],)

So all you need to do is to define the missing method, fft(::FillArray). fft(collect(x)) should work, though since the output is trivial you could also just hand code it.

Would be great to have these definitions in a PR.

sipposip commented 5 years ago

Ok now I see. With the following definitions both fft and ifft seem to work

FFTW.ifft(x::Zygote.FillArray) = FFTW.ifft(collect(x))
FFTW.fft(x::Zygote.FillArray) = FFTW.fft(collect(x))

Zygote.@adjoint FFTW.fft(xs) = (FFTW.fft(xs), (Δ)-> (FFTW.ifft(Δ),))
Zygote.@adjoint FFTW.ifft(xs) = (FFTW.ifft(xs), (Δ)-> (FFTW.fft(Δ),))

I will open a PR after doing some more testing.

sipposip commented 5 years ago

I found out that the definitions in my last comment were wrong. it should be

Zygote.@adjoint function FFTW.fft(xs)
    return FFTW.fft(xs), function(Δ)
        N = length(Δ)
        return (N * FFTW.ifft(Δ),)
    end
end

Zygote.@adjoint function FFTW.ifft(xs)
    return FFTW.ifft(xs), function(Δ)
        N = length(Δ)
        return (1/N* FFTW.fft(Δ),)
    end
end

I am currently working on the other functions in FFTW and will open a PR when I have everything.

sipposip commented 5 years ago

I opend a pull request on this https://github.com/FluxML/Zygote.jl/pull/215

CarloLucibello commented 4 years ago

this has been closed by #215