Closed Red-Portal closed 2 years ago
hasn't been done yet in ChainRules.jl, which makes me wonder if that's against the current policy?
It is. ChainRules only contains the rules for Base and for standard libraries.
CUDA.jl should depend on ChainRulesCore.jl and add rules itself. Or we should tweak the rules in ChainRules.jl to be a bit more generic.
Hi, I think
Or we should tweak the rules in ChainRules.jl to be a bit more generic.
would be ideal, but it's not currently possible to do in an efficient way because the rrule
requires a right-to-left triangular-dense
multiplication, which is not a thing for rdiv!
. (The Julia docs state performance reasons, but I'm not sure if this is a valid reason for this not existing.) The ideal scenario would involve upstream willing to add a triangular-dense
specialization for rdiv!
but that would involve a lot of beaurocracy. Any thoughts?
but it's not currently possible to do in an efficient way because the
rrule
requires a right-to-lefttriangular-dense
multiplication, which is not a thing forrdiv!
. (The Julia docs state performance reasons, but I'm not sure if this is a valid reason for this not existing.) The ideal scenario would involve upstream willing to add atriangular-dense
specialization forrdiv!
but that would involve a lot of beaurocracy.
Can you clarify which rdiv!
specialization is missing? Our rule can effectively be written as
function _cholesky_pullback(A, C::Cholesky, ΔC::Tangent)
Ā = similar(C.factors)
U = C.U
Ū = ΔC.U
mul!(Ā, Ū, U')
LinearAlgebra.copytri!(Ā, 'U', true)
ldiv!(U, Ā)
rdiv!(Ā, U')
rmul!(Ā, one(eltype(Ā)) / 2)
return Hermitian(Ā)
end
so I assume you're referring to rdiv!(Ā, U')
? What would its signature be for CUDA?
@sethaxen Oh, I thought it was supposed to be rdiv!(U', Ā)
. So there's no problem here? It that's the case I'll create a PR.
It's weird that both ChainRules.jl
and Zygote.jl
used trsm!
instead of rdiv!
, which reinforced my misconception.
Oh, I thought it was supposed to be
rdiv!(U', Ā)
.
This would break the mathematical constraint that pushforwards/pullbacks are linear operators. So indeed, no problems here!
It that's the case I'll create a PR.
As it so happens, I'm in the middle of a PR fixing issues raised in #611 that would also fix this issue, so no need! I'll tag you in that PR so you can check that it also resolves your issue.
It's weird that both
ChainRules.jl
andZygote.jl
usedtrsm!
instead ofrdiv!
, which reinforced my misconception.
I'm not certain the provenance of this specific code, but some of our LinearAlgebra rules were adapted from Zygote, so that could be the reason. trsm!
really doesn't seem to be necessary here, but I'll profile the difference in the new PR.
All great. If you could ping me when merging the PR, I'll copy the changes back to Zygote.jl
.
If the version in #611 works with CuArrays, then there's no need for a custom rule in Zygote. In fact, it may be possible to remove https://github.com/FluxML/Zygote.jl/blob/v0.6.40/src/lib/array.jl#L567-L594 entirely.
@ToucheSir I think it will. I'll try to run some tests once this is handled and modify the PR in Zygote if that's the case.
That's right, @ToucheSir. It might be 2 PRs here, but my goal is for all Cholesky-specific adjoints in Zygote to be removed. As noted in https://github.com/JuliaDiff/ChainRules.jl/issues/611#issuecomment-1149650865, there's already a Zygote PR that started to do this. It might be stale though,
Hi, it recently turned out that the adjoint for the Cholesky on
Zygote.jl
is not GPU compatible as it explicitly callsBLAS.trsm!
. At the given moment, there doesn't seem to be a computationally efficient solution to that except to write a GPU specialization that calls toCUBLAS.trsm!
. See the discussion in Zygote.Currently,
ChainRules.jl
has the same issue: https://github.com/JuliaDiff/ChainRules.jl/blob/e4029dfe651c6483ff92da1de3241c0c94bd3256/src/rulesets/LinearAlgebra/factorization.jl#L513-L519But it seems that something like
hasn't been done yet in
ChainRules.jl
, which makes me wonder if that's against the current policy?