Closed jglaser2 closed 3 years ago
Hi @jglaser2!
Given that the ObliqueManifold(n,k)
is just k
copies of the sphere, then we can just use geotorch.sphere
to do so.
m = geotorch.sphere(torch.nn.Linear(7, 3), "weight")
# m.weight is of shape (3, 7) with 3 rows of size 7, each of which lies in the sphere.
assert torch.allclose((m.weight * m.weight).sum(dim=-1), torch.ones(3))
Thanks for the quick response!
I'm using a layer of size (8,50) to go from a lower-dimensional space to a higher-dimensional space. I'd like to have 8 vectors of unit norm rather than 50 vectors of unit norm.
When using the Sphere class as you've described, however, it produces 50 vectors of unit norm.
I was previously using the "orthogonal"/Stiefel class, which did result in 8 vectors of unit norm as I wanted. I realize the reason the Stiefel class does this is because it transposes the data when size[-2] < size[-1] (line 31 in stiefel.py).
Should I modify the Sphere class to transpose the data in the same way, or do you have a different recommendation? Thanks!
The Stiefel constraint is a much harder one than the one you want.
If you want to put it in the other dimension, you could write a Transpose
parametrisation as:
class Transpose(nn.Module):
def forward(self, X):
return X.transpose(-2, -1)
def right_inverse(self, X):
return X.transpose(-2, -1)
and then use it as:
m = nn.Linear(8, 50)
torch.nn.utils.parametrize.register_parametrization(m, "weight", Transpose())
geotorch.sphere(m, "weight")
torch.nn.utils.parametrize.register_parametrization(m, "weight", Transpose())
I realised that the previous code would just work on PyTorch master using the unsafe
flag. Since the sphere is such a simple manifold, it might be best to implement the constraint from scratch as:
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as P
class Sphere(nn.Module):
def __init__(self, dim=-1):
super().__init__()
self.dim = dim
def forward(self, x):
return x / x.norm(dim=self.dim, keepdim=True)
m = P.register_parametrization(nn.Linear(8, 50), "weight", Sphere(dim=0))
assert torch.allclose((m.weight * m.weight).sum(dim=0), torch.ones(8))
That new solution worked well! (as you said, the previous one threw an error). Thanks again for your help, it's a great package!
Could you provide some guidance on how to create ObliqueManifold(n,k) that you mention in the documentation? Thanks for your help!