JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
435 stars 89 forks source link

Need a GPU compatible `rrule` for Cholesky #629

Closed Red-Portal closed 2 years ago

Red-Portal commented 2 years ago

Hi, it recently turned out that the adjoint for the Cholesky on Zygote.jl is not GPU compatible as it explicitly calls BLAS.trsm!. At the given moment, there doesn't seem to be a computationally efficient solution to that except to write a GPU specialization that calls to CUBLAS.trsm!. See the discussion in Zygote.

Currently, ChainRules.jl has the same issue: https://github.com/JuliaDiff/ChainRules.jl/blob/e4029dfe651c6483ff92da1de3241c0c94bd3256/src/rulesets/LinearAlgebra/factorization.jl#L513-L519

But it seems that something like

 @init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
    # CUDA-compatible Specialization
 end

hasn't been done yet in ChainRules.jl, which makes me wonder if that's against the current policy?

oxinabox commented 2 years ago

hasn't been done yet in ChainRules.jl, which makes me wonder if that's against the current policy?

It is. ChainRules only contains the rules for Base and for standard libraries.

CUDA.jl should depend on ChainRulesCore.jl and add rules itself. Or we should tweak the rules in ChainRules.jl to be a bit more generic.

Red-Portal commented 2 years ago

Hi, I think

Or we should tweak the rules in ChainRules.jl to be a bit more generic.

would be ideal, but it's not currently possible to do in an efficient way because the rrule requires a right-to-left triangular-dense multiplication, which is not a thing for rdiv!. (The Julia docs state performance reasons, but I'm not sure if this is a valid reason for this not existing.) The ideal scenario would involve upstream willing to add a triangular-dense specialization for rdiv! but that would involve a lot of beaurocracy. Any thoughts?

sethaxen commented 2 years ago

but it's not currently possible to do in an efficient way because the rrule requires a right-to-left triangular-dense multiplication, which is not a thing for rdiv!. (The Julia docs state performance reasons, but I'm not sure if this is a valid reason for this not existing.) The ideal scenario would involve upstream willing to add a triangular-dense specialization for rdiv! but that would involve a lot of beaurocracy.

Can you clarify which rdiv! specialization is missing? Our rule can effectively be written as

function _cholesky_pullback(A, C::Cholesky, ΔC::Tangent)
    Ā = similar(C.factors)
    U = C.U
    Ū = ΔC.U
    mul!(Ā, Ū, U')
    LinearAlgebra.copytri!(Ā, 'U', true)
    ldiv!(U, Ā)
    rdiv!(Ā, U')
    rmul!(Ā, one(eltype(Ā)) / 2)
    return Hermitian(Ā)
end

so I assume you're referring to rdiv!(Ā, U')? What would its signature be for CUDA?

Red-Portal commented 2 years ago

@sethaxen Oh, I thought it was supposed to be rdiv!(U', Ā). So there's no problem here? It that's the case I'll create a PR.

It's weird that both ChainRules.jl and Zygote.jl used trsm! instead of rdiv!, which reinforced my misconception.

sethaxen commented 2 years ago

Oh, I thought it was supposed to be rdiv!(U', Ā).

This would break the mathematical constraint that pushforwards/pullbacks are linear operators. So indeed, no problems here!

It that's the case I'll create a PR.

As it so happens, I'm in the middle of a PR fixing issues raised in #611 that would also fix this issue, so no need! I'll tag you in that PR so you can check that it also resolves your issue.

It's weird that both ChainRules.jl and Zygote.jl used trsm! instead of rdiv!, which reinforced my misconception.

I'm not certain the provenance of this specific code, but some of our LinearAlgebra rules were adapted from Zygote, so that could be the reason. trsm! really doesn't seem to be necessary here, but I'll profile the difference in the new PR.

Red-Portal commented 2 years ago

All great. If you could ping me when merging the PR, I'll copy the changes back to Zygote.jl.

ToucheSir commented 2 years ago

If the version in #611 works with CuArrays, then there's no need for a custom rule in Zygote. In fact, it may be possible to remove https://github.com/FluxML/Zygote.jl/blob/v0.6.40/src/lib/array.jl#L567-L594 entirely.

Red-Portal commented 2 years ago

@ToucheSir I think it will. I'll try to run some tests once this is handled and modify the PR in Zygote if that's the case.

sethaxen commented 2 years ago

That's right, @ToucheSir. It might be 2 PRs here, but my goal is for all Cholesky-specific adjoints in Zygote to be removed. As noted in https://github.com/JuliaDiff/ChainRules.jl/issues/611#issuecomment-1149650865, there's already a Zygote PR that started to do this. It might be stale though,