Closed simonvary closed 1 year ago
Do you have a small repro?
also, hi Simon! Hope you're doing great in Belgium working with Absil :)
Hi, thanks it is going well, likewise, I hope you are also doing well after finishing the DPhil :) I added the code from the Issue 10 that I am trying to get working.
Right, so the intialize_
API is from pre 0.3. Now, the API for a manifold has the following functions: forward
, right_inverse
, in_manifold
and sample
. You'll need to implement these via flattening and unflattening for everything to work as expected.
You can have a look at Stiefel.py
to see what's the exact API for these methods. As to what they do, I think it should be clear from the name :)
https://github.com/lezcano/geotorch/blob/ba38d406c245d609fee4b4dac3f6427bf6d73a8e/geotorch/stiefel.py#L40-L47
Thank you for the advice. I tried implementing it as follows which allows me to register the parameterisation. However, I cannot seem to be able to access the weight parameter.
Basically, I just have this _unfold(X)
and _fold(X)
functions that put the tensor into a matrix and back respectively. This I then just put into the geotorch.Stiefel API. I don't use the @transpose decorators since they are called in the super().right_inverse()
and super().forward()
but declare the self.transposed
flag.
from numpy import prod
import torch
import torch.nn as nn
import geotorch
from geotorch.utils import transpose, _extra_repr
from geotorch.exceptions import VectorError, InManifoldError
def size_flattened(size, dim):
size = list(size)
size_dim = size[dim]
size[dim] = 1
return (size_dim, prod(size))
class FlattenedStiefel(geotorch.Stiefel):
def __init__(self, size, triv="expm"):
super().__init__(size_flattened(size, 0), triv)
self.size = list(size)
self.transposed = self.size[0] < prod(self.size[1:])
def _unfold(self, X):
X = X.flatten(1)
return X
def _fold(self, X):
return X.view(self.size)
def forward(self, X):
X = self.unfold(X)
X = self.frame(X)
X = super().forward(X)
X = self._fold(X[..., : self.k])
return X
def right_inverse(self, X, check_in_manifold=True):
if check_in_manifold and not self.in_manifold(X):
raise InManifoldError(X, self)
X = self._unfold(X)
X = super().right_inverse(X,check_in_manifold=False)
return self._fold(X)
def in_manifold(self, X, eps=1e-4):
X = self._unfold(X)
if X.size(-1) > X.size(-2):
X = X.transpose(-2, -1)
return geotorch.so._has_orthonormal_columns(X, eps)
def sample(self, distribution="uniform", init_=None):
X = super().sample(distribution, init_)
X = self._fold(X)
return X
def flattened_orthogonal(module, tensor_name="weight", triv="expm"):
return geotorch.constraints._register_manifold(module, tensor_name, FlattenedStiefel, triv)
layer = nn.Conv2d(40, 20, 3, 3) # So layer.weight.shape = (20, 40, 3, 3)
flattened_orthogonal(layer, "weight") # Register the parameterization
print(layer)
which gives:
ParametrizedConv2d(
40, 20, kernel_size=(3, 3), stride=(3, 3)
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): FlattenedStiefel(n=360, k=20, triv=linalg_matrix_exp, transposed)
)
)
)
However, if I try calling W = layer.weight
, it gives me Exception has occurred: AttributeError 'ParametrizedConv2d' object has no attribute 'weight'
.
Any idea why it does not have correctly registered the weight attribute? Thanks!
That looks like it's going in the right direction. Now, I suggest you test the functions FlattenedStiefel.forward
and FlattenedStiefel.right_inverse
and so on the input before registering it. That exception happens due to an issue with how properties work in Python. It means that an exception was raised during the execution of the property. Then properties sadly swallow that exception unconditionally and raise that other exception, making everything quite confusing.
Oh thanks for the very good hint how to look for the bug, I managed to find that I had a typo and some other small mistakes! The final result looks like this if anyone is interested:
from numpy import prod
import torch
import torch.nn as nn
import geotorch
from geotorch.utils import transpose, _extra_repr
from geotorch.exceptions import VectorError, InManifoldError
def size_flattened(size, dim):
size = list(size)
size_dim = size[dim]
size[dim] = 1
return (size_dim, prod(size))
class FlattenedStiefel(geotorch.Stiefel):
def __init__(self, size, triv="expm"):
super().__init__(size_flattened(size, 0), triv)
self.size = list(size)
self.transposed = self.size[0] < prod(self.size[1:])
def _unfold(self, X):
X = X.flatten(1)
return X
def _fold(self, X):
return X.view(self.size)
def forward(self, X):
X = self._unfold(X)
X = super().forward(X)
X = self._fold(X)
return X
def right_inverse(self, X, check_in_manifold=True):
if check_in_manifold and not self.in_manifold(X):
raise InManifoldError(X, self)
X = self._unfold(X)
X = super().right_inverse(X,check_in_manifold=False)
return self._fold(X)
def in_manifold(self, X, eps=1e-4):
X = self._unfold(X)
if X.size(-1) > X.size(-2):
X = X.transpose(-2, -1)
return geotorch.so._has_orthonormal_columns(X, eps)
def sample(self, distribution="uniform", init_=None):
X = super().sample(distribution, init_)
X = self._fold(X)
return X
def flattened_orthogonal(module, tensor_name="weight", triv="expm"):
return geotorch.constraints._register_manifold(module, tensor_name, FlattenedStiefel, triv)
layer = nn.Conv2d(40, 20, 3, 3) # Make the kernels orthogonal
# layer.weight.shape = (20, 40, 3, 3)
flattened_orthogonal(layer, "weight")
print(layer)
W = layer.weight # W has size (20, 40, 3, 3)
W = W.flatten(1)
print(torch.allclose(W @ W.T, torch.eye(20), atol=1e-4))
prints:
ParametrizedConv2d(
40, 20, kernel_size=(3, 3), stride=(3, 3)
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): FlattenedStiefel(n=360, k=20, triv=linalg_matrix_exp, transposed)
)
)
)
True
Thanks for the help, it has been quite informative to read through your library!
Great that you made it work in the end! Closing this one
Hi,
I am struggling to impose orthogonality constraints on the weights of convolutional layers as was also asked in Issue 10, where we unfold the weight matrix (in, out, k ,k) as as matrix (in, out k k) on which we pose orthogonality. If I try to run the code from @lezcano's answer, I get:
line 28, in _register_manifold tensor.copy_(X) RuntimeError: The size of tensor a (3) must match the size of tensor b (20) at non-singleton dimension 3
Trying to run the other answer from @bokveizen, I get:
raise InManifoldError(X, self) geotorch.exceptions.InManifoldError: Tensor not contained in FlattenedStiefel(n=180, k=40, triv=linalg_matrix_exp, transposed). Got: tensor([[[[ ...
I tried rewriting the
in_manifold
check to run self.forward() first, but the code still gives RuntimeError.What works is rewinding to the commit from a year ago, however, I would like to reproduce the same functionality in geotorch@v0.3.0.
Thanks for help!
Code from the answer of @bokveizen in Issue 10: