SciML / LinearSolve.jl

LinearSolve.jl: High-Performance Unified Interface for Linear Solvers in Julia. Easily switch between factorization and Krylov methods, add preconditioners, and all in one interface.
https://docs.sciml.ai/LinearSolve/stable/
Other
249 stars 53 forks source link

linearsolve using the transpose of the factorization #92

Open MKAbdElrahman opened 2 years ago

MKAbdElrahman commented 2 years ago

For defining rrules, I need to make use of the factorization in the forward pass. The rrules is also a linearsolve with the transpose of the linear system, how to avoid defining a new LinearProblem ?

This my current code, the LinearSolve interface with make it make use other algorithms and more clean


function ChainRulesCore.rrule(::typeof(efield), sim::Simulation,ϵ)
    linsys = ConstructLinearSystem(sim, ϵ )
    A ,b = linsys.A , linsys.b
    x = similar(b); x_adj = similar(b)
    F = lu(A)
    LinearAlgebra.ldiv!(x, F, b)
    sim.E = x
    function efield_pullback(ȳ)
        LinearAlgebra.ldiv!(x_adj, transpose(F), conj.(ȳ))
        f̄ = NoTangent()
        f̄oo =  NoTangent()
        ϵoo =  real(x .* x_adj)
        return f̄, f̄oo , ϵoo
    end
    return x, efield_pullback
end

Thanks!
ChrisRackauckas commented 2 years ago

oh yes, this would be good to add. Now it might be hard for Krylov methods because it will require that the operator has an adjoint defined, and many times they might be defined Jacobian-free, that will just throw an appropriate error if it's not defined.

I think the right thing to do would be to add a boolean to the LinearProblem transpose=false by default, and then we can setup the algorithms to specialize on this. Many, such as Pardiso, have a lazy transpose option in the solver so we'd use that bool to flip the option.

vpuri3 commented 2 years ago

@MKAbdElrahman did the issue get resolved?

ChrisRackauckas commented 2 years ago

It did not

vpuri3 commented 2 years ago

In LinearCache, we can add a flag symmetric defaulting to false and a field Atransp defaulting to Adjoint(A), the lazy wrapper. And then to solve the adjoint problem via

solve(prob, alg, adjoint=true)

vpuri3 commented 2 years ago

maybe change terminology to make it more clear for complex eltypes

vpuri3 commented 2 years ago

@ChrisRackauckas is Adjoint(::DiffEqArrayOperator) defined?

ChrisRackauckas commented 2 years ago

No. What needs to be done is the operator interface documentation should get a note about adjoint(::SciMLOperator) as being a part of the (optional) interface, required for reverse mode automatic differentiation. Then, adjoint(::DiffEqArrayOperator) should be added to SciMLBase by just taking the adjoint of the internal array. Many other operators would have to be handled though, but that can be done over time.

vpuri3 commented 2 years ago

sounds good, is solve(prob, alg, adjoint=true)the standard interface for solving a joint problem in diffeq?

ChrisRackauckas commented 2 years ago

Not in solve, it's not a solve level thing. It changes the result, so it's not a solver control but instead something to do with the problem. The real question is whether the answer is for it to just be LinearProblem(A',b) or LinearProblem(A,b,adjoint=true). The advantage of the latter is that it makes it easier to do general implementations of adjoints for operators which only define A*x, since then it can use AbstractDifferentiation.jl behind the scenes. However, you don't have to do that if adjoint(A) already exists. So we need has_adjoint(A::AbstractSciMLOperator) and adjoint(A::AbstractSciMLOperator) in the interface in order to write it effectively. Also, I'm not so sure we want the AD overloads though, just because of the dependencies that would give. So it would need to do something like DiffEqSensitivity, where by default an error is raised saying to using DiffEqSensitivity to get the adjoint overloads, unless has_adjoint(A). This complexity is eliminated if you just assume LinearProblem(A',b) is the way to do it.

vpuri3 commented 2 years ago

LinearProblem(prob, alg, adjoint=true) would be good for factorizations since adjoint of LinearAlgebra.Factorization is defined. For iterative methods, etc we can fallback to adjoint wrapper with a has_adjoint check