SciML / OperatorLearning.jl

No need to train, he's a smooth operator
https://operatorlearning.sciml.ai/dev
MIT License
43 stars 8 forks source link

Work with pre-allocated arrays and `mul!` when doing FFT #14

Closed pzimbrod closed 2 years ago

pzimbrod commented 3 years ago

In the docs of AbstractFFTs.jl, it says that it is preferrable to apply the FFT plans to pre-allocated arrays using mul! from the LinearAlgebra package.

As this step is taken for every pass in the network, this should be well worthwile investigating.

pzimbrod commented 3 years ago
using Zygote, FFTW, LinearAlgebra, BenchmarkTools

# Create the plans
n = rand(100);
f = plan_rfft(n);
fi = plan_brfft(rfft(n), length(n));

# Create the pre-allocated mutable arrays
pref = Array{Complex{eltype(n)}}(undef, floor(Int, length(n)/2+1));
prefi = Array{eltype(n)}(undef, length(n));

# Benchmarks for the FT

@benchmark f * x
BechmarkTools.Trial: 10000 samples with 585 evaluations.
 Range (min … max):  197.414 ns …   2.603 μs   ┊ GC (min … max): 0.00% … 78.65%
 Time  (median):         226.930 ns                      ┊ GC (median):    0.00%
 Time  (mean ± σ):      250.871 ns ± 164.226 ns  ┊ GC (mean ± σ):  5.91% ±  8.13%

  ▆█▅▃                                                          ▁
  █████▇▇▆▅▆▅▆▅▃▃▁▃▁▄▇▆▆▅▅▅▅▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▃▄▅▄ █
  197 ns        Histogram: log(frequency) by time       1.51 μs <

 Memory estimate: 896 bytes, allocs estimate: 1.

@benchmark mul!(pref, f, n)
BechmarkTools.Trial: 10000 samples with 723 evaluations.
 Range (min … max):  175.148 ns … 491.354 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):         198.685 ns                       ┊ GC (median):    0.00%
 Time  (mean ± σ):      203.636 ns ±  19.412 ns   ┊ GC (mean ± σ):  0.00% ± 0.00%

   ▂    ▂  ▄ ▂█▂▄▄▄▄▂▁▂▁▂  ▁                                    ▁
  ▇█▃██▇█▃██████████████████▇█▇▆▆█▆▅▄▅▄▅▃▄▇▅▃▅▄▆▃▃▃▂▃▆▃▄▃▃▄▇▄▄▅ █
  175 ns        Histogram: log(frequency) by time        287 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

# Benchmarks for everything

@benchmark fi * (f * n)
BechmarkTools.Trial: 10000 samples with 167 evaluations.
 Range (min … max):  603.240 ns …   9.211 μs   ┊ GC (min … max): 0.00% … 85.68%
 Time  (median):         702.231 ns                        ┊ GC (median):    0.00%
 Time  (mean ± σ):      812.338 ns ± 645.850 ns  ┊ GC (mean ± σ):  5.40% ±  7.04%

  ▅▇█▇▅▅▃▁                                                      ▂
  █████████▇▇▆▇▇▇▅▆▇▇▇▇▅▅▆▅█▇██▆▇▆▆▆▆▆▆▆▄▁▁▄▅▅▆▅▆▆▆▅▆▆▄▆▆▆▅▇▃▃▅ █
  603 ns        Histogram: log(frequency) by time       2.76 μs <

 Memory estimate: 1.75 KiB, allocs estimate: 2.

@benchmark mul!(prefi, fi, mul!(pref, f, n))
BechmarkTools.Trial: 10000 samples with 194 evaluations.
 Range (min … max):  512.505 ns …  11.209 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):         575.735 ns                       ┊ GC (median):    0.00%
 Time  (mean ± σ):      611.997 ns ± 245.942 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▃▄▇█▇▄▃▂▃▂▁       ▁                                           ▂
  █████████████▇█▇▇▆█▇▆▆▄▁▃▁▁▁▃▃▁▁▁▃▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▄▃▄▃▄▆▇▇▇▇ █
  513 ns        Histogram: log(frequency) by time       1.62 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

All in all, there's some mem allocation up for grabs and also a moderate speed improvement (~25 %).

pzimbrod commented 3 years ago

This however breaks AD, which is absolutely necessary:

