Closed wkearn closed 4 years ago
Upon doing a little more experimentation, I found that this error is thrown in https://github.com/FluxML/Zygote.jl/blob/08e8122bab50729edb04f22492ba30ab6a670289/src/lib/array.jl#L937
because Δ
needs to be real. Changing this line to
return (1/N * AbstractFFTs.rfft(real.(Δ), dims), nothing, nothing)
as in the other adjoint definitions for irfft
gets rid of the realfloat
error above. However, another error emerges when running the test.
julia> gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)
ERROR: AssertionError: osize[d1] == d >> 1 + 1
Stacktrace:
[1] brfft_output_size(::Array{Complex{Float64},2}, ::Int64, ::Array{Int64,0}) at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:329
[2] #plan_brfft#20(::UInt32, ::Float64, ::typeof(plan_brfft), ::Array{Complex{Float64},2}, ::Int64, ::Array{Int64,0}) at C:\Users\wkearney\.julia\packages\FFTW\5DZuu\src\fft.jl:678
[3] plan_brfft at C:\Users\wkearney\.julia\packages\FFTW\5DZuu\src\fft.jl:678 [inlined]
[4] #plan_irfft#19 at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:334 [inlined]
[5] plan_irfft at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:334 [inlined]
[6] irfft(::Array{Complex{Float64},2}, ::Int64, ::Array{Int64,0}) at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:284
[7] (::Zygote.var"#1047#1048"{Array{Float64,2}})(::Array{Complex{Float64},2}) at C:\Users\wkearney\.julia\dev\Zygote\src\lib\array.jl:929
[8] (::Zygote.var"#3689#back#1049"{Zygote.var"#1047#1048"{Array{Float64,2}}})(::Array{Complex{Float64},2}) at C:\Users\wkearney\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
[9] #5 at .\REPL[6]:1 [inlined]
[10] (::typeof(∂(#5)))(::Float64) at C:\Users\wkearney\.julia\dev\Zygote\src\compiler\interface2.jl:0
[11] (::Zygote.var"#43#44"{typeof(∂(#5))})(::Float64) at C:\Users\wkearney\.julia\dev\Zygote\src\compiler\interface.jl:45
[12] gradient(::Function, ::Array{Float64,2}) at C:\Users\wkearney\.julia\dev\Zygote\src\compiler\interface.jl:54
[13] top-level scope at REPL[6]:1
This is the error that I ran into while working on #751 .
Turns out that the adjoint definitions for irfft(xs,d,dims)
and brfft(xs,d,dims)
call the wrong function:
line 934 should be
return AbstractFFTs.irfft(xs, d, dims), function(Δ)
and the same for brfft
. Pull request forthcoming.
In #751 I added some tests to make sure that the adjoints for FFTs preserved the types of the inputs, but I had some trouble with the following test
which fails with the error
A more minimal example shows that the problem is in the adjoint for
irfft(xs,d,dims)
:which throws the same error.
This fails regardless of the input types, so I commented out the relevant tests in #751