SciML / LinearSolve.jl

LinearSolve.jl: High-Performance Unified Interface for Linear Solvers in Julia. Easily switch between factorization and Krylov methods, add preconditioners, and all in one interface.
https://docs.sciml.ai/LinearSolve/stable/
Other
244 stars 52 forks source link

Add Enzyme extension #377

Closed wsmoses closed 12 months ago

wsmoses commented 12 months ago

requires current Enzyme main for a custom rules fix

wsmoses commented 12 months ago

Sample call:

using Enzyme

using LinearSolve, LinearAlgebra

n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
b2 = rand(n);
db2 = zeros(n);

function f(A, b1, b2; alg = LUFactorization())
    prob = LinearProblem(A, b1)

    sol1 = solve(prob, alg)

    s1 = sol1.u
    norm(s1)
end

f(A, b1, b2) # Uses BLAS

Enzyme.autodiff(Reverse, f, Duplicated(A, dA), Duplicated(b1, db1), Duplicated(b2, db2))

@show dA, db1, db2
codecov[bot] commented 12 months ago

Codecov Report

Merging #377 (89e10df) into main (5a25b7d) will increase coverage by 48.24%. Report is 6 commits behind head on main. The diff coverage is 1.11%.

@@             Coverage Diff             @@
##             main     #377       +/-   ##
===========================================
+ Coverage   20.01%   68.25%   +48.24%     
===========================================
  Files          14       24       +10     
  Lines        1444     1884      +440     
===========================================
+ Hits          289     1286      +997     
+ Misses       1155      598      -557     
Files Changed Coverage Δ
ext/LinearSolveEnzymeExt.jl 0.00% <0.00%> (ø)
src/init.jl 57.14% <50.00%> (-17.86%) :arrow_down:

... and 22 files with indirect coverage changes

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

ChrisRackauckas commented 12 months ago

It looks like this only handles the case of solve, but not solve!. So I presume this case would still not work:

using LinearSolve, LinearAlgebra
# using MKL_jll

n = 100
A = rand(n, n)
b1 = rand(n);
b2 = rand(n);

function f(A, b1, b2; alg = LUFactorization())
    prob = LinearProblem(A, b1)

    linsolve = init(prob, alg)
    sol1 = solve!(linsolve)

    s1 = copy(sol1.u)

    linsolve.b = b2
    sol2 = solve!(linsolve)

    s2 = copy(sol2.u)
    norm(s1 + s2)
end

f(A, b1, b2) # Uses BLAS
f(A, b1, b2; alg=RFLUFactorization()) # Uses loops
f(A, b1, b2; alg=MKLLUFactorization()) # Requires `using MKL_jll`

using Enzyme

dA = zero(A)
db1 = zero(b1)
db2 = zero(b2)
Enzyme.autodiff(Reverse, f, Duplicated(A,dA), 
                Duplicated(b1, db1), Duplicated(b2, db2))

which is https://github.com/EnzymeAD/Enzyme.jl/issues/1065.

I at least added a test for the solve case, but the most common case is on solve! so it would be good to figure out how to do that. It's the same thing except solve!(cache) has cache.A and cache.b1, where cache.isfresh == true means A is already factorized. Is there a way to define the derivative w.r.t. fields of the mutable cache? Or should this be done with a solve!_up type thing?

wsmoses commented 12 months ago

Pushed extension for solve! and init now.

While was at it, also added batch mode support.

ChrisRackauckas commented 12 months ago

While was at it, also added batch mode support.

What in here was required for batch mode support?

wsmoses commented 12 months ago

Not specializing to just duplicated but also supporting batchduplicated, which has dval as a tuple of shadows

ChrisRackauckas commented 12 months ago

As a tuple, does that have an issue scaling to say batch of a 100 or 1000 things?

wsmoses commented 12 months ago

For conservative correctness yes. A may be modified between the forward and reverse pass.

The overwritten set of bools says if the outermost struct pointer is overwritten and has no information about internal members being overwritten.

As Julia and other Alias analysis is improved (or we have an ImmutableArray type or something), this can be elided in the future.

On Fri, Sep 22, 2023 at 5:28 PM Christopher Rackauckas < @.***> wrote:

@.**** commented on this pull request.

In ext/LinearSolveEnzymeExt.jl https://github.com/SciML/LinearSolve.jl/pull/377#discussion_r1334858450:

+

  • if EnzymeRules.width(config) == 1
  • dres.u .= 0
  • else
  • for dr in dres
  • dr.u .= 0
  • end
  • end
  • resvals = if EnzymeRules.width(config) == 1
  • dres.u
  • else
  • (dr.u for dr in dres)
  • end
  • cache = (copy(linsolve.val.A), res, resvals)

Is this copy necessary?

— Reply to this email directly, view it on GitHub https://github.com/SciML/LinearSolve.jl/pull/377#pullrequestreview-1640707035, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXH63K3U4YYGH6FJEALX3YGHPANCNFSM6AAAAAA5CMNQWY . You are receiving this because you authored the thread.Message ID: @.***>

wsmoses commented 12 months ago

It supports being used in arbitrary sizes.

In practice of course some sizes could be better than others. Eg for vectorization sake a power of two. Likewise, if a computation can be reused for all batch elements that could improve perf. Eg if transpose(A) generated a new matrix and not a view we could do that once for all batches.

On Fri, Sep 22, 2023 at 5:29 PM Christopher Rackauckas < @.***> wrote:

As a tuple, does that have an issue scaling to say batch of a 100 or 1000 things?

— Reply to this email directly, view it on GitHub https://github.com/SciML/LinearSolve.jl/pull/377#issuecomment-1732094283, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXAF3EZFCL2WXMKTA4TX3YGOJANCNFSM6AAAAAA5CMNQWY . You are receiving this because you authored the thread.Message ID: @.***>

wsmoses commented 12 months ago

Sure, if you know a better set of things to cache, we can choose those instead. I don’t know much about the internals of solve so I went for this form.

On Fri, Sep 22, 2023 at 5:31 PM Christopher Rackauckas < @.***> wrote:

@.**** commented on this pull request.

In ext/LinearSolveEnzymeExt.jl https://github.com/SciML/LinearSolve.jl/pull/377#discussion_r1334860509:

  • end
  • dAs = if EnzymeRules.width(config) == 1
  • (linsolve.dval.A,)
  • else
  • (dval.A for dval in linsolve.dval)
  • end
  • dbs = if EnzymeRules.width(config) == 1
  • (linsolve.dval.b,)
  • else
  • (dval.b for dval in linsolve.dval)
  • end
  • for (dA, db, dy) in zip(dAs, dbs, dys)
  • invprob = LinearSolve.LinearProblem(transpose(A), dy)

In the forward pass the matrix A is factorized, so in theory we don't need to factorize it again, just transpose A from the forward pass. Is there a way to grab that?

— Reply to this email directly, view it on GitHub https://github.com/SciML/LinearSolve.jl/pull/377#pullrequestreview-1640710765, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXBE5PMQ2DDPQRP7SSTX3YGUJANCNFSM6AAAAAA5CMNQWY . You are receiving this because you authored the thread.Message ID: @.***>

ChrisRackauckas commented 12 months ago

The key that I'm pointing out here is similar to the top of https://docs.sciml.ai/LinearSolve/stable/tutorials/caching_interface/. But here, what solve! is doing is solving:

_A = lu!(A)
_A \ b1

and then the backpass is:

_At = lu!(A')
_At \ db1

but we also have that (essentially) _At = _A', or at least it can be computed in O(n) time, whereas a factorization is O(n^3) and thus lu! is one of the most expensive operations.

So what I'm wondering is if it's safe to assume that linsolve is the same linsolve object from the forward pass, or if it may have been further mutated.

wsmoses commented 12 months ago

The key that I'm pointing out here is similar to the top of https://docs.sciml.ai/LinearSolve/stable/tutorials/caching_interface/. But here, what solve! is doing is solving:


_A = lu!(A)

_A \ b1

and then the backpass is:


_At = lu!(A')

_At \ db1

but we also have that (essentially) _At = _A', or at least it can be computed in O(n) time, whereas a factorization is O(n^3) and thus lu! is one of the most expensive operations.

So what I'm wondering is if it's safe to assume that linsolve is the same linsolve object from the forward pass, or if it may have been further mutated.

It's the same Julia object, but it's possible it's fields may have been modified. If it's immutable, then it's the same.

wsmoses commented 12 months ago

Even if it's overwritten, however, you can still add whatever is relevant from he LU into the cache and use that as a starting point

ChrisRackauckas commented 12 months ago

Awesome, I'll leave that as a follow-up, no need to handle it in this PR. But the tests do need to get fixed.

ChrisRackauckas commented 12 months ago

The transpose of the factorization is the factorization of the transpose:

using LinearAlgebra
A = rand(4,4)
luA = lu(A)

At = transpose(A)
luAt = lu(At)

b = rand(4)

x  = A \ b
x2 = A' \ b
x3 = luA \ b
x4 = luAt \ b
x5 = luA' \ b

x ≈ x3
x2 ≈ x4 ≈ x5

Confirmed from https://web.mit.edu/18.06/www/Spring17/Transposes.pdf. We can use this to generalize and optimize a bit.

ChrisRackauckas commented 12 months ago

The solving twice tests are a bit odd:

julia> db1
4-element Vector{Float64}:
 0.0
 0.0
 0.0
 0.0

julia> db2
4-element Vector{Float64}:
  2.1215949279204196
 -3.7095838683317943
 -1.2286715744423384
  5.967859589815037

It doubles db2 and has db1 = 0. I think it's because the solve!(linsolve).u aliases between the two. The forward pass is fine because of the copy, but the Enzyme rule likely needs to copy something as well?

ChrisRackauckas commented 12 months ago

We can skip over that last test to merge, but do you know why that one algorithm would be treated so differently by Enzyme? I would've thought it didn't care if we're capturing stuff in rules, but it treats this algorithm particularly differently:

https://github.com/SciML/LinearSolve.jl/actions/runs/6290016689/job/17077077461?pr=377#step:6:807