invenia / PDMatsExtras.jl

Extra Positive (Semi-)Definite Matricies
MIT License
8 stars 6 forks source link

Generalizing WoodburyPDMat #25

Open sethaxen opened 3 years ago

sethaxen commented 3 years ago

WoodburyPDMat currently requires that S and D both be positive and diagonal. Technically neither of these needs to be the case. For example, in the L-BFGS algorithm, the approximate inverse Hessian that is constructed begins with a diagonal PD S and constructs D to be symmetric but non-PD, but such that S+ADA' is guaranteed to be PD.

Pathfinder.jl includes a WoodburyPDMat implementation that releases these constraints and uses a different set of decompositions to get the necessary overloads, but ideally this implementation would live in a more general package. Would you be interested in this implementation being integrated here?

sethaxen commented 2 years ago

@oxinabox maybe? Thoughts on this?

oxinabox commented 2 years ago

@sdl1 @wytbella are more familiar with the math and can comment on if that is desirable.

sethaxen commented 2 years ago

@sdl1 @wytbella, any input on this would be appreciated.

wytbella commented 2 years ago

I think we can relax the constraints here, but @willtebbutt is probably a better person to confirm this.

willtebbutt commented 2 years ago

Yup, that seems fine to me. I agree that the important property is that S+ADA' is positive definite -- it doesn't really matter whether either S or D are either positive-definite or diagonal.

sethaxen commented 2 years ago

Okay, I'll open a PR.

sethaxen commented 2 years ago

Heh, may be a while. This code is a bit foggy in my brain now. Either way, this will be a breaking change, as the decompositions needed when A and D are not diagonal should really be done upon construction, and they may be different from the decompositions one would want when they are diagonal. So we would also be doing #3. I expect that also means we will need a hand-written rrule for the constructor, which shouldn't be hard.

willtebbutt commented 2 years ago

No problem. If you get a minute, could you elaborate on why this is likely to need a hand-written constructor rrule?

sethaxen commented 2 years ago

I recall Zygote throws an error requesting a rule when it hits custom constructors of arrays. But maybe that's just when there are multiple constructors.

willtebbutt commented 2 years ago

I think that's just for inner constructors. If you handle it with an outer, it should be fine.

sethaxen commented 2 years ago

Ah good to know. @willtebbutt, I noticed validate_woodbury_arguments currently only checks that diagonal elements are non-negative. Is this type meant to represent positive semi-definite matrices as well?

willtebbutt commented 2 years ago

Good question. I think this is an oversight, rather than something intentional. e.g. we regularly compute log determinants of these matrices, and assume they'll be finite.

sethaxen commented 2 years ago

Ah, the main blocker here will be that there's currently no ChainRules rrule for qr, and since the primal function mutates, Zygote throws an error. From first glance, it doesn't look like there's a simpler rule we can use that doesn't involve the rrule for qr, which is really complicated. Looks like it might be time to try to get this in ChainRules.

Yup, that seems fine to me. I agree that the important property is that S+ADA' is positive definite -- it doesn't really matter whether either S or D are either positive-definite or diagonal.

Ah, I should note that the generalization still requires that S is positive definite. Otherwise our factorizations don't work. Also, if both S and D are diagonal and positive definite, the generalized implementations will be slower than the current implementations, since the generalization trades a thin QR decomposition for a diagonal Cholesky decomposition (necessary when D is not PD). This might be a reason to not compute the decompositions in the constructor and instead in each function choose the decomposition we use based on the satisfied constraints. If we went that route, we could make a non-breaking generalization, where Zygote would only fail if the current assumptions are not met.

sethaxen commented 2 years ago

e.g. we could do this:

function LinearAlgebra.logdet(W::WoodburyPDMat)
    C_S = cholesky(W.S)
    C = if W.D isa Diagonal && isposdef(W.D)
        C_D = cholesky(W.D)
        B = C_S.U' \ (W.A * C_D.U')
        Symmetric(muladd(B', B, I))
    else
        R = qr(C_S.U' \ W.A).R
        Symmetric(muladd(R, W.D * R', I))
    end
    C_C = cholesky(C)
    return logdet(C_S) + logdet(C_C)
end
willtebbutt commented 2 years ago

Ahhh I see. I'm slightly concerned about what this will do for type stability (if logdet is currently type-stable). I guess it might need a high-level rrule to circumvent any inference problems.

sethaxen commented 2 years ago

This should be type-stable. One could design a pathological case where e.g. D has special overloads for both left multiplication by UpperTriangular and right multiplication by LowerTriangular that produces a different type from B'B, but this seems unlikely to happen in the real world. I could check what happens if sparse arrays are used.

willtebbutt commented 2 years ago

Sorry, I mean on the reverse-pass, specifically with Zygote. It doesn't handle conditionals well, so it's likely that even though the primal should be stable, and each of the rules for each of the expressions is type stable, the forwards / reverse Zygote passes won't be.

sethaxen commented 2 years ago

Ugh, right. Well thankfully, with rrule_via_ad, that shouldn't be hard.