SciML / LinearSolve.jl

LinearSolve.jl: High-Performance Unified Interface for Linear Solvers in Julia. Easily switch between factorization and Krylov methods, add preconditioners, and all in one interface.
https://docs.sciml.ai/LinearSolve/stable/
Other
245 stars 52 forks source link

Make the rrule's outer product lazy #484

Closed mohamed82008 closed 6 months ago

mohamed82008 commented 6 months ago

Checklist

Additional context

This PR implements the suggestion in https://discourse.julialang.org/t/how-do-you-speed-up-the-linear-sparse-solver-in-zygote/111801/41?u=mohamed82008. To be safe, I also bumped the major version because the output type of Zygote.gradient wrt the matrix is changed in this PR. No new documentation was added as this PR is just a performance improvement.

mohamed82008 commented 6 months ago

After this PR, the gradient wrt the matrix requires only O(n) memory.

using LinearAlgebra, LinearSolve, Zygote

n = 100; A = rand(n, n); b1 = rand(n); b2 = rand(n);

function invquad(a, A, b)
    prob = LinearProblem(A, b)
    sol = solve(
        prob,
        LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.RFLUFactorization),
    )
    return dot(a, sol.u)
end

db1, dA, db2 = Zygote.gradient(invquad, b1, A, b2);

Base.summarysize(dA)
# 1752

Base.summarysize(A)
# 80040
codecov[bot] commented 6 months ago

Codecov Report

Attention: Patch coverage is 0% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 25.10%. Comparing base (c08f2e9) to head (8d0fd26).

Files Patch % Lines
src/adjoint.jl 0.00% 2 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #484 +/- ## =========================================== - Coverage 64.22% 25.10% -39.13% =========================================== Files 28 28 Lines 2200 2167 -33 =========================================== - Hits 1413 544 -869 - Misses 787 1623 +836 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

mohamed82008 commented 6 months ago

Tests seem to pass. I added version 3 to the docs toml file to hopefully fix the docs build.

mohamed82008 commented 6 months ago

The test failure is new and not related to this PR. I only added a version number in the last commit but the test failure is a method ambiguity failure. If you re-run the tests on master, you will probably get the same failure. Could be a dependency that upgraded and broke things between the 2 commits.

ChrisRackauckas commented 6 months ago

It seems like the tolerance is just set too tight in that test and multithreading in BLAS change it at that level.

mohamed82008 commented 6 months ago

new release?

ChrisRackauckas commented 6 months ago

I was going to handle some downgrade and test tolerance stuff https://github.com/SciML/LinearSolve.jl/pull/485 and release in a little bit.

ChrisRackauckas commented 6 months ago

Wait, why was this a major?

mohamed82008 commented 6 months ago

To be safe, I bumped the major version because the output type of Zygote.gradient wrt the matrix is changed in this PR.

ChrisRackauckas commented 6 months ago

I don't think we guaranteed the type on the pullback anywhere, just that it has the right actions in the derivative, and the Zygote overload was just added a release ago, so together I don't think this constitute a major bump but instead a minor.

mohamed82008 commented 6 months ago

Feel free to revert it. I was being safe. I wouldn't want the type to change on me if I am a user.