FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Errors in adjoints for ifft(xs,d,dims) #755

Closed wkearn closed 4 years ago

wkearn commented 4 years ago

In #751 I added some tests to make sure that the adjoints for FFTs preserved the types of the inputs, but I had some trouble with the following test

x = randn(16,16)
@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2}

which fails with the error

MethodError: no method matching realfloat(::Array{Complex{Float64},2})
Closest candidates are:
  realfloat(!Matched::Union{DenseArray{#s12,N}, Base.ReinterpretArray{#s12,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray}, Base.ReshapedArray{#s12,N,A,MI} where MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{Base.ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray}, SubArray{#s12,N,A,I,L} where L where I<:Tuple{Vararg{Union{Int64, AbstractRange{Int64}, Base.AbstractCartesianIndex},N} where N} where A<:Union{Base.ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, Base.ReshapedArray{T,N,A,MI} where MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{Base.ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Union{Tuple{Vararg{Real,N} where N}, Tuple{AbstractUnitRange,Vararg{Any,N} where N}} where A<:DenseArray where N where T, DenseArray} where N where T, DenseArray}} where N where #s12<:Union{Float32, Float64}) at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:26
  realfloat(!Matched::AbstractArray{T,N} where N) where T<:Real at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:33

Stacktrace:
 [1] #plan_rfft#14(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(plan_rfft), ::Array{Complex{Float64},2}, ::Array{Int64,0}) at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:205
 [2] plan_rfft(::Array{Complex{Float64},2}, ::Array{Int64,0}) at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:205
 [3] rfft(::Array{Complex{Float64},2}, ::Array{Int64,0}) at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:51
 [4] (::Zygote.var"#1051#1052"{Array{Complex{Float64},2}})(::Array{Complex{Float64},2}) at C:\Users\wkearney\.julia\dev\Zygote\src\lib\array.jl:937
 [5] (::Zygote.var"#3707#back#1053"{Zygote.var"#1051#1052"{Array{Complex{Float64},2}}})(::Array{Complex{Float64},2}) at C:\Users\wkearney\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [6] #273 at .\In[173]:2 [inlined]
 [7] (::typeof(∂(#273)))(::Float64) at C:\Users\wkearney\.julia\dev\Zygote\src\compiler\interface2.jl:0
 [8] (::Zygote.var"#43#44"{typeof(∂(#273))})(::Float64) at C:\Users\wkearney\.julia\dev\Zygote\src\compiler\interface.jl:45
 [9] gradient(::Function, ::Array{Float64,2}) at C:\Users\wkearney\.julia\dev\Zygote\src\compiler\interface.jl:54
 [10] top-level scope at In[173]:2

A more minimal example shows that the problem is in the adjoint for irfft(xs,d,dims):

gradient(x->sum(abs2,irfft(F,16,1)),x)

which throws the same error.

This fails regardless of the input types, so I commented out the relevant tests in #751

wkearn commented 4 years ago

Upon doing a little more experimentation, I found that this error is thrown in https://github.com/FluxML/Zygote.jl/blob/08e8122bab50729edb04f22492ba30ab6a670289/src/lib/array.jl#L937

because Δ needs to be real. Changing this line to

return (1/N * AbstractFFTs.rfft(real.(Δ), dims), nothing, nothing) 

as in the other adjoint definitions for irfft gets rid of the realfloat error above. However, another error emerges when running the test.

julia> gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)

ERROR: AssertionError: osize[d1] == d >> 1 + 1
Stacktrace:
 [1] brfft_output_size(::Array{Complex{Float64},2}, ::Int64, ::Array{Int64,0}) at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:329
 [2] #plan_brfft#20(::UInt32, ::Float64, ::typeof(plan_brfft), ::Array{Complex{Float64},2}, ::Int64, ::Array{Int64,0}) at C:\Users\wkearney\.julia\packages\FFTW\5DZuu\src\fft.jl:678
 [3] plan_brfft at C:\Users\wkearney\.julia\packages\FFTW\5DZuu\src\fft.jl:678 [inlined]
 [4] #plan_irfft#19 at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:334 [inlined]
 [5] plan_irfft at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:334 [inlined]
 [6] irfft(::Array{Complex{Float64},2}, ::Int64, ::Array{Int64,0}) at C:\Users\wkearney\.julia\packages\AbstractFFTs\mhQvY\src\definitions.jl:284
 [7] (::Zygote.var"#1047#1048"{Array{Float64,2}})(::Array{Complex{Float64},2}) at C:\Users\wkearney\.julia\dev\Zygote\src\lib\array.jl:929
 [8] (::Zygote.var"#3689#back#1049"{Zygote.var"#1047#1048"{Array{Float64,2}}})(::Array{Complex{Float64},2}) at C:\Users\wkearney\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [9] #5 at .\REPL[6]:1 [inlined]
 [10] (::typeof(∂(#5)))(::Float64) at C:\Users\wkearney\.julia\dev\Zygote\src\compiler\interface2.jl:0
 [11] (::Zygote.var"#43#44"{typeof(∂(#5))})(::Float64) at C:\Users\wkearney\.julia\dev\Zygote\src\compiler\interface.jl:45
 [12] gradient(::Function, ::Array{Float64,2}) at C:\Users\wkearney\.julia\dev\Zygote\src\compiler\interface.jl:54
 [13] top-level scope at REPL[6]:1

This is the error that I ran into while working on #751 .

wkearn commented 4 years ago

Turns out that the adjoint definitions for irfft(xs,d,dims) and brfft(xs,d,dims) call the wrong function:

https://github.com/FluxML/Zygote.jl/blob/a0f9c8ea789362b5998b3117ba64b55f97be5f5c/src/lib/array.jl#L933-L939

line 934 should be

return AbstractFFTs.irfft(xs, d, dims), function(Δ)

and the same for brfft. Pull request forthcoming.