SciML / OperatorLearning.jl

No need to train, he's a smooth operator
https://operatorlearning.sciml.ai/dev
MIT License
43 stars 8 forks source link

`Zygote` won't work with FFT Plans #11

Closed pzimbrod closed 2 years ago

pzimbrod commented 3 years ago

When taking the jacobian of a sample FFT-Pipeline, Zygote complains about missing fields in the corresponding FFT Plan:

using Zygote, FFTW

n = rand(100);
f = plan_rfft(n);
fi = plan_irfft(rfft(n), length(n));
fi * (f * n);

jacobian(x -> fi * (f * x), n)
ERROR: type ScaledPlan has no field region
Stacktrace:
  [1] getproperty(x::AbstractFFTs.ScaledPlan{ComplexF64, FFTW.rFFTWPlan{ComplexF64, 1, false, 1, UnitRange{Int64}}, Float64}, f::Symbol)
    @ Base ./Base.jl:33
  [2] (::Zygote.var"#931#932"{AbstractFFTs.ScaledPlan{ComplexF64, FFTW.rFFTWPlan{ComplexF64, 1, false, 1, UnitRange{Int64}}, Float64}, Vector{ComplexF64}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/array.jl:788
  [3] (::Zygote.var"#3479#back#933"{Zygote.var"#931#932"{AbstractFFTs.ScaledPlan{ComplexF64, FFTW.rFFTWPlan{ComplexF64, 1, false, 1, UnitRange{Int64}}, Float64}, Vector{ComplexF64}}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./REPL[6]:1 [inlined]
  [5] (::typeof(∂(#1)))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [6] (::Zygote.var"#209#210"{Tuple{Tuple{Nothing}}, typeof(∂(#1))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203
  [7] (::Zygote.var"#1746#back#211"{Zygote.var"#209#210"{Tuple{Tuple{Nothing}}, typeof(∂(#1))}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [8] Pullback
    @ ./operators.jl:938 [inlined]
  [9] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), var"#1#2"}(Zygote._jvec, var"#1#2"()))))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#46#47"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), var"#1#2"}(Zygote._jvec, var"#1#2"())))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41
 [11] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:162
 [12] jacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:140
 [13] top-level scope
    @ REPL[6]:1

However, doing the FFT in-place, the error vanishes:

jacobian(x -> irfft(rfft(x), length(x)), n)
([1.0 0.0 … 0.0 0.0; 0.0 1.0 … 0.0 6.187926798518962e-19; … ; 0.0 0.0 … 1.0 0.0; 0.0 5.204170427930421e-18 … 0.0 1.0],)

This is an issue since we're doing the FFT and its inverse a lot of times, so the speed gain from using pre-planned FFTs shoud be quite considerable.

pzimbrod commented 3 years ago

This problem is mentioned in FluxML/Zygote.jl#899 and subsequently JuliaMath/FFTW.jl#182. In some implementations of FFTW.jl the region field of the struct plan is missing.

It would be plausible that the problem can be fixed by switching to the unscaled inverse transform and do the scaling manually afterwards.

pzimbrod commented 3 years ago

Apparently, you can use plan_brfft instead for the inverse and scale afterwards:

fib = plan_brfft(rfft(n), length(n));

# Check for same results
irfft(f * n, length(n)) ≈ fib * (f * n) ./ length(n)
true

# Does Zygote run now?
jacobian(x -> fib * (f * x) ./ length(x), n)
([0.51 0.0 … 0.0 0.0; 0.0 0.51 … 0.0 3.996891658508158e-19; … ; 0.0 0.0 … 0.51 0.0; 0.0 8.673617379884035e-19 … 0.0 0.51],)
pzimbrod commented 2 years ago

Another workaround could also be to just define custom adjoints for Zygote, as partially discussed here. That will probably require some serious fiddling, though.