jacobian(x -> mul!(prefi, fi, mul!(pref, f, x)), n)
ERROR: Can't differentiate foreigncall expression
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] Pullback
    @ ~/.julia/packages/FFTW/kKdEk/src/fft.jl:486 [inlined]
  [3] (::typeof(∂(unsafe_execute!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [4] Pullback
    @ ~/.julia/packages/FFTW/kKdEk/src/fft.jl:791 [inlined]
  [5] (::typeof(∂(mul!)))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./REPL[32]:1 [inlined]
  [7] (::typeof(∂(#5)))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [8] (::Zygote.var"#209#210"{Tuple{Tuple{Nothing}}, typeof(∂(#5))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203
  [9] (::Zygote.var"#1746#back#211"{Zygote.var"#209#210"{Tuple{Tuple{Nothing}}, typeof(∂(#5))}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [10] Pullback
    @ ./operators.jl:938 [inlined]
 [11] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), var"#5#6"}(Zygote._jvec, var"#5#6"()))))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#46#47"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), var"#5#6"}(Zygote._jvec, var"#5#6"())))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41
 [13] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:162
 [14] jacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:140
 [15] top-level scope
    @ REPL[32]:1
pzimbrod commented 3 years ago

Note: as explained here in section 3.3, mul! should transfer to the GPU effortlessly since mul! either calls BLAS or cuBLAS depending on the input types.

using LinearAlgebra, CUDA

mul!(CUDA.zeros(Float32,2,2),
                   cu(rand(Float32,2,2)),
                   cu(rand(Float32,2,2)))
2×2 CuArray{Float32, 2}:
 0.609215  1.2757
 0.813715  0.897835

mul!(zeros(Float32,2,2),
                 rand(Float32,2,2),
                 rand(Float32,2,2))
2×2 Matrix{Float32}:
 0.753383  0.263215
 1.16842   0.309142
pzimbrod commented 2 years ago

This however breaks AD, which is absolutely necessary:

jacobian(x -> mul!(prefi, fi, mul!(pref, f, x)), n)
ERROR: Can't differentiate foreigncall expression
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] Pullback
    @ ~/.julia/packages/FFTW/kKdEk/src/fft.jl:486 [inlined]
  [3] (::typeof(∂(unsafe_execute!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [4] Pullback
    @ ~/.julia/packages/FFTW/kKdEk/src/fft.jl:791 [inlined]
  [5] (::typeof(∂(mul!)))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./REPL[32]:1 [inlined]
  [7] (::typeof(∂(#5)))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [8] (::Zygote.var"#209#210"{Tuple{Tuple{Nothing}}, typeof(∂(#5))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203
  [9] (::Zygote.var"#1746#back#211"{Zygote.var"#209#210"{Tuple{Tuple{Nothing}}, typeof(∂(#5))}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [10] Pullback
    @ ./operators.jl:938 [inlined]
 [11] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), var"#5#6"}(Zygote._jvec, var"#5#6"()))))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#46#47"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), var"#5#6"}(Zygote._jvec, var"#5#6"())))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41
 [13] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:162
 [14] jacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:140
 [15] top-level scope
    @ REPL[32]:1

Related to #11. Resolving the Diff problem likely takes care of both issues, maybe even #12.

Problem seems to be that Zygote needs to be supplied with custom adjoints since FFTW.jl calls the external FFTW library - which Zygote has no insight to and thus doesn't know what to do. But, as explained here, AD potentially isn't even necessary.

pzimbrod commented 2 years ago

A custom implementation of NNlib's batched_mul! is given in 096bab9d8cfad3a4115a35bae885f1e60a307553. It is however a huge pain to get the batching working efficiently with CUDA.

Maybe this should be left aside for now. Doing regular FFT comes with the additional advantage that AD works well.

It appears that the subsequent caching of the FFT arrays doesn't happen within julia, but rather after FFTW is called via unsafe_execute! and ccall. Hence, it's probably not worthwile working around these.

pzimbrod commented 2 years ago

For completeness sake, the same benchmarks but using a 3D Array of shape (200, 128, 100) using CUDA:

@benchmark f * n
BechmarkTools.Trial: 10000 samples with 3 evaluations.
 Range (min … max):   7.358 μs …  3.094 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     63.611 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   62.548 μs ± 57.216 μs  ┊ GC (mean ± σ):  0.63% ± 0.84%

                                                            █  
  ▃▂▂▂▂▁▂▂▁▁▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▂▂▂▂▁▂▁▁▂▁▂▁▂▁▂▁▂▂▂▂▁▂▁▂▂▂▁▂▂▂▂▂▅█ ▂
  7.36 μs         Histogram: frequency by time        64.4 μs <

 Memory estimate: 816 bytes, allocs estimate: 22.

julia> @benchmark mul!(pref, f, n)
BechmarkTools.Trial: 10000 samples with 5 evaluations.
 Range (min … max):   6.135 μs … 900.818 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     63.982 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   63.333 μs ±  11.964 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▁                                                         ▁█ ▁
  █▃▁▁▁▁▁▁▁▁▁▁▃▁▄▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▃▄▁▃▃▃▁▁▁▄▃▅▆██ █
  6.14 μs       Histogram: log(frequency) by time      64.5 μs <

 Memory estimate: 496 bytes, allocs estimate: 14.

julia> @benchmark fi * (f * n)
BechmarkTools.Trial: 10000 samples with 1 evaluations.
 Range (min … max):   17.151 μs …   4.731 ms  ┊ GC (min … max): 0.00% … 30.70%
 Time  (median):     153.542 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   150.429 μs ± 113.811 μs  ┊ GC (mean ± σ):  0.58% ±  0.76%

                                                              █  
  ▃▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▂▁▁▂▁▁▁▁▁▂▂▁▁▁▂▂▁▁▁▁▁▁▁▁▁▂▁▂▂▁▂▂▂▂▂▂▂▂▂▁▂▂██ ▂
  17.2 μs          Histogram: frequency by time          155 μs <

 Memory estimate: 1.52 KiB, allocs estimate: 44.

julia> @benchmark mul!(prefi, fi, mul!(pref, f, n))
BechmarkTools.Trial: 10000 samples with 1 evaluations.
 Range (min … max):   15.249 μs … 225.914 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     153.672 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   150.568 μs ±  20.788 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

                                                             ▃█  
  ▃▂▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▂▁▁▁▂▁▁▁▁▁▂▁▁▁▂▂▂▁▂▂▂▂▂▂▂▂██ ▂
  15.2 μs          Histogram: frequency by time          156 μs <

 Memory estimate: 1.38 KiB, allocs estimate: 36.

julia> @benchmark rfft(n,3)
BechmarkTools.Trial: 10000 samples with 1 evaluations.
 Range (min … max):  81.273 μs …   7.313 ms  ┊ GC (min … max): 0.00% … 24.65%
 Time  (median):     84.526 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   90.103 μs ± 172.512 μs  ┊ GC (mean ± σ):  1.44% ±  0.75%

        ▁▅█▇█▇▇▆▃▃                                              
  ▂▂▂▃▄▆██████████▇▆▅▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▂▂▂▂▂▂▂▂ ▃
  81.3 μs         Histogram: frequency by time         98.3 μs <

 Memory estimate: 1.91 KiB, allocs estimate: 48.

julia> @benchmark irfft(rfft(n,3), size(n,3), 3)
BechmarkTools.Trial: 10000 samples with 1 evaluations.
 Range (min … max):  167.979 μs …   7.797 ms  ┊ GC (min … max): 0.00% … 23.84%
 Time  (median):     172.844 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   184.898 μs ± 259.530 μs  ┊ GC (mean ± σ):  1.48% ±  1.03%

     ▁██▃                                                        
  ▂▂▄████▅▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▂▂▁▁▂▁▁▂▂▁▁▁▂▂▁▁▁▁▂▁▁▂ ▃
  168 μs           Histogram: frequency by time          224 μs <

 Memory estimate: 4.09 KiB, allocs estimate: 101.

julia> @benchmark irfft(rfft(n,3), 100, 3)
BechmarkTools.Trial: 10000 samples with 1 evaluations.
 Range (min … max):  168.265 μs …   7.826 ms  ┊ GC (min … max): 0.00% … 23.62%
 Time  (median):     172.484 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   184.944 μs ± 259.261 μs  ┊ GC (mean ± σ):  1.48% ±  1.03%

        ▃▆█▇▆▃                                                   
  ▁▁▂▃▅▇███████▅▃▃▂▂▂▂▂▂▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  168 μs           Histogram: frequency by time          195 μs <

 Memory estimate: 4.09 KiB, allocs estimate: 101.

Bottom line: Speedup using pre-allocated arrays is somewhat debatable. We allocate a little less memory, similar to the CPU implementation. The largest improvement can be seen when doing the whole pipeline x -> rfft -> irfft. Here, we can slice the amount of allocated memory in third.