lezcano / geotorch

Constrained optimization toolkit for PyTorch
https://geotorch.readthedocs.io
MIT License
656 stars 34 forks source link

Custom initialization for class #4

Closed jiahaosu closed 3 years ago

jiahaosu commented 3 years ago

I am wondering whether I can initialize the constrained objects myself (for example, I would like to initialize the orthogonal matrix as identical mapping).

Best,

Jiahao

lezcano commented 3 years ago

This functionality is implemented in the branch fibrations. This will be merged into master as soon as I get myself around updating the documentation. If I find the time to do so today, I will do it today.

You have examples of how to initialise weights in the examples folder in that branch. It is as simple as passing to the parametrised parameter the tensor that you want to use to initialise it.

An example using torch.nn.init.orthogonal_ to initialise an orthogonal layer to uniformly random (according to the Haar measure) orthogonal matrix could be done as follows:

geotorch.orthogonal(self.recurrent_kernel, "weight")
self.recurrent_kernel.weight = torch.nn.init.orthogonal_(self.recurrent_kernel.weight)
jiahaosu commented 3 years ago

Thanks a lot for your quick response! I am looking forward to the merged master branch.

jiahaosu commented 3 years ago

I have tried to use the method you suggested, I have got an error.

My command line code is

>>> layer = nn.Linear(3, 3, bias = False)
>>> layer.weight
Parameter containing:
tensor([[ 0.4125, -0.2594, -0.3813],
        [-0.3678,  0.5076, -0.3462],
        [ 0.4600,  0.0209, -0.4292]], requires_grad=True)
>>> geotorch.orthogonal(layer, "weight")
>>> layer.weight
tensor([[ 0.1501, -0.9853,  0.0816],
        [-0.0489,  0.0750,  0.9960],
        [-0.9875, -0.1535, -0.0369]], grad_fn=<MmBackward>)
>>> layer.weight = torch.nn.init.eye_(layer.weight)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/sujiahao/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 819, in __setattr__
    object.__setattr__(self, name, value)
  File "/Users/sujiahao/miniconda3/envs/pytorch/lib/python3.6/site-packages/geotorch/parametrize.py", line 182, in set_value
    module.parametrizations[tensor_name].set_value_(value)
  File "/Users/sujiahao/miniconda3/envs/pytorch/lib/python3.6/site-packages/geotorch/parametrize.py", line 52, in set_value_
    value = module.initialize_(value)
  File "/Users/sujiahao/miniconda3/envs/pytorch/lib/python3.6/site-packages/geotorch/so.py", line 79, in initialize_
    if not SO.in_manifold(X, self.base.size()):
  File "/Users/sujiahao/miniconda3/envs/pytorch/lib/python3.6/site-packages/geotorch/so.py", line 98, in in_manifold
    error = D.abs().sum(dim=-2).max(dim=-1) / k
TypeError: unsupported operand type(s) for /: 'torch.return_types.max' and 'int'

Could you suggest the correct usage of the code? Thanks!!!

Best,

Jiahao

lezcano commented 3 years ago

Oh, thanks for the MWE, but may I ask what PyTorch version are you using, and whether this was on the last version of the fibrations branch? I was not able to reproduce it in PyTorch 1.7.1 and the current fibrations branch:

>>> import torch
>>> import geotorch
>>> import torch.nn as nn
>>> layer = nn.Linear(3, 3, bias = False)
>>> layer.weight
Parameter containing:
tensor([[ 0.1486, -0.0671, -0.4187],
        [-0.2514, -0.2428, -0.2468],
        [ 0.4127, -0.2797,  0.1743]], requires_grad=True)
>>> geotorch.orthogonal(layer, "weight")
>>> layer.weight
tensor([[-0.5787, -0.6748, -0.4581],
        [-0.0457,  0.5876, -0.8078],
        [ 0.8143, -0.4465, -0.3709]], grad_fn=<MmBackward>)
>>> layer.weight = torch.nn.init.eye_(layer.weight)
>>> layer.weight
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]], grad_fn=<MmBackward>)
>>> layer.parametrizations.weight[0].base
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
lezcano commented 3 years ago

Oh, I see the problem now! Before PyTorch 1.7.0, there was no norm 1 for matrices implemented in PyTorch, so I implemented it myself... and it had the bug you found. And, at the moment, I do not test against older versions of PyTorch... I should add that sometime soon, sorry for that! I'll push a fix in a couple minutes. Even then, I am planning to add all this properly tested to master sometime in the following days, so all this should be working and tested for all the manifolds sometime next week :)

Edit: Comited now, please tell me if you still have problems!

lezcano commented 3 years ago

I just finished implementing and testing all that, and now it's merged into master. I also implemented convenience sample() methods in all the manifolds, to make the life of the user a bit easier :) You can check the examples and take a look at the implementation. In summary, now you can do things like:

linear = nn.Linear(8, 8)
geotorch.orthogonal(linear, "weight")
SO = linear.parametrizations.weight[0]
# Initialize weight using the Haar measure
linear.weight = SO.sample("uniform")
# Initialize weight using a distribution that works well on RNNs
linear.weight = SO.sample("torus")
# Initialize weight to be the identity matrix
linear.weight = torch.nn.init.eye_(layer.weight)

Do you have any other question? Can we close this?

jiahaosu commented 3 years ago

Thanks! We can close this issue.