SciML / LinearSolve.jl

LinearSolve.jl: High-Performance Unified Interface for Linear Solvers in Julia. Easily switch between factorization and Krylov methods, add preconditioners, and all in one interface.
https://docs.sciml.ai/LinearSolve/stable/
Other
241 stars 52 forks source link

Adjoints for Linear Solve #449

Closed avik-pal closed 6 months ago

avik-pal commented 8 months ago

Fixes #198, Fixes #322

TODOs:

Example:

using LinearSolve, Zygote

A = rand(4, 4)
b = rand(4)

test_func_1(A, b) = sum(abs2, A \ b)

test_func_1(A, b)

∂A_1, ∂b_1 = @btime Zygote.gradient(test_func_1, copy(A), copy(b))
display(∂A_1)
display(∂b_1)

function test_func_2(A, b)
    prob = LinearProblem(A, b)
    sol = solve(prob)
    return sum(abs2, sol.u)
end

test_func_2(A, b)

∂A_2, ∂b_2 = @btime Zygote.gradient(test_func_2, copy(A), copy(b))
display(∂A_2)
display(∂b_2)

In the following case the cache stores the correct gradients but they are not propagated to A and b. @ChrisRackauckas any idea how to fix this?

cache = init(LinearProblem(copy(A), copy(b)), nothing);
function test_func_3(cache, A, b)
    cache.A = A
    cache.b = b
    sol = solve!(cache)
    return sum(abs2, sol.u)
end

test_func_3(cache, copy(A), copy(b))

∂cache, ∂A_3, ∂b_3 = @btime Zygote.gradient(test_func_3, cache, copy(A), copy(b))
∂cache.A
∂cache.b
display(∂A_3)  # nothing
display(∂b_3)  # nothing
codecov[bot] commented 8 months ago

Codecov Report

Attention: Patch coverage is 6.66667% with 42 lines in your changes are missing coverage. Please review.

Project coverage is 22.96%. Comparing base (a206054) to head (06c09a3).

:exclamation: Current head 06c09a3 differs from pull request most recent head 7671369. Consider uploading reports for the commit 7671369 to get more accurate results

Files Patch % Lines
src/adjoint.jl 2.32% 42 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #449 +/- ## =========================================== - Coverage 66.12% 22.96% -43.17% =========================================== Files 27 28 +1 Lines 2146 2147 +1 =========================================== - Hits 1419 493 -926 - Misses 727 1654 +927 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

ChrisRackauckas commented 8 months ago

In the following case the cache stores the correct gradients but they are not propagated to A and b. @ChrisRackauckas any idea how to fix this?

Is this not just an inherent limitation of Zygote with mutation? I would presume we just need to stay away from that and only support solve with CRC.