Open sethaxen opened 3 years ago
@oxinabox maybe? Thoughts on this?
@sdl1 @wytbella are more familiar with the math and can comment on if that is desirable.
@sdl1 @wytbella, any input on this would be appreciated.
I think we can relax the constraints here, but @willtebbutt is probably a better person to confirm this.
Yup, that seems fine to me. I agree that the important property is that S+ADA'
is positive definite -- it doesn't really matter whether either S
or D
are either positive-definite or diagonal.
Okay, I'll open a PR.
Heh, may be a while. This code is a bit foggy in my brain now. Either way, this will be a breaking change, as the decompositions needed when A
and D
are not diagonal should really be done upon construction, and they may be different from the decompositions one would want when they are diagonal. So we would also be doing #3. I expect that also means we will need a hand-written rrule
for the constructor, which shouldn't be hard.
No problem. If you get a minute, could you elaborate on why this is likely to need a hand-written constructor rrule?
I recall Zygote throws an error requesting a rule when it hits custom constructors of arrays. But maybe that's just when there are multiple constructors.
I think that's just for inner constructors. If you handle it with an outer, it should be fine.
Ah good to know. @willtebbutt, I noticed validate_woodbury_arguments
currently only checks that diagonal elements are non-negative. Is this type meant to represent positive semi-definite matrices as well?
Good question. I think this is an oversight, rather than something intentional. e.g. we regularly compute log determinants of these matrices, and assume they'll be finite.
Ah, the main blocker here will be that there's currently no ChainRules rrule for qr
, and since the primal function mutates, Zygote throws an error. From first glance, it doesn't look like there's a simpler rule we can use that doesn't involve the rrule for qr
, which is really complicated. Looks like it might be time to try to get this in ChainRules.
Yup, that seems fine to me. I agree that the important property is that
S+ADA'
is positive definite -- it doesn't really matter whether eitherS
orD
are either positive-definite or diagonal.
Ah, I should note that the generalization still requires that S
is positive definite. Otherwise our factorizations don't work. Also, if both S
and D
are diagonal and positive definite, the generalized implementations will be slower than the current implementations, since the generalization trades a thin QR decomposition for a diagonal Cholesky decomposition (necessary when D
is not PD). This might be a reason to not compute the decompositions in the constructor and instead in each function choose the decomposition we use based on the satisfied constraints. If we went that route, we could make a non-breaking generalization, where Zygote would only fail if the current assumptions are not met.
e.g. we could do this:
function LinearAlgebra.logdet(W::WoodburyPDMat)
C_S = cholesky(W.S)
C = if W.D isa Diagonal && isposdef(W.D)
C_D = cholesky(W.D)
B = C_S.U' \ (W.A * C_D.U')
Symmetric(muladd(B', B, I))
else
R = qr(C_S.U' \ W.A).R
Symmetric(muladd(R, W.D * R', I))
end
C_C = cholesky(C)
return logdet(C_S) + logdet(C_C)
end
Ahhh I see. I'm slightly concerned about what this will do for type stability (if logdet
is currently type-stable). I guess it might need a high-level rrule to circumvent any inference problems.
This should be type-stable. One could design a pathological case where e.g. D
has special overloads for both left multiplication by UpperTriangular
and right multiplication by LowerTriangular
that produces a different type from B'B
, but this seems unlikely to happen in the real world. I could check what happens if sparse arrays are used.
Sorry, I mean on the reverse-pass, specifically with Zygote. It doesn't handle conditionals well, so it's likely that even though the primal should be stable, and each of the rules for each of the expressions is type stable, the forwards / reverse Zygote passes won't be.
Ugh, right. Well thankfully, with rrule_via_ad
, that shouldn't be hard.
WoodburyPDMat
currently requires thatS
andD
both be positive and diagonal. Technically neither of these needs to be the case. For example, in the L-BFGS algorithm, the approximate inverse Hessian that is constructed begins with a diagonal PDS
and constructsD
to be symmetric but non-PD, but such thatS+ADA'
is guaranteed to be PD.Pathfinder.jl includes a
WoodburyPDMat
implementation that releases these constraints and uses a different set of decompositions to get the necessary overloads, but ideally this implementation would live in a more general package. Would you be interested in this implementation being integrated here?