lezcano / geotorch

Constrained optimization toolkit for PyTorch
https://geotorch.readthedocs.io
MIT License
656 stars 34 forks source link

ObliqueManifold #22

Closed jglaser2 closed 3 years ago

jglaser2 commented 3 years ago

Could you provide some guidance on how to create ObliqueManifold(n,k) that you mention in the documentation? Thanks for your help!

lezcano commented 3 years ago

Hi @jglaser2!

Given that the ObliqueManifold(n,k) is just k copies of the sphere, then we can just use geotorch.sphere to do so.

m =  geotorch.sphere(torch.nn.Linear(7, 3), "weight")
# m.weight is of shape (3, 7) with 3 rows of size 7, each of which lies in the sphere.
assert torch.allclose((m.weight * m.weight).sum(dim=-1), torch.ones(3))
jglaser2 commented 3 years ago

Thanks for the quick response!

I'm using a layer of size (8,50) to go from a lower-dimensional space to a higher-dimensional space.  I'd like to have 8 vectors of unit norm rather than 50 vectors of unit norm.

When using the Sphere class as you've described, however, it produces 50 vectors of unit norm.

I was previously using the "orthogonal"/Stiefel class, which did result in 8 vectors of unit norm as I wanted. I realize the reason the Stiefel class does this is because it transposes the data when size[-2] < size[-1] (line 31 in stiefel.py).

Should I modify the Sphere class to transpose the data in the same way, or do you have a different recommendation? Thanks!

lezcano commented 3 years ago

The Stiefel constraint is a much harder one than the one you want. If you want to put it in the other dimension, you could write a Transpose parametrisation as:

class Transpose(nn.Module):
  def forward(self, X):
    return X.transpose(-2, -1)
  def right_inverse(self, X):
    return X.transpose(-2, -1)

and then use it as:

m = nn.Linear(8, 50)
torch.nn.utils.parametrize.register_parametrization(m, "weight", Transpose())
geotorch.sphere(m, "weight")
torch.nn.utils.parametrize.register_parametrization(m, "weight", Transpose())
lezcano commented 3 years ago

I realised that the previous code would just work on PyTorch master using the unsafe flag. Since the sphere is such a simple manifold, it might be best to implement the constraint from scratch as:

import torch
import torch.nn as nn
import torch.nn.utils.parametrize as P

class Sphere(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim
    def forward(self, x):
        return x / x.norm(dim=self.dim, keepdim=True)

m = P.register_parametrization(nn.Linear(8, 50), "weight", Sphere(dim=0))
assert torch.allclose((m.weight * m.weight).sum(dim=0), torch.ones(8))
jglaser2 commented 3 years ago

That new solution worked well! (as you said, the previous one threw an error). Thanks again for your help, it's a great package!