Closed mohamed82008 closed 6 months ago
After this PR, the gradient wrt the matrix requires only O(n)
memory.
using LinearAlgebra, LinearSolve, Zygote
n = 100; A = rand(n, n); b1 = rand(n); b2 = rand(n);
function invquad(a, A, b)
prob = LinearProblem(A, b)
sol = solve(
prob,
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.RFLUFactorization),
)
return dot(a, sol.u)
end
db1, dA, db2 = Zygote.gradient(invquad, b1, A, b2);
Base.summarysize(dA)
# 1752
Base.summarysize(A)
# 80040
Attention: Patch coverage is 0%
with 2 lines
in your changes are missing coverage. Please review.
Project coverage is 25.10%. Comparing base (
c08f2e9
) to head (8d0fd26
).
Files | Patch % | Lines |
---|---|---|
src/adjoint.jl | 0.00% | 2 Missing :warning: |
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Tests seem to pass. I added version 3 to the docs toml file to hopefully fix the docs build.
The test failure is new and not related to this PR. I only added a version number in the last commit but the test failure is a method ambiguity failure. If you re-run the tests on master, you will probably get the same failure. Could be a dependency that upgraded and broke things between the 2 commits.
It seems like the tolerance is just set too tight in that test and multithreading in BLAS change it at that level.
new release?
I was going to handle some downgrade and test tolerance stuff https://github.com/SciML/LinearSolve.jl/pull/485 and release in a little bit.
Wait, why was this a major?
To be safe, I bumped the major version because the output type of Zygote.gradient wrt the matrix is changed in this PR.
I don't think we guaranteed the type on the pullback anywhere, just that it has the right actions in the derivative, and the Zygote overload was just added a release ago, so together I don't think this constitute a major bump but instead a minor.
Feel free to revert it. I was being safe. I wouldn't want the type to change on me if I am a user.
Checklist
Additional context
This PR implements the suggestion in https://discourse.julialang.org/t/how-do-you-speed-up-the-linear-sparse-solver-in-zygote/111801/41?u=mohamed82008. To be safe, I also bumped the major version because the output type of
Zygote.gradient
wrt the matrix is changed in this PR. No new documentation was added as this PR is just a performance improvement.