Closed lezcano closed 3 years ago
Let me try to disentangle this ticket; there seem to be several subquestions.
the implementation and the differences with torch.householder_product
First, this aspect does not pertain to the exemplary parameterization. The function provided by this package aims to replicate the exact interface and behavior of the ORGQR
function, which it does. The original interface expects two arguments: a matrix of reflectors v
and the vector tau
, computed from their magnitudes as 2 / ||v||^2
. I can only speculate why tau
(which can be computed from the first argument) was exposed as an argument in the original LAPACK interface; perhaps this was done in order to avoid an expensive square root computation when computing reflector matrices I - tau * v * v.T
instead of I - 2 * u * u.T
(where u = v / ||v||
) . For this reason, I allowed myself to extend the interface for the case when only the first argument is provided and unified both cases by normalizing the reflectors, which worked faster according to my benchmarks. Nevertheless, this normalization aspect is nothing but a performance hack and has little to do with parameterizations.
Now, this is another constraint. Each column is normalised, which removes another k degrees of freedom (k dimensions to be precise).
Let us consider the case of St(2,1). We can parameterize it with v=[1, p].T
, where p is a real value and is the only degree of freedom. Now, does normalization remove the one and the only degree of freedom? It does not: the knowledge of p
does not give any new information about 1
, and vice-versa. This argument will not hold if, instead of 1
, we put another real-valued parameter of the reflector, but that would yield a different parameterization. It would still be compatible with the orgqr
/ormqr
/householder_product
interfaces, but before plugging it into unconstrained optimization, one would have to constrain the reflector norm to be nonzero (which comes for free with ones on the diagonal).
It is true that such a parameterization (with ones on the diagonal) does not represent all possible matrices; in the example with St(2,1), such a reflector will cover only (-pi/2, pi/2) range, which parameterizes the entire circle except just one vector from St(2,1), corresponding to [1,0].T
. The design behind the original set of LAPACK functions (including geqrf
) deals with this situation by always choosing the one reflector out of the possible two, which leads to better numerical stability. This is implemented by negating the diagonal entries of the (initially identity) matrix R in the QR decomposition of the orthonormal frame. In a differentiable parameterization, we cannot flip signs on the diagonal of R, and thus we can either keep R=I
and admit to the inability to parameterize the truncated identity matrix or choose R=-I
and let go of the negative truncated identity. Of course, reaching representations close to the critical point in an unconstrained setting is a subject for a separate study. Nevertheless, this issue pertains only to the selected parameterization (with ones on the diagonal) -- this is the price to pay for the guaranteed nonzero norm of the reflector parameters.
Please let me know if this answers your questions, or we can continue the conversation via email.
When having a look at the implementation and looking at the differences with
torch.householder_product
, I found a weird thing.At the moment, when called with a tensor of the form
hh = param.tril(diagonal=-1) + torch.eye(d, r)
(as per the documentation) we are passingnk - k(k+1)/2
parameters. This is correct, as it is the number of parameters necessary to parametrise the orthogonal matrices (i.e. it is the dimension of the Stiefel manifold).Now, when this matrix is passed to
torch_householder_orgqr
, it columns are normalised: https://github.com/toshas/torch-householder/blob/afe72ebf9a01c257a5070a6d01b3f81be9e8efd6/torch_householder/householder.py#L67 Now, this is another constraint. Each column is normalised, which removes anotherk
degrees of freedom (k
dimensions to be precise). This means that the current implementation cannot represent all the possible orthogonal matrices.Do you know what is going on here?