SciML / LinearSolve.jl

LinearSolve.jl: High-Performance Unified Interface for Linear Solvers in Julia. Easily switch between factorization and Krylov methods, add preconditioners, and all in one interface.
https://docs.sciml.ai/LinearSolve/stable/
Other
235 stars 51 forks source link

LinearSolve with SciMLSensitivity Solution Handling Requires sol.u #483

Open marklau34 opened 3 months ago

marklau34 commented 3 months ago

I wanted to try something simple with SciMLSensitivity.jl to find the sensitivities of the solution to a LinearProblem with respect to parameters. However, I get an unexpected error, which I outlined in a post.

using Zygote
using SciMLSensitivity
using ForwardDiff
using LinearSolve
import Random
Random.seed!(1234)

N = 2

function test_func(x::AbstractVector{T}) where {T<:Real}
    A = reshape(x[1:N*N], (N,N))
    b = x[N*N+1:end]
    # This works:
    # sol = A\b
    # But this seems to not work:
    prob = LinearProblem(A, b)
    sol = solve(prob)
    return sum(sol)
end

# Random Point
x0 = rand(N*N+N)

# Try with Zygote
grad_zygote = Zygote.gradient(test_func, x0)
display(grad_zygote[1])

# Compare with ForwardDiff
grad_forwarddiff = ForwardDiff.gradient(test_func, x0)
display(grad_forwarddiff)

The following error occurs:

ERROR: type Fill has no field u
Stacktrace:
  [1] getproperty
    @ .\Base.jl:37 [inlined]
  [2] (::LinearSolve.var"#∇linear_solve#103"{…})(∂sol::FillArrays.Fill{…})
    @ LinearSolve C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\adjoint.jl:58
  [3] ZBack
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
  [4] (::Zygote.var"#291#292"{Tuple{…}, Zygote.ZBack{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206
  [5] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
  [6] #solve#5
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:188 [inlined]
  [7] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
  [8] #291
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
  [9] #2169#back
    @ C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [10] solve
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:186 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [12] #291
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [13] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [14] #solve#4
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:183 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [16] #291
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [17] #2169#back
    @ C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [18] solve
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:182 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [20] test_func
    @ c:\GitCode\marklau\sandbox\sensitivity\test_zygote.jl:17 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [22] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91
 [23] gradient(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:148
 [24] top-level scope
    @ c:\GitCode\marklau\sandbox\sensitivity\test_zygote.jl:25
Some type information was truncated. Use `show(err)` to see complete types.

As suggested by @avik-pal, returning sum(sol.u) fixes the problem and this may be a bug not handling thegetindex(sol, sym) rrule correctly.

Is this a bug or is there as reason sol.u should be used in this case?

ChrisRackauckas commented 3 months ago

We just need a similar getindex overload https://github.com/SciML/SciMLBase.jl/blob/master/ext/SciMLBaseZygoteExt.jl#L97-L109