Closed sipposip closed 4 years ago
The error is nothing fundamental but just an artefact of sum
producing a FillArray
(and FFTW apparently does not work on arbitrary abstract arrays). A different case works fine:
julia> Zygote.@adjoint function FFTW.fft(xs)
return FFTW.fft(xs), function (Δ)
return (FFTW.ifft(Δ),)
end
end
julia> Zygote.gradient((x) -> abs(FFTW.fft(x)[1]), [1,2,3])
(Complex{Float64}[0.3333333333333333 + 0.0im, 0.3333333333333333 + 0.0im, 0.3333333333333333 + 0.0im],)
So all you need to do is to define the missing method, fft(::FillArray)
. fft(collect(x))
should work, though since the output is trivial you could also just hand code it.
Would be great to have these definitions in a PR.
Ok now I see. With the following definitions both fft and ifft seem to work
FFTW.ifft(x::Zygote.FillArray) = FFTW.ifft(collect(x))
FFTW.fft(x::Zygote.FillArray) = FFTW.fft(collect(x))
Zygote.@adjoint FFTW.fft(xs) = (FFTW.fft(xs), (Δ)-> (FFTW.ifft(Δ),))
Zygote.@adjoint FFTW.ifft(xs) = (FFTW.ifft(xs), (Δ)-> (FFTW.fft(Δ),))
I will open a PR after doing some more testing.
I found out that the definitions in my last comment were wrong. it should be
Zygote.@adjoint function FFTW.fft(xs)
return FFTW.fft(xs), function(Δ)
N = length(Δ)
return (N * FFTW.ifft(Δ),)
end
end
Zygote.@adjoint function FFTW.ifft(xs)
return FFTW.ifft(xs), function(Δ)
N = length(Δ)
return (1/N* FFTW.fft(Δ),)
end
end
I am currently working on the other functions in FFTW and will open a PR when I have everything.
I opend a pull request on this https://github.com/FluxML/Zygote.jl/pull/215
this has been closed by #215
I am trying to implement custom adjoints for Fast Fourier Transform (FFT) functions. As the derivative of a DFT (discrete fourier transform) output is the inverse DFT of the derivative of the inputs, and FFT is just a way to compute a DFT, this is in principle not very complicated.
The following implementation (based on https://github.com/FluxML/Flux.jl/issues/410) works, but is inefficient for large dimensions because the DFT for the gradient is computed via a matrix instead of using FFT.
Ideally, one would use FFT to compute the DFT of the gradient
This throws the following error because FFTW.ifft is not implemented for Zygote.FillArray
MethodError: no method matching plan_bfft(::Zygote.FillArray{Complex{Float64},1}, ::UnitRange{Int64})
Is there maybe another way to do this? E.g. can one extract the data from the FillArray? This boils down to a more fundamental question, namely is it possible to use a function that Zygote cannot differentiate inside the definition of the adjoint of this function, using the input-gradient as input to this function?