Closed wsmoses closed 12 months ago
Sample call:
using Enzyme
using LinearSolve, LinearAlgebra
n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
b2 = rand(n);
db2 = zeros(n);
function f(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg)
s1 = sol1.u
norm(s1)
end
f(A, b1, b2) # Uses BLAS
Enzyme.autodiff(Reverse, f, Duplicated(A, dA), Duplicated(b1, db1), Duplicated(b2, db2))
@show dA, db1, db2
Merging #377 (89e10df) into main (5a25b7d) will increase coverage by
48.24%
. Report is 6 commits behind head on main. The diff coverage is1.11%
.
@@ Coverage Diff @@
## main #377 +/- ##
===========================================
+ Coverage 20.01% 68.25% +48.24%
===========================================
Files 14 24 +10
Lines 1444 1884 +440
===========================================
+ Hits 289 1286 +997
+ Misses 1155 598 -557
Files Changed | Coverage Δ | |
---|---|---|
ext/LinearSolveEnzymeExt.jl | 0.00% <0.00%> (ø) |
|
src/init.jl | 57.14% <50.00%> (-17.86%) |
:arrow_down: |
... and 22 files with indirect coverage changes
:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more
It looks like this only handles the case of solve
, but not solve!
. So I presume this case would still not work:
using LinearSolve, LinearAlgebra
# using MKL_jll
n = 100
A = rand(n, n)
b1 = rand(n);
b2 = rand(n);
function f(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)
linsolve = init(prob, alg)
sol1 = solve!(linsolve)
s1 = copy(sol1.u)
linsolve.b = b2
sol2 = solve!(linsolve)
s2 = copy(sol2.u)
norm(s1 + s2)
end
f(A, b1, b2) # Uses BLAS
f(A, b1, b2; alg=RFLUFactorization()) # Uses loops
f(A, b1, b2; alg=MKLLUFactorization()) # Requires `using MKL_jll`
using Enzyme
dA = zero(A)
db1 = zero(b1)
db2 = zero(b2)
Enzyme.autodiff(Reverse, f, Duplicated(A,dA),
Duplicated(b1, db1), Duplicated(b2, db2))
which is https://github.com/EnzymeAD/Enzyme.jl/issues/1065.
I at least added a test for the solve
case, but the most common case is on solve!
so it would be good to figure out how to do that. It's the same thing except solve!(cache)
has cache.A
and cache.b1
, where cache.isfresh == true
means A
is already factorized. Is there a way to define the derivative w.r.t. fields of the mutable cache? Or should this be done with a solve!_up
type thing?
Pushed extension for solve! and init now.
While was at it, also added batch mode support.
While was at it, also added batch mode support.
What in here was required for batch mode support?
Not specializing to just duplicated but also supporting batchduplicated, which has dval as a tuple of shadows
As a tuple, does that have an issue scaling to say batch of a 100 or 1000 things?
For conservative correctness yes. A may be modified between the forward and reverse pass.
The overwritten set of bools says if the outermost struct pointer is overwritten and has no information about internal members being overwritten.
As Julia and other Alias analysis is improved (or we have an ImmutableArray type or something), this can be elided in the future.
On Fri, Sep 22, 2023 at 5:28 PM Christopher Rackauckas < @.***> wrote:
@.**** commented on this pull request.
In ext/LinearSolveEnzymeExt.jl https://github.com/SciML/LinearSolve.jl/pull/377#discussion_r1334858450:
+
- if EnzymeRules.width(config) == 1
- dres.u .= 0
- else
- for dr in dres
- dr.u .= 0
- end
- end
- resvals = if EnzymeRules.width(config) == 1
- dres.u
- else
- (dr.u for dr in dres)
- end
- cache = (copy(linsolve.val.A), res, resvals)
Is this copy necessary?
— Reply to this email directly, view it on GitHub https://github.com/SciML/LinearSolve.jl/pull/377#pullrequestreview-1640707035, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXH63K3U4YYGH6FJEALX3YGHPANCNFSM6AAAAAA5CMNQWY . You are receiving this because you authored the thread.Message ID: @.***>
It supports being used in arbitrary sizes.
In practice of course some sizes could be better than others. Eg for vectorization sake a power of two. Likewise, if a computation can be reused for all batch elements that could improve perf. Eg if transpose(A) generated a new matrix and not a view we could do that once for all batches.
On Fri, Sep 22, 2023 at 5:29 PM Christopher Rackauckas < @.***> wrote:
As a tuple, does that have an issue scaling to say batch of a 100 or 1000 things?
— Reply to this email directly, view it on GitHub https://github.com/SciML/LinearSolve.jl/pull/377#issuecomment-1732094283, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXAF3EZFCL2WXMKTA4TX3YGOJANCNFSM6AAAAAA5CMNQWY . You are receiving this because you authored the thread.Message ID: @.***>
Sure, if you know a better set of things to cache, we can choose those instead. I don’t know much about the internals of solve so I went for this form.
On Fri, Sep 22, 2023 at 5:31 PM Christopher Rackauckas < @.***> wrote:
@.**** commented on this pull request.
In ext/LinearSolveEnzymeExt.jl https://github.com/SciML/LinearSolve.jl/pull/377#discussion_r1334860509:
- end
- dAs = if EnzymeRules.width(config) == 1
- (linsolve.dval.A,)
- else
- (dval.A for dval in linsolve.dval)
- end
- dbs = if EnzymeRules.width(config) == 1
- (linsolve.dval.b,)
- else
- (dval.b for dval in linsolve.dval)
- end
- for (dA, db, dy) in zip(dAs, dbs, dys)
- invprob = LinearSolve.LinearProblem(transpose(A), dy)
In the forward pass the matrix A is factorized, so in theory we don't need to factorize it again, just transpose A from the forward pass. Is there a way to grab that?
— Reply to this email directly, view it on GitHub https://github.com/SciML/LinearSolve.jl/pull/377#pullrequestreview-1640710765, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXBE5PMQ2DDPQRP7SSTX3YGUJANCNFSM6AAAAAA5CMNQWY . You are receiving this because you authored the thread.Message ID: @.***>
The key that I'm pointing out here is similar to the top of https://docs.sciml.ai/LinearSolve/stable/tutorials/caching_interface/. But here, what solve!
is doing is solving:
_A = lu!(A)
_A \ b1
and then the backpass is:
_At = lu!(A')
_At \ db1
but we also have that (essentially) _At = _A'
, or at least it can be computed in O(n) time, whereas a factorization is O(n^3) and thus lu!
is one of the most expensive operations.
So what I'm wondering is if it's safe to assume that linsolve
is the same linsolve
object from the forward pass, or if it may have been further mutated.
The key that I'm pointing out here is similar to the top of https://docs.sciml.ai/LinearSolve/stable/tutorials/caching_interface/. But here, what
solve!
is doing is solving:_A = lu!(A) _A \ b1
and then the backpass is:
_At = lu!(A') _At \ db1
but we also have that (essentially)
_At = _A'
, or at least it can be computed in O(n) time, whereas a factorization is O(n^3) and thuslu!
is one of the most expensive operations.So what I'm wondering is if it's safe to assume that
linsolve
is the samelinsolve
object from the forward pass, or if it may have been further mutated.
It's the same Julia object, but it's possible it's fields may have been modified. If it's immutable, then it's the same.
Even if it's overwritten, however, you can still add whatever is relevant from he LU into the cache and use that as a starting point
Awesome, I'll leave that as a follow-up, no need to handle it in this PR. But the tests do need to get fixed.
The transpose of the factorization is the factorization of the transpose:
using LinearAlgebra
A = rand(4,4)
luA = lu(A)
At = transpose(A)
luAt = lu(At)
b = rand(4)
x = A \ b
x2 = A' \ b
x3 = luA \ b
x4 = luAt \ b
x5 = luA' \ b
x ≈ x3
x2 ≈ x4 ≈ x5
Confirmed from https://web.mit.edu/18.06/www/Spring17/Transposes.pdf. We can use this to generalize and optimize a bit.
The solving twice tests are a bit odd:
julia> db1
4-element Vector{Float64}:
0.0
0.0
0.0
0.0
julia> db2
4-element Vector{Float64}:
2.1215949279204196
-3.7095838683317943
-1.2286715744423384
5.967859589815037
It doubles db2
and has db1 = 0
. I think it's because the solve!(linsolve).u
aliases between the two. The forward pass is fine because of the copy
, but the Enzyme rule likely needs to copy something as well?
We can skip over that last test to merge, but do you know why that one algorithm would be treated so differently by Enzyme? I would've thought it didn't care if we're capturing stuff in rules, but it treats this algorithm particularly differently:
https://github.com/SciML/LinearSolve.jl/actions/runs/6290016689/job/17077077461?pr=377#step:6:807
requires current Enzyme main for a custom rules fix