JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
434 stars 89 forks source link

Determinant could be faster if use LU #456

Closed cortner closed 3 years ago

cortner commented 3 years ago

I noticed the rrule for det does not use the LU factorisation. Is this intentional? Or is it implicit?

EDIT: implicitly both rrule and frule do use the factorisation but it could be reused for effiicency. For rrule a question of numerical stability remains but not clear to me yet if resolvable?

oxinabox commented 3 years ago

Yeah, I think that makes sense det would use LU under the hook. but if we computed the LU ourselves we could reuse that facrorization to do /x_lu instead of *inv(x) in the pullback, right?

https://github.com/JuliaDiff/ChainRules.jl/blob/37a9e9b3427c4ca80cd66a5252ab04e8110a414e/src/rulesets/LinearAlgebra/dense.jl#L87-L94

It would be one of the cases where we can benifit from changing the primal computation

cortner commented 3 years ago

Yes that was my point.

One difficulty (??) - what if the LinearAlgebra implementation changes? Unlikely here but maybe not in general. I guess ChainRules needs to keep track of such changes? But is it also a bit concerning that this strategy requires code duplication?

oxinabox commented 3 years ago

what if the LinearAlgebra implementation changes? Unlikely here but maybe not in general.

The math remains the same even if implementation changes. we don't promise the same primal answer as running the original primal function call. we only suggest that we will meet the same accuracy thresholds where documented. (just like julia versions) If the primal function is 5ULP, then we will aim also to be 5ULP but might be wrong in opposite direction.

But is it also a bit concerning that this strategy requires code duplication?

It's how it is. Sometimes you want to modify the primal for this reason. You need some extra internal information from the middle like this, (or you want to calculate it in a totally different way)

In theory not having a rule can be better, if the optimizer can inline everything and see the common subexpressions (like computing lu for det and for inv) and eliminate them. In practice that is very rare (and doing it between pullback and primal computation is impossible without opaque closures)

cortner commented 3 years ago

Thanks for the thoughts

sethaxen commented 3 years ago

IIRC, we didn't use lu here because this rule is for abstract types, meaning it will in an almost type-piratical way make any specialized implementations of det for some other matrix type invisible to the AD, where instead we are doing something generic. That specialized matrix type might have some other very efficient way to compute det that doesn't use the LU decomposition, and the decomposition may be sloooow for that type. So to try to salvage this bad situation, we potentially do twice the work, i.e. implicitly decomposing twice.

cortner commented 3 years ago

I think it is a little worse than that. By calling inv you are assembling the inverse matrix, which should almost never be done. Even when computed via LU, is this even numerically stable? Ok to have a fallback, but I feel at least for any matrix that has an LU decomposition defined one should do it via LU rather than inv.

oxinabox commented 3 years ago

we can at least add the LU path for StridedMatrix{<:BLASFLoat} which IIRC should all have LU.

sethaxen commented 3 years ago

I think it is a little worse than that. By calling inv you are assembling the inverse matrix, which should almost never be done. Even when computed via LU, is this even numerically stable? Ok to have a fallback, but I feel at least for any matrix that has an LU decomposition defined one should do it via LU rather than inv.

If you work out the rrule for computing the determinant from the LU decomposition and then compose it with the rrule for the LU decomposition, you end up with (conj(d) * Δd) * inv(F)', where F is the LU decomposition, d is the determinant, and Δd is the cotangent of the determinant, and you'll notice this is identical to our rrule for det, so there's no way around computing the inverse here.

From the LU decomposition, the matrix inverse can be computed quickly, cheaply, and in-place using two applications of backwards substitution, so the only remaining question is one of stability. The matrix inverse does not exist exactly when the determinant is exactly zero. We don't currently do any special-casing for the zero-determinant case, but perhaps we should. By the subgradient convenient, the cotangent should then be the zero matrix.

we can at least add the LU path for StridedMatrix{<:BLASFLoat} which IIRC should all have LU.

This would be safe to do. Since I think det is almost always computed from a factorization, the best solution ultimately is probably to ensure we have rules for all common factorizations, remove this det rrule entirely, and then only add a few det rules with factorization arguments if necessary. For det from lu, this will probably be a little less efficient than calling lu in the rrule, because instead of just calling inv(F), it goes through a few more steps. I think the main factorization we're missing rules for here is Bunch-Kaufman for Symmetric{<:BlasFloat} matrices.

cortner commented 3 years ago

When I take grad(det(A)), sure, there is no other way, except possibly have a lazy representation of the inverse (which I'd personally prefer but appreciate this may lead to other problems).

However, if A = A(p) and I want grad(det(A(p)) then wouldn't the process be

when applying the pullback would I not want to use the LU factorisation instead of the "collected" matrix?

sethaxen commented 3 years ago

However, if A = A(p) and I want grad(det(A(p)) then wouldn't the process be

  • compute A + store the pullback ddet -> dA(p)'[ddet]
  • compute det(A) + ddet

I wasn't able to follow. What do these terms mean?

cortner commented 3 years ago

f(p) = det(A(p))

Forward pass:

  1. Compute A and the pullback closure ddet -> adjoint(dA(p)) * ddet
  2. Compute det(A) and ddet = ddet(A)

Backward pass:

  1. apply the closure to get grad_f(p) = adjoint(dA(p)) * ddet

(perfectly possible I'm missing something or getting something wrong here ... I only just started to think about AD at the implementation level.)

sethaxen commented 3 years ago

I'm sorry, it's still not entirely clear to me what your notation means. e.g. you seem to be using the prefix of d to denote a cotangent, but cotangents cannot be multiplied by each other (due to linearity of the pullback operator), so I'm not certain what adjoint(dA(p)) * ddet means.

cortner commented 3 years ago

I'm appying an adjoint operator - this application is denotes by *, I'm not multiplying two tangents. I apologize if my notation is not precise as you'd like it but as an expert you can probably make an educated guess what I mean.

cortner commented 3 years ago

So I'm trying to write down a concrete use-case, and at least the ones I had in mind when posting this seem to be more relevant for frule than rrule. I will momentarily edit the first post to add this point.

Re the frule:

  D[ det(A) ] = det(A) tr[ A \ DA ]

and for the A \ DA operation one can again reuse the factorisation. (I'm specifically interested in cases where DA has special structure, e.g., low rank.)

But I appreciate the problem of a generic vs specialised implementation.

I'll write something else on rrule later.

cortner commented 3 years ago

Returning to rrule - the following is not quite my use-case, but reasonably close and doesn't require further explanation. I'm actually confused how the rrule applies here?

We have parameters p = (pj) and

   L = sum_i   f( det(A(xi, p)) - yi )

with D = d / dpj, Ai = A(xi,p), fi' = f'(det(Ai) - yi), DAi = DA(xi, p), then

  DL = sum_i fi' * det(Ai) * tr[ Ai \ DAi ]

I'm now struggling to re-order the operations to see the backpropagation. Is it simply this?

  DL = tr[  sum_i  { fi' * det Ai * inv(Ai) } *  DAi  ]

? I.e. the backpropagation would be:

And would you agree that this indicates I should use the frule in such a scenario?

If what I've written is correct, then there is still the issue left that "collecting" inv(Ai) is not necessary and for numerical stability could be replace with a lazy inverse operation?

cortner commented 3 years ago

I now think this issue is irrelevant, will close reopen a new issue.