LuxDL / Lux.jl

Elegant and Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
506 stars 63 forks source link

Mixed-Precision Matrix Multiply Performance Regression #847

Closed marcobonici closed 3 months ago

marcobonici commented 3 months ago

In the recent releases of Lux, I have found a worsened performance on some small NNs.

Here is a MWE

using BenchmarkTools
using Lux
using Random

rng = Xoshiro(2)

model = Chain(Dense(6, 64, tanh), Dense(64, 64, tanh), Dense(64, 64, tanh), Dense(64, 64, tanh), Dense(64, 64, tanh), Dense(64, 4999))

x = rand(6)
ps, st = Lux.setup(rng, model)

@benchmark Lux.apply(model, x, ps, st)

When using Lux@0.5.10, I get

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  46.278 μs …  1.356 ms  ┊ GC (min … max): 0.00% … 91.56%
 Time  (median):     50.907 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   56.941 μs ± 24.005 μs  ┊ GC (mean ± σ):  1.09% ±  3.37%

  ▄█▇▅▄▄▄▃▃▄▅▃▄▃▃▃▂▃▃▂▁▂▂▂▁▂▁▁▁▁▁                             ▂
  ██████████████████████████████████▇▇▇▇▇▆▇▆▆▆▆▆▄▃▆▄▄▂▄▄▂▄▂▃▄ █
  46.3 μs      Histogram: log(frequency) by time       109 μs <

 Memory estimate: 87.14 KiB, allocs estimate: 32.

with the latest release, I find

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  333.805 μs …  2.110 ms  ┊ GC (min … max): 0.00% … 70.15%
 Time  (median):     358.978 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   385.809 μs ± 82.572 μs  ┊ GC (mean ± σ):  0.82% ±  3.93%

  ▃▇▇█▆▅▅▄▄▄▄▄▅▄▄▃▃▃▃▂▂▂▂▁▁▁▁▁  ▁▁▁▁   ▁                       ▂
  ██████████████████████████████████████▇▇▇▇▆▇▆▇▇▅▆▅▅▆▅▁▅▄▅▃▃▅ █
  334 μs        Histogram: log(frequency) by time       664 μs <

 Memory estimate: 198.69 KiB, allocs estimate: 94.

Lux.jl went from being only 20% slower than SimpleChains.jl on the equivalent NN to 7 times slower. Is this something expected, given some recent developmenets? If useful, I can try to pindown which specific release created the performance issue.

Cheers, Marco

avik-pal commented 3 months ago

Sorry for this inconvenience. It was caused by an error in the hardware detection. I will release a version of LuxLib by tonight that fixes this. Meanwhile can you install LuxLib#main and Lux#main to confirm that it is fixed on main for you?

Sidenote: I just setup https://luxdl.github.io/LuxLib.jl/benchmarks/ to avoid exactly these problems from happening 😓

marcobonici commented 3 months ago

Hi @avik-pal , thank you for your (lightning fast!) answer. No issue at all :)

I added both Lux and LuxLib on the main. Here is the benchmark

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  307.991 μs …   2.027 ms  ┊ GC (min … max): 0.00% … 57.99%
 Time  (median):     378.190 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   420.085 μs ± 132.718 μs  ┊ GC (mean ± σ):  0.30% ±  2.25%

  ▅█▆    ▁                                                       
  ███▇▅▅██▇▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  308 μs           Histogram: frequency by time          917 μs <

 Memory estimate: 70.08 KiB, allocs estimate: 68.
avik-pal commented 3 months ago

That's strange. Let's try LuxLib#ap/act_fuse2 once.

Can you share the following?

julia> versioninfo()
julia> LuxLib.System.L1CacheSize
julia> LuxLib.System.L2CacheSize
julia> LuxLib.System.L3CacheSize
julia> LuxLib.System.INTEL_HARDWARE
julia> LuxLib.System.AMD_RYZEN_HARDWARE
julia> LuxLib.System.use_octavian()
avik-pal commented 3 months ago

I am getting

Float32

julia> @benchmark Lux.apply(model, $x, $ps, $st)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  36.540 μs …  7.864 ms  ┊ GC (min … max): 0.00% … 98.81%
 Time  (median):     40.156 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   42.397 μs ± 88.964 μs  ┊ GC (mean ± σ):  2.82% ±  1.39%

          ▃▆▇███▇▆▅▄▄▃▃▂▂▂▂▂▂▂▂▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁ ▁ ▁▁          ▃
  ▃▃▁▁▁▃▄██████████████████████████████████████████████▇▇▇▇▇▇ █
  36.5 μs      Histogram: log(frequency) by time      52.3 μs <

 Memory estimate: 22.41 KiB, allocs estimate: 39.

Float64

julia> @benchmark Lux.apply(model, x, ps, st)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  71.279 μs …   7.377 ms  ┊ GC (min … max): 0.00% … 98.06%
 Time  (median):     75.448 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   79.545 μs ± 114.839 μs  ┊ GC (mean ± σ):  3.47% ±  2.84%

             ▃▆██▆▂
  ▂▅▅▄▄▄▄▅▄▆███████▇▄▃▃▂▃▂▂▂▂▂▁▁▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁ ▃
  71.3 μs         Histogram: frequency by time         88.8 μs <

 Memory estimate: 43.05 KiB, allocs estimate: 38.
avik-pal commented 3 months ago

Almost all of the recent changes were made to make Lux faster on smaller models. For eg, if your last layer is not 4999 but 4.

Lux 0.5.10

julia> @benchmark Lux.apply(model, $x, $ps, $st)
BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min … max):  2.684 μs … 606.403 μs  ┊ GC (min … max): 0.00% … 93.67%
 Time  (median):     2.863 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.137 μs ±   8.433 μs  ┊ GC (mean ± σ):  6.94% ±  2.90%

       ▃▅▇███▇▇▅▄▄▂▂▁
  ▁▁▂▄▇██████████████▇▇▇▆▆▅▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁ ▄
  2.68 μs         Histogram: frequency by time        3.46 μs <

 Memory estimate: 5.67 KiB, allocs estimate: 31.

Lux#main and LuxLib#ap/act_fuse2

julia> @benchmark Lux.apply(model, $x, $ps, $st)
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
 Range (min … max):  1.475 μs …  2.240 ms  ┊ GC (min … max):  0.00% … 99.80%
 Time  (median):     1.610 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   2.297 μs ± 25.193 μs  ┊ GC (mean ± σ):  14.70% ±  1.41%

   ▅█▅▁
  ▄████▆▄▃▂▂▂▂▁▁▁▁▂▂▂▂▂▂▂▃▃▃▂▂▂▂▃▂▂▃▃▃▃▂▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  1.48 μs        Histogram: frequency by time        2.88 μs <

 Memory estimate: 2.86 KiB, allocs estimate: 38.
marcobonici commented 3 months ago

Here is what I get, with the branches you asked me to use

BenchmarkTools.Trial: 9895 samples with 1 evaluation.
 Range (min … max):  379.379 μs …   2.589 ms  ┊ GC (min … max): 0.00% … 75.62%
 Time  (median):     446.982 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   501.847 μs ± 139.619 μs  ┊ GC (mean ± σ):  0.29% ±  2.18%

   ██▁  ▁                   ▃                                    
  ▁███▇▅█▆▅▄▄▃▂▂▂▂▂▂▁▁▁▁▁▂▂▂█▄▂▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  379 μs           Histogram: frequency by time         1.03 ms <

 Memory estimate: 70.08 KiB, allocs estimate: 68.

Sharing the output of what you asked me in a sec.

marcobonici commented 3 months ago
julia> versioninfo()
Julia Version 1.10.0
Commit 3120989f39b (2023-12-25 18:01 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 20 × 13th Gen Intel(R) Core(TM) i7-13700H
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, goldmont)
  Threads: 23 on 20 virtual cores
Environment:
  LD_GOLD = /home/marcobonici/miniconda3/bin/x86_64-conda-linux-gnu-ld.gold

julia> LuxLib.System.L1CacheSize
32768

julia> LuxLib.System.L2CacheSize
1310720

julia> LuxLib.System.L3CacheSize
25165824

julia> LuxLib.System.INTEL_HARDWARE
static(true)

julia> LuxLib.System.AMD_RYZEN_HARDWARE
static(false)

julia> LuxLib.use_octavian()
ERROR: UndefVarError: `use_octavian` not defined
Stacktrace:
 [1] getproperty(x::Module, f::Symbol)
   @ Base ./Base.jl:31
 [2] top-level scope
   @ REPL[8]:1
avik-pal commented 3 months ago

Threads: 23 on 20 virtual cores

Start with reduced threads maybe --threads=12? LoopVectorization is probably oversubscribing the threads

marcobonici commented 3 months ago

It didn't change the result.

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  358.917 μs …  2.160 ms  ┊ GC (min … max): 0.00% … 76.43%
 Time  (median):     393.664 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   414.230 μs ± 61.950 μs  ┊ GC (mean ± σ):  0.27% ±  2.19%

      █ ▅                                                       
  ▄▃▄▅█▇█▆▆▄▅▄▄▄▃▆▅▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂ ▃
  359 μs          Histogram: frequency by time          634 μs <

 Memory estimate: 70.08 KiB, allocs estimate: 68.
julia> versioninfo()
Julia Version 1.10.0
Commit 3120989f39b (2023-12-25 18:01 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 20 × 13th Gen Intel(R) Core(TM) i7-13700H
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, goldmont)
  Threads: 1 on 20 virtual cores
Environment:
  LD_GOLD = /home/marcobonici/miniconda3/bin/x86_64-conda-linux-gnu-ld.gold
avik-pal commented 3 months ago

What happens if your model doesn't use 4999 as the last dim and instead uses 4?

Also can you show the output of a profiler? @profview if you are using VSCode.

avik-pal commented 3 months ago

The allocations are kind of bothering me " Memory estimate: 70.08 KiB, allocs estimate: 68.". It is going down a codepath it shouldn't. On my machines it always gives "Memory estimate: 43.05 KiB, allocs estimate: 38.".

marcobonici commented 3 months ago

If I use 4 rather than 4999, I get

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  19.544 μs … 40.848 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     21.276 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   21.388 μs ±  1.528 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

    ▂▄      ▆█▆▄▂                                              
  ▂▅███▆▄▃▃▇█████▆▅▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  19.5 μs         Histogram: frequency by time        28.5 μs <

 Memory estimate: 5.36 KiB, allocs estimate: 63.
avik-pal commented 3 months ago

Okay let try to break it down. Can you run:

julia> using LuxLib

julia> N = 2 .^ (1:12)

julia> for xdim in N
           x = rand(Float32, xdim, xdim)
           @info xdim
           @btime LuxLib.Impl.matmul($x, $x)
           @btime $x * $x
       end

(I need to step away from my computer for a couple of hrs, I will get back to this in the evening)

marcobonici commented 3 months ago
[ Info: 2
  26.013 ns (1 allocation: 80 bytes)
  21.012 ns (1 allocation: 80 bytes)
[ Info: 4
  28.621 ns (1 allocation: 128 bytes)
  82.369 ns (1 allocation: 128 bytes)
[ Info: 8
  41.386 ns (1 allocation: 336 bytes)
  140.383 ns (1 allocation: 336 bytes)
[ Info: 16
  127.166 ns (1 allocation: 1.06 KiB)
  412.754 ns (1 allocation: 1.06 KiB)
[ Info: 32
  925.400 ns (1 allocation: 4.12 KiB)
  2.177 μs (1 allocation: 4.12 KiB)
[ Info: 64
  6.182 μs (1 allocation: 16.12 KiB)
  14.286 μs (1 allocation: 16.12 KiB)
[ Info: 128
  49.376 μs (2 allocations: 64.05 KiB)
  73.540 μs (2 allocations: 64.05 KiB)
[ Info: 256
  381.197 μs (2 allocations: 256.05 KiB)
  328.398 μs (2 allocations: 256.05 KiB)
[ Info: 512
  3.152 ms (2 allocations: 1.00 MiB)
  2.060 ms (2 allocations: 1.00 MiB)
[ Info: 1024
  25.814 ms (2 allocations: 4.00 MiB)
  17.013 ms (2 allocations: 4.00 MiB)
[ Info: 2048
  141.340 ms (2 allocations: 16.00 MiB)
  137.267 ms (2 allocations: 16.00 MiB)
[ Info: 4096
  1.159 s (2 allocations: 64.00 MiB)
  1.189 s (2 allocations: 64.00 MiB)
avik-pal commented 3 months ago

Oh that ran fast, I think I know what is happening here. Can I get a profview profile? I think SLEEFPirates is not great on your hardware

marcobonici commented 3 months ago

Screenshot from 2024-08-13 19-00-21 What do I need to show, specifically?

avik-pal commented 3 months ago

If you can share https://github.com/tkluck/StatProfilerHTML.jl using this I can take it from there

On Tue, 13 Aug, 2024, 16:05 Marco Bonici, @.***> wrote:

Screenshot.from.2024-08-13.19-00-21.png (view on web) https://github.com/user-attachments/assets/344fbae8-0231-44b8-88fc-9d64f3d9cc9b What do I need to show, specifically?

— Reply to this email directly, view it on GitHub https://github.com/LuxDL/Lux.jl/issues/847#issuecomment-2287357805, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHJF57SWEXJJBRRMTCYXKKDZRKGNFAVCNFSM6AAAAABMPDQXZ6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOBXGM2TOOBQGU . You are receiving this because you were mentioned.Message ID: <LuxDL/Lux. @.***>

marcobonici commented 3 months ago

statprof.zip image Here they are!

marcobonici commented 3 months ago

If I stop answering is becasue I am gonna go to bed (I am in Europe currently).

avik-pal commented 3 months ago

Ah figured it out (and finally reproduced locally)

Turns out if you do Float32 x Float64 julia silently converts it to Float64 x Float64 allowing it to hit BLAS. This is easy to fix, I will land a fix later tonight.

The machine I was using before was a server CPU ~so had a massive L2 cache and we kept using Octavian or Loopvec so it was never hitting the slow julia fallback.~ -- the actual reason is that native matrix multiply in Julia 1.11 is really fast

As a sidenote, I recommend users to set https://lux.csail.mit.edu/stable/api/Lux/utilities#Lux.match_eltype to warn by default. (or error if you are using Lux in performance critical code)

marcobonici commented 3 months ago

Great! Thanks for sorting this out quickly:)

marcobonici commented 3 months ago

First and foremost, thanks! Now the timings are much better and even improved over the old ones!

Before, using 5.10 I found

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  43.618 μs …  1.576 ms  ┊ GC (min … max): 0.00% … 95.93%
 Time  (median):     45.809 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   49.805 μs ± 33.522 μs  ┊ GC (mean ± σ):  2.15% ±  3.59%

   ▂▄▆██▇▅▃▂▂▂▂▁                       ▂▃▃▃▃▃▃▃▃▂▂▂▂▁         ▂
  ▇██████████████▇▇▇▆▆▆▅▆▄▆▅▃▃▅▄▅▁▁▁▅▇███████████████████▇█▇▇ █
  43.6 μs      Histogram: log(frequency) by time      65.7 μs <

 Memory estimate: 87.25 KiB, allocs estimate: 33.

Now, with 5.64 I have

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  32.010 μs …  1.506 ms  ┊ GC (min … max): 0.00% … 95.73%
 Time  (median):     33.059 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   34.433 μs ± 22.344 μs  ┊ GC (mean ± σ):  1.59% ±  2.81%

    ▂▅▇██▇▆▄▃▁▁                           ▁▁▁▁▁▂▂▂▁▁          ▂
  ▄▇██████████████▇▇▇▇▆▅▆▆▆▆▆▅▆▅▄▅▆▅▆▅▆▇▇▇███████████▇▇▇▇▇▆▅▆ █
  32 μs        Histogram: log(frequency) by time      41.5 μs <

 Memory estimate: 44.48 KiB, allocs estimate: 65.

The only residual issue I see is with multithreading. If I launch julia with

julia --project=. -t 16

on 5.10 I get

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  42.952 μs …  1.285 ms  ┊ GC (min … max): 0.00% … 93.94%
 Time  (median):     45.000 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   47.024 μs ± 21.948 μs  ┊ GC (mean ± σ):  1.40% ±  3.49%

    ▃▅▇███▇▅▄▂▁▁▁▁                              ▁▁▂▂▂▃▂▂▂▂▁▁  ▃
  ▆█████████████████▇▇▇▇▇▇▅▅▄▆▅▅▄▄▃▅▁▅▃▃▃▁▁▁▁▃▅▇█████████████ █
  43 μs        Histogram: log(frequency) by time      60.7 μs <

 Memory estimate: 87.25 KiB, allocs estimate: 33.

on 5.64 I get

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  32.097 μs … 39.295 ms  ┊ GC (min … max): 0.00% … 4.86%
 Time  (median):     33.026 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   72.568 μs ±  1.209 ms  ┊ GC (mean ± σ):  1.12% ± 0.08%

  ▁▆██▇▅▅▃▂▁▁               ▁▁▁▂▂▂▁▁▁▁                        ▂
  ████████████████▇██▇▇▇▇▇█▇████████████▇▆▇▇▅▆▆▆▅▄▅▅▄▆▅▆▆▄▅▅▅ █
  32.1 μs      Histogram: log(frequency) by time      47.2 μs <

 Memory estimate: 44.48 KiB, allocs estimate: 65.

I also tried to use Chairmarks.jl to perform the benchmark, but the results (after 10'000 evaluations) are pretty much consistent with BenchmarkTools.jl.

In my use-case scenario I can circumvent the issue, launching multiple processes (in this way the performance does not worsen) and using distributed computing, but I wonder whether also this can be fixed for the general audience.

Thank you again @avik-pal for your support up to now :)

avik-pal commented 3 months ago

What is the issue with multithreading?

marcobonici commented 3 months ago

I get a higher execution meantime.

avik-pal commented 3 months ago

Right, it is coming from us using a Julia native matrix multiplication which leads to one-off high compile times but that shouldn't show in general after the first run (not sure why you are getting it in multiple runs). Now coming to why we made the switch:

  1. The idea is similar to SimpleChains where we want really good performance for smaller problems, so BLAS does have a sizeable overhead here and hence we use either a custom LoopVec impl or Octavian for mid-sized problems
  2. However, in contrast to SimpleChains, we also need to be fast at scale and for larger problems we do use BLAS. In your case we would have hit BLAS BUT there is a Mixed-Precision matmul. To use BLAS we need to make a copy of the matrix (while it doesn't show up very freq in the benchmarks that will lead to significantly more pressure on the GC), so instead we use Octavian which gives a nice performance for arbitrary types
marcobonici commented 3 months ago

Thanks for the explanation @avik-pal ! So, should the performance obtained "good" or you expect to try to correct/improve? As I said, in my use case scenario I can circumvent the issue I found using distributed computing (also in local).

avik-pal commented 3 months ago

No this should be pretty much it. You could try to play around with thread count and see but generally the backend (loopvec and octavian) is smart enough to not use top many threads.

marcobonici commented 3 months ago

Makes sense. Thank again for the detailed explanations and the amazing library :)