lezcano / geotorch

Constrained optimization toolkit for PyTorch
https://geotorch.readthedocs.io
MIT License
649 stars 33 forks source link

Posing orthogonality condition on convolution weights in v0.3.0 #38

Closed simonvary closed 1 year ago

simonvary commented 1 year ago

Hi,

I am struggling to impose orthogonality constraints on the weights of convolutional layers as was also asked in Issue 10, where we unfold the weight matrix (in, out, k ,k) as as matrix (in, out k k) on which we pose orthogonality. If I try to run the code from @lezcano's answer, I get:

line 28, in _register_manifold tensor.copy_(X) RuntimeError: The size of tensor a (3) must match the size of tensor b (20) at non-singleton dimension 3

Trying to run the other answer from @bokveizen, I get:

raise InManifoldError(X, self) geotorch.exceptions.InManifoldError: Tensor not contained in FlattenedStiefel(n=180, k=40, triv=linalg_matrix_exp, transposed). Got: tensor([[[[ ...

I tried rewriting the in_manifold check to run self.forward() first, but the code still gives RuntimeError.

What works is rewinding to the commit from a year ago, however, I would like to reproduce the same functionality in geotorch@v0.3.0.

Thanks for help!

Code from the answer of @bokveizen in Issue 10:

from numpy import prod
import torch
import torch.nn as nn
import geotorch

def size_flattened(size, dim):
    size = list(size)
    size_dim = size[dim]
    size[dim] = 1
    return (size_dim, prod(size))

class FlattenedStiefel(geotorch.Stiefel):
    def __init__(self, size, triv="expm"):
        # We asume that you want to flatten the dimensions [1:n]
        # See the comment in forward for why we keep the dim=1 and the dim=0
        super().__init__(size_flattened(size, 0), triv)
        # size = (out, in, k, k) so we transpose it
        self.size = size
        # size = list(size)
        # size[0], size[1] = size[1], size[0]
        # self.size = tuple(size)

    def forward(self, X):
        # The weight of a CNN of with params (in, out, k ,k)
        # is of size (out, in, k, k), so we transpose it before flattening it
        # X = X.T
        X = X.flatten(1)
        X = super().forward(X)
        X = X.view(self.size)
        # return X.T
        return X

    def initialize_(self, X, check_in_manifold=True):
        # X = X.T
        X = X.flatten(1)
        X = super().initialize_(X, check_in_manifold)
        X = X.view(self.size)
        # return X.T
        return X

    def sample(self, distribution="uniform", init_=None):
        X = super().sample(distribution, init_)
        X = X.view(self.size)
        # return X.T
        return X

def flattened_orthogonal(module, tensor_name="weight", triv="expm"):
    return geotorch.constraints._register_manifold(module, tensor_name, FlattenedStiefel, triv)

layer = nn.Conv2d(20, 40, 3, 3)  # Make the kernels orthogonal
flattened_orthogonal(layer, "weight")
print(layer)

W = layer.weight  # W has size (40, 20, 3, 3)
# W = W.T.flatten(1)  # W has size (20, 360) with orthogonal rows
# W = W.T # W has size (360, 20) with orthogonal columns
# # Check that W.T @ W = Id
# print(torch.allclose(W.T @ W, torch.eye(40), atol=1e-4))
W = W.flatten(1)
print(torch.allclose(W @ W.T, torch.eye(40), atol=1e-4))
lezcano commented 1 year ago

Do you have a small repro?

lezcano commented 1 year ago

also, hi Simon! Hope you're doing great in Belgium working with Absil :)

simonvary commented 1 year ago

Hi, thanks it is going well, likewise, I hope you are also doing well after finishing the DPhil :) I added the code from the Issue 10 that I am trying to get working.

lezcano commented 1 year ago

Right, so the intialize_ API is from pre 0.3. Now, the API for a manifold has the following functions: forward, right_inverse, in_manifold and sample. You'll need to implement these via flattening and unflattening for everything to work as expected.

lezcano commented 1 year ago

You can have a look at Stiefel.py to see what's the exact API for these methods. As to what they do, I think it should be clear from the name :) https://github.com/lezcano/geotorch/blob/ba38d406c245d609fee4b4dac3f6427bf6d73a8e/geotorch/stiefel.py#L40-L47

simonvary commented 1 year ago

Thank you for the advice. I tried implementing it as follows which allows me to register the parameterisation. However, I cannot seem to be able to access the weight parameter.

Basically, I just have this _unfold(X) and _fold(X) functions that put the tensor into a matrix and back respectively. This I then just put into the geotorch.Stiefel API. I don't use the @transpose decorators since they are called in the super().right_inverse() and super().forward() but declare the self.transposed flag.

from numpy import prod
import torch
import torch.nn as nn
import geotorch

from geotorch.utils import transpose, _extra_repr
from geotorch.exceptions import VectorError, InManifoldError

def size_flattened(size, dim):
    size = list(size)
    size_dim = size[dim]
    size[dim] = 1
    return (size_dim, prod(size))

class FlattenedStiefel(geotorch.Stiefel):
    def __init__(self, size, triv="expm"):
        super().__init__(size_flattened(size, 0), triv)
        self.size = list(size)
        self.transposed = self.size[0] < prod(self.size[1:])

    def _unfold(self, X):
        X = X.flatten(1)
        return X

    def _fold(self, X):
        return X.view(self.size)

    def forward(self, X):
        X = self.unfold(X)
        X = self.frame(X)
        X = super().forward(X)
        X = self._fold(X[..., : self.k])
        return X

    def right_inverse(self, X, check_in_manifold=True):
        if check_in_manifold and not self.in_manifold(X):
            raise InManifoldError(X, self)  
        X = self._unfold(X)
        X = super().right_inverse(X,check_in_manifold=False)
        return self._fold(X)

    def in_manifold(self, X, eps=1e-4):
        X = self._unfold(X)
        if X.size(-1) > X.size(-2):
            X = X.transpose(-2, -1)
        return geotorch.so._has_orthonormal_columns(X, eps)

    def sample(self, distribution="uniform", init_=None):
        X = super().sample(distribution, init_)
        X = self._fold(X)
        return X

def flattened_orthogonal(module, tensor_name="weight", triv="expm"):
    return geotorch.constraints._register_manifold(module, tensor_name, FlattenedStiefel, triv)

layer = nn.Conv2d(40, 20, 3, 3)   # So layer.weight.shape = (20, 40, 3, 3)
flattened_orthogonal(layer, "weight")  # Register the parameterization
print(layer)

which gives:

ParametrizedConv2d(
  40, 20, kernel_size=(3, 3), stride=(3, 3)
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): FlattenedStiefel(n=360, k=20, triv=linalg_matrix_exp, transposed)
    )
  )
)

However, if I try calling W = layer.weight, it gives me Exception has occurred: AttributeError 'ParametrizedConv2d' object has no attribute 'weight'.

Any idea why it does not have correctly registered the weight attribute? Thanks!

lezcano commented 1 year ago

That looks like it's going in the right direction. Now, I suggest you test the functions FlattenedStiefel.forward and FlattenedStiefel.right_inverse and so on the input before registering it. That exception happens due to an issue with how properties work in Python. It means that an exception was raised during the execution of the property. Then properties sadly swallow that exception unconditionally and raise that other exception, making everything quite confusing.

simonvary commented 1 year ago

Oh thanks for the very good hint how to look for the bug, I managed to find that I had a typo and some other small mistakes! The final result looks like this if anyone is interested:

from numpy import prod
import torch
import torch.nn as nn
import geotorch

from geotorch.utils import transpose, _extra_repr
from geotorch.exceptions import VectorError, InManifoldError

def size_flattened(size, dim):
    size = list(size)
    size_dim = size[dim]
    size[dim] = 1
    return (size_dim, prod(size))

class FlattenedStiefel(geotorch.Stiefel):
    def __init__(self, size, triv="expm"):
        super().__init__(size_flattened(size, 0), triv)
        self.size = list(size)
        self.transposed = self.size[0] < prod(self.size[1:])

    def _unfold(self, X):
        X = X.flatten(1)
        return X

    def _fold(self, X):
        return X.view(self.size)

    def forward(self, X):
        X = self._unfold(X)
        X = super().forward(X)
        X = self._fold(X)
        return X

    def right_inverse(self, X, check_in_manifold=True):
        if check_in_manifold and not self.in_manifold(X):
            raise InManifoldError(X, self)  
        X = self._unfold(X)
        X = super().right_inverse(X,check_in_manifold=False)
        return self._fold(X)

    def in_manifold(self, X, eps=1e-4):
        X = self._unfold(X)
        if X.size(-1) > X.size(-2):
            X = X.transpose(-2, -1)
        return geotorch.so._has_orthonormal_columns(X, eps)

    def sample(self, distribution="uniform", init_=None):
        X = super().sample(distribution, init_)
        X = self._fold(X)
        return X

def flattened_orthogonal(module, tensor_name="weight", triv="expm"):
    return geotorch.constraints._register_manifold(module, tensor_name, FlattenedStiefel, triv)

layer = nn.Conv2d(40, 20, 3, 3)  # Make the kernels orthogonal
# layer.weight.shape = (20, 40, 3, 3)
flattened_orthogonal(layer, "weight")
print(layer)

W = layer.weight  # W has size (20, 40, 3, 3)
W = W.flatten(1)
print(torch.allclose(W @ W.T, torch.eye(20), atol=1e-4))

prints:

ParametrizedConv2d(
  40, 20, kernel_size=(3, 3), stride=(3, 3)
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): FlattenedStiefel(n=360, k=20, triv=linalg_matrix_exp, transposed)
    )
  )
)
True

Thanks for the help, it has been quite informative to read through your library!

lezcano commented 1 year ago

Great that you made it work in the end! Closing this one