SciML / LinearSolve.jl

LinearSolve.jl: High-Performance Unified Interface for Linear Solvers in Julia. Easily switch between factorization and Krylov methods, add preconditioners, and all in one interface.
https://docs.sciml.ai/LinearSolve/stable/
Other
248 stars 53 forks source link

Rework default algorithm to be fully type stable #307

Closed ChrisRackauckas closed 1 year ago

ChrisRackauckas commented 1 year ago

Two possible ideas. One is to put all algorithms and caches in the default algorithm into a unityper. The other is to make a "mega algorithm" with runtime information on the choice. This takes the latter approach because that keeps most of the package code intact and makes it so that any algorithm choice by the user will not have any runtime behavior.

This uses an enum inside of the algorithm struct in order to choose the actual solver for the given process. In init and solve static dispatching is done through hardcoded branches.

Todo:

Things to consider:

The one thing that may be a blocker is the init cost. While I don't think it will be an issue, we will need to see if constructing all of the empty arrays simply to hold a thing of the right type is too costly. Basically, what we may need is for Float64[] to be optimized to be a no-op zero allocation, in which case essentially the whole cache structure should be free to build. As it stands, this may be a cause of overhead as it needs to build all of the potential caches even if it only ever uses one.

ChrisRackauckas commented 1 year ago

Inference Works:

using LinearSolve

A = rand(4, 4)
b = rand(4)
prob = LinearProblem(A, b)
sol = solve(prob)
using Test
@inferred solve(prob)
@inferred init(prob, nothing)
ChrisRackauckas commented 1 year ago
julia> @benchmark solve($prob)
BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min … max):  2.653 μs …  1.069 ms  ┊ GC (min … max):  0.00% … 74.47%
 Time  (median):     2.875 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   4.443 μs ± 36.512 μs  ┊ GC (mean ± σ):  24.39% ±  2.96%

    ▃▅▆▇█████▇▆▆▅▅▄▃▃▂▂▂▁▁▁▁▁  ▁ ▁▁▁▁                        ▃
  ▆████████████████████████████████████████████▇█▇▇▅▇▆▇▇▅▆▅▅ █
  2.65 μs      Histogram: log(frequency) by time     4.02 μs <

 Memory estimate: 8.83 KiB, allocs estimate: 101.

julia> @benchmark solve($prob, $(LUFactorization()))
BenchmarkTools.Trial: 10000 samples with 112 evaluations.
 Range (min … max):  762.652 ns … 35.605 μs  ┊ GC (min … max): 0.00% … 96.61%
 Time  (median):     778.643 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   858.785 ns ±  1.231 μs  ┊ GC (mean ± σ):  8.18% ±  5.54%

  ▂▅▇██▇▇▅▄▃▂▂▂▂▂▁▁▁                                           ▂
  ██████████████████████▇▆▇▇▇▇█▇▇▇▇▆▇▆▆▆▆▆▆▇▇▆▆▆▆▆▆▆▆▅▅▆▆▅▅▅▃▅ █
  763 ns        Histogram: log(frequency) by time       966 ns <

 Memory estimate: 1.47 KiB, allocs estimate: 15.
ChrisRackauckas commented 1 year ago

Needs https://github.com/JuliaArrays/ArrayInterface.jl/pull/415

ChrisRackauckas commented 1 year ago
using LinearSolve

A = rand(100, 100)
b = rand(100)
prob = LinearProblem(A, b)

using BenchmarkTools
@benchmark solve($prob)
@benchmark solve($prob, $(LUFactorization()))
julia> @benchmark solve($prob)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  66.166 μs …  2.565 ms  ┊ GC (min … max): 0.00% … 89.21%
 Time  (median):     72.355 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   77.852 μs ± 88.509 μs  ┊ GC (mean ± σ):  4.89% ±  4.16%

            ▃██▅▂                                              
  ▂▄▅▄▂▂▂▁▁▄█████▆▆▆█▇▆▅▄▃▃▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  66.2 μs         Histogram: frequency by time        91.7 μs <

 Memory estimate: 110.23 KiB, allocs estimate: 124.

julia> @benchmark solve($prob, $(LUFactorization()))
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  107.875 μs …  6.946 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     116.917 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   121.663 μs ± 92.707 μs  ┊ GC (mean ± σ):  1.87% ± 3.81%

        ▆█▁  ▆▂                                                 
  ▁▁▃▂▂▅███▄███▅▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  108 μs          Histogram: frequency by time          163 μs <

 Memory estimate: 81.80 KiB, allocs estimate: 16.

The problem seems to be that we need a way to lazily init the GMRES cache

ChrisRackauckas commented 1 year ago
using LinearSolve

A = rand(100, 100)
b = rand(100)
prob = LinearProblem(A, b)

using BenchmarkTools
@benchmark solve($prob)
@benchmark solve($prob, $(RFLUFactorization()))

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  65.333 μs …  2.537 ms  ┊ GC (min … max): 0.00% … 87.08%
 Time  (median):     70.583 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   76.125 μs ± 91.330 μs  ┊ GC (mean ± σ):  4.66% ±  3.77%

   ▄▅▄▁   ▁▆▇█▇▆▅▃▄▄▅▅▄▄▃▃▃▂▂▁▁▁▁▂▂▁▁                         ▂
  ▇████▆▅▂█████████████████████████████████████▇▆▇▇▇▅▇▆▆▄▅▄▅▅ █
  65.3 μs      Histogram: log(frequency) by time      90.2 μs <

 Memory estimate: 88.84 KiB, allocs estimate: 105.

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  63.458 μs …  2.124 ms  ┊ GC (min … max): 0.00% … 95.95%
 Time  (median):     68.208 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   72.105 μs ± 57.144 μs  ┊ GC (mean ± σ):  3.27% ±  3.97%

   ▁       ▅█▅                                                 
  ▄█▂▁▁▁▁▁▅████▄▃▃▃▄▄▄▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  63.5 μs         Histogram: frequency by time        87.5 μs <

 Memory estimate: 81.81 KiB, allocs estimate: 16.

reasonable scaling

ChrisRackauckas commented 1 year ago
using LinearSolve

A = rand(4, 4)
b = rand(4)
prob = LinearProblem(A, b)

solve(prob).alg.alg

using BenchmarkTools
@benchmark solve($prob)
@benchmark solve($prob, $(GenericLUFactorization()))

Small scale is not okay.

BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min … max):  2.653 μs … 904.287 μs  ┊ GC (min … max):  0.00% … 72.30%
 Time  (median):     2.796 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   4.310 μs ±  33.458 μs  ┊ GC (mean ± σ):  21.23% ±  2.74%

  ▁▅▇██▇▇▅▄▃▁     ▁ ▁▁  ▁ ▁▁▁▁▁▁                              ▂
  ███████████████████████████████▇▇▇▆▇▆▆▆▆▆▇▆▆▆▆▄▆▆▆▄▆▆▄▅▅▅▅▅ █
  2.65 μs      Histogram: log(frequency) by time      4.47 μs <

 Memory estimate: 8.59 KiB, allocs estimate: 105.

BenchmarkTools.Trial: 10000 samples with 140 evaluations.
 Range (min … max):  705.057 ns …  18.727 μs  ┊ GC (min … max): 0.00% … 95.22%
 Time  (median):     727.971 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   813.099 ns ± 938.705 ns  ┊ GC (mean ± σ):  7.74% ±  6.35%

   ▂▅▇██▇▄▃▃▃▄▄▄▄▂▂▂▂▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁                        ▂
  ▆████████████████████████████████████████████▇█▅▇▇▆▆▇▆▆▆▅▆▆▆▅ █
  705 ns        Histogram: log(frequency) by time        934 ns <

 Memory estimate: 1.47 KiB, allocs estimate: 15.
ChrisRackauckas commented 1 year ago

This is actually very good now.

using LinearSolve

A = rand(4, 4)
b = rand(4)
prob = LinearProblem(A, b)

solve(prob).alg.alg

using BenchmarkTools
@benchmark solve($prob)
@benchmark solve($prob, $(GenericLUFactorization(LinearSolve.RowMaximum())))
BenchmarkTools.Trial: 10000 samples with 202 evaluations.
 Range (min … max):  379.743 ns …   7.574 μs  ┊ GC (min … max): 0.00% … 93.00%
 Time  (median):     398.099 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   427.764 ns ± 306.002 ns  ┊ GC (mean ± σ):  4.29% ±  5.58%

       ▇█▄                                                       
  ▁▁▂▃████▇▆▆▅▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  380 ns           Histogram: frequency by time          516 ns <

 Memory estimate: 880 bytes, allocs estimate: 8.

BenchmarkTools.Trial: 10000 samples with 197 evaluations.
 Range (min … max):  452.198 ns …  11.249 μs  ┊ GC (min … max): 0.00% … 92.86%
 Time  (median):     463.411 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   498.177 ns ± 441.242 ns  ┊ GC (mean ± σ):  4.74% ±  5.09%

   ▃▇██▆▄▃▃▃▄▄▃▃▃▂▁▁▁▂▂▂▁▁▁▁▁   ▁                               ▂
  ▇█████████████████████████████████▇▇█▇▇▇▇█▇▆▆▄▇▇▅▆▅▄▆▅▅▄▅▅▅▄▄ █
  452 ns        Histogram: log(frequency) by time        599 ns <

 Memory estimate: 704 bytes, allocs estimate: 9.

Not sure why the default method is faster, but it's pretty consistently faster. And inference works:

using Test
@inferred solve(prob)
@inferred init(prob, nothing)
codecov[bot] commented 1 year ago

Codecov Report

Merging #307 (73dfa2c) into main (cb31d58) will increase coverage by 2.99%. The diff coverage is 78.91%.

@@            Coverage Diff             @@
##             main     #307      +/-   ##
==========================================
+ Coverage   73.39%   76.38%   +2.99%     
==========================================
  Files          15       15              
  Lines        1026     1207     +181     
==========================================
+ Hits          753      922     +169     
- Misses        273      285      +12     
Impacted Files Coverage Δ
ext/LinearSolveHYPREExt.jl 90.80% <ø> (ø)
src/LinearSolve.jl 54.54% <ø> (-28.22%) :arrow_down:
src/default.jl 67.26% <68.55%> (+30.19%) :arrow_up:
src/factorization.jl 79.67% <85.62%> (-0.46%) :arrow_down:
src/iterative_wrappers.jl 78.97% <94.44%> (-0.03%) :arrow_down:
src/common.jl 92.30% <100.00%> (+1.19%) :arrow_up:

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

ChrisRackauckas commented 1 year ago

Next up I want to make IterativeSolvers and KrylovKit into extensions, and then setup every solver with docstrings, and that's 2.0.

chriselrod commented 1 year ago

Not sure why the default method is faster, but it's pretty consistently faster. And inference works:

using Test
@inferred solve(prob)
@inferred init(prob, nothing)

Does it pass JET.@test_opt?

ChrisRackauckas commented 1 year ago

I didn't know about that, but I did check Cthulhu and it worked. If you can PR to add the JET testing that would be helpful.

ChrisRackauckas commented 1 year ago

https://github.com/SciML/LinearSolve.jl/pull/318

some work, others hit random stuff like print.