lezcano / geotorch

Constrained optimization toolkit for PyTorch
https://geotorch.readthedocs.io
MIT License
648 stars 33 forks source link

Can two parametrizations be used in the same tensor? #42

Open dherrera1911 opened 6 months ago

dherrera1911 commented 6 months ago

Hello,

I am wondering whether two parametrizations can be used on a same tensor. In particular, I want to parametrize a matrix to be both positive definite and in the SL group.

In the example below, I create a class with a matrix that should be both PSD and SLN:

class PrNorm(torch.nn.Module):
    def __init__(self, nDim):
        super().__init__()
        self.B = nn.Parameter(torch.eye(nDim, requires_grad=True))
        geotorch.positive_semidefinite(self, "B")
        geotorch.sln(self, "B")

    def forward(self, x):
        quadratic = torch.einsum('i,ij,j->', x, self.B, x)
        return quadratic

prnorm = PrNorm(nDim)

However, I get the following error when I initialize the class:

InManifoldError: Tensor not contained in PSSD(
  n=7
  (0): Stiefel(n=7, k=7, triv=linalg_matrix_exp)
  (1): Rn(n=7)
). Got

I think that the parametrized tensor gets re-initialized by SLN, and so it is no longer PSD, leading to the error. Is there some way to do something as intended here with geotorch?

lezcano commented 6 months ago

Alas, these parametrisations are not compositional as-is, but it shouldn't be too difficult to implement your own. Note that the class PSD is implemented in terms of the class PSSDFixedRank by fixing the rank to be the highest and choosing a particular transformation like softplus https://github.com/lezcano/geotorch/blob/ba38d406c245d609fee4b4dac3f6427bf6d73a8e/geotorch/psd.py#L6 On the other hand, sl(n) is implemented in a similar way, but tweaking the function that's used to define the eigenvalues (we normalise them so that their product is one): https://github.com/lezcano/geotorch/blob/ba38d406c245d609fee4b4dac3f6427bf6d73a8e/geotorch/sl.py#L31-L42

As such, you can inherit from PSD and implement your own class where you pass in a normalised function like we do in SL(n) to the parent class. You'll also need to implement the method in_manifold_eigen and sample, similar to how SL(n) does it.

Please reach out if there is anything that's not clear :)