Closed cortner closed 3 years ago
Yeah, I think that makes sense det
would use LU
under the hook.
but if we computed the LU ourselves we could reuse that facrorization to do /x_lu
instead of *inv(x)
in the pullback, right?
It would be one of the cases where we can benifit from changing the primal computation
Yes that was my point.
One difficulty (??) - what if the LinearAlgebra implementation changes? Unlikely here but maybe not in general. I guess ChainRules needs to keep track of such changes? But is it also a bit concerning that this strategy requires code duplication?
what if the LinearAlgebra implementation changes? Unlikely here but maybe not in general.
The math remains the same even if implementation changes. we don't promise the same primal answer as running the original primal function call. we only suggest that we will meet the same accuracy thresholds where documented. (just like julia versions) If the primal function is 5ULP, then we will aim also to be 5ULP but might be wrong in opposite direction.
But is it also a bit concerning that this strategy requires code duplication?
It's how it is. Sometimes you want to modify the primal for this reason. You need some extra internal information from the middle like this, (or you want to calculate it in a totally different way)
In theory not having a rule can be better, if the optimizer can inline everything and see the common subexpressions (like computing lu
for det
and for inv
) and eliminate them.
In practice that is very rare (and doing it between pullback and primal computation is impossible without opaque closures)
Thanks for the thoughts
IIRC, we didn't use lu
here because this rule is for abstract types, meaning it will in an almost type-piratical way make any specialized implementations of det
for some other matrix type invisible to the AD, where instead we are doing something generic. That specialized matrix type might have some other very efficient way to compute det
that doesn't use the LU decomposition, and the decomposition may be sloooow for that type. So to try to salvage this bad situation, we potentially do twice the work, i.e. implicitly decomposing twice.
I think it is a little worse than that. By calling inv
you are assembling the inverse matrix, which should almost never be done. Even when computed via LU, is this even numerically stable? Ok to have a fallback, but I feel at least for any matrix that has an LU decomposition defined one should do it via LU rather than inv.
we can at least add the LU path for StridedMatrix{<:BLASFLoat}
which IIRC should all have LU.
I think it is a little worse than that. By calling
inv
you are assembling the inverse matrix, which should almost never be done. Even when computed via LU, is this even numerically stable? Ok to have a fallback, but I feel at least for any matrix that has an LU decomposition defined one should do it via LU rather than inv.
If you work out the rrule for computing the determinant from the LU decomposition and then compose it with the rrule for the LU decomposition, you end up with (conj(d) * Δd) * inv(F)'
, where F
is the LU decomposition, d
is the determinant, and Δd
is the cotangent of the determinant, and you'll notice this is identical to our rrule for det
, so there's no way around computing the inverse here.
From the LU decomposition, the matrix inverse can be computed quickly, cheaply, and in-place using two applications of backwards substitution, so the only remaining question is one of stability. The matrix inverse does not exist exactly when the determinant is exactly zero. We don't currently do any special-casing for the zero-determinant case, but perhaps we should. By the subgradient convenient, the cotangent should then be the zero matrix.
we can at least add the LU path for
StridedMatrix{<:BLASFLoat}
which IIRC should all have LU.
This would be safe to do. Since I think det
is almost always computed from a factorization, the best solution ultimately is probably to ensure we have rules for all common factorizations, remove this det
rrule entirely, and then only add a few det
rules with factorization arguments if necessary. For det from lu
, this will probably be a little less efficient than calling lu
in the rrule, because instead of just calling inv(F)
, it goes through a few more steps. I think the main factorization we're missing rules for here is Bunch-Kaufman for Symmetric{<:BlasFloat}
matrices.
When I take grad(det(A))
, sure, there is no other way, except possibly have a lazy representation of the inverse (which I'd personally prefer but appreciate this may lead to other problems).
However, if A = A(p)
and I want grad(det(A(p))
then wouldn't the process be
A
+ store the pullback ddet -> dA(p)'[ddet]
det(A)
+ ddet
when applying the pullback would I not want to use the LU factorisation instead of the "collected" matrix?
However, if
A = A(p)
and I wantgrad(det(A(p))
then wouldn't the process be
- compute
A
+ store the pullbackddet -> dA(p)'[ddet]
- compute
det(A)
+ddet
I wasn't able to follow. What do these terms mean?
f(p) = det(A(p))
Forward pass:
A
and the pullback closure ddet -> adjoint(dA(p)) * ddet
det(A)
and ddet = ddet(A)
Backward pass:
grad_f(p) = adjoint(dA(p)) * ddet
(perfectly possible I'm missing something or getting something wrong here ... I only just started to think about AD at the implementation level.)
I'm sorry, it's still not entirely clear to me what your notation means. e.g. you seem to be using the prefix of d
to denote a cotangent, but cotangents cannot be multiplied by each other (due to linearity of the pullback operator), so I'm not certain what adjoint(dA(p)) * ddet
means.
I'm appying an adjoint operator - this application is denotes by *
, I'm not multiplying two tangents. I apologize if my notation is not precise as you'd like it but as an expert you can probably make an educated guess what I mean.
So I'm trying to write down a concrete use-case, and at least the ones I had in mind when posting this seem to be more relevant for frule
than rrule
. I will momentarily edit the first post to add this point.
Re the frule
:
D[ det(A) ] = det(A) tr[ A \ DA ]
and for the A \ DA
operation one can again reuse the factorisation. (I'm specifically interested in cases where DA
has special structure, e.g., low rank.)
But I appreciate the problem of a generic vs specialised implementation.
I'll write something else on rrule
later.
Returning to rrule
- the following is not quite my use-case, but reasonably close and doesn't require further explanation. I'm actually confused how the rrule
applies here?
We have parameters p = (pj)
and
L = sum_i f( det(A(xi, p)) - yi )
with D = d / dpj, Ai = A(xi,p), fi' = f'(det(Ai) - yi), DAi = DA(xi, p), then
DL = sum_i fi' * det(Ai) * tr[ Ai \ DAi ]
I'm now struggling to re-order the operations to see the backpropagation. Is it simply this?
DL = tr[ sum_i { fi' * det Ai * inv(Ai) } * DAi ]
? I.e. the backpropagation would be:
[ fi' ]
Gi = det Ai / fi' * inv(Ai)
sum_i Gi \ DAi
tr
And would you agree that this indicates I should use the frule
in such a scenario?
If what I've written is correct, then there is still the issue left that "collecting" inv(Ai)
is not necessary and for numerical stability could be replace with a lazy inverse operation?
I now think this issue is irrelevant, will close reopen a new issue.
I noticed the
rrule
fordet
does not use the LU factorisation. Is this intentional? Or is it implicit?EDIT: implicitly both
rrule
andfrule
do use the factorisation but it could be reused for effiicency. Forrrule
a question of numerical stability remains but not clear to me yet if resolvable?