pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.82k stars 22.59k forks source link

nn.Orthogonal #42243

Closed AlexanderMath closed 3 years ago

AlexanderMath commented 4 years ago

🚀 Feature

A module nn.Orthogonal similar to nn.Linear where the weight matrix is constrained to be orthogonal, i.e., .

Motivation

There has been a growing interest in orthogonal parameterization of neural networks, see, e.g., [1,2,3,4,5]. To use orthogonal parameterization with PyTorch one has to implement it themselves or use third party code. It would be convenient if PyTorch has a built-in module nn.Orthogonalthat handles everything automatically. In particular, it would be convenient if nn.Orthogonal support different methods by, e.g., method={fasth,cayley,exp}.

Pitch

During ICML I was suggested to make a pull request for PyTorch for fasth [5] as nn.Orthogonal. I want to

  1. be sure this feature is desired
  2. discuss potential ways of interfacing with the user
  3. implement the code and submit a pull request.

I want nn.Orthogonal to support three methods: Cayley transform, matrix exponential and fasth.

Additional context

The contribution instructions (see screenshot below) states that, generally, algorithms from recently-published research are not accepted, but it is suggested to open an issue, as I have now done. FastH is up to 20 times faster than the previous sequential algorithm (see image in bottom of page). Please note this is an algorithmic speed-up, it computes the exact same thing as the previous algorithm, just faster.

image

References

[1] Efficient Orthogonal Parametrisation of Recurrent Neural Networks Using Householder Reflections (ICML 2017) [2] A Simple Parametrization of the Orthogonal and Unitary Group (ICML 2019) [3] Stabilizing Gradients forDeep Neural Networks via Efficient SVD Parameterization (ICML 2018) [4] Trivializations for Gradient-Based Optimization on Manifolds (NeurIPS 2019) [5] Faster Orthogonal Parameterization with Householder Matrices (ICML Workshop 2020)

image

cc @albanD @mruberry

albanD commented 4 years ago

Note that we have https://github.com/pytorch/pytorch/pull/33344 that would be very related and would make this very simple to implement.

lezcano commented 3 years ago

While I completely agree that adding a nn.Orthogonal through the Cayley transform an the exponential map is a great idea, I do not see how "FastH" would be equally good.

FastH is just another way to implement a product of Householder reflections, and Householder reflections have been shown to have convergence problems (Z. Mhammedi et al.) when compared with the Cayley transform and the exponential map. In particular, they may introduce new critical points. At the same time, the paper that exposes them has not been published (as it stands it is just a workshop paper), nor it presents experiments on standard datasets such as MNIST, TIMIT or Penn Treebank, but then again, some of these experiments were carried in previous papers, and it was shown that the optimisation with Householder reflections was inferior to that given by the Cayley and the exponential map.

About the time comparison, I think that the authors should go over their benchmarking of the times again. As it was shown in the PR that introduced it, at the moment, computing the matrix exponential using torch.matrix_exp on GPU of one matrix of size 1024 x 1024 takes 700us = 0.0007s. I do not know where the figures from the workshop paper came from, but they certainly do not pose a fair comparison.

As a bottom line, as @albanD mentioned, this should be direct to implement once #33344 is finished. For now, I am maintaining a library that allows to put these kind of constraints in arbitrary layers based in #33344:

https://github.com/Lezcano/geotorch

Efficient Orthogonal Parametrisation of Recurrent Neural Networks Using Householder Reflections. Z. Mhammendi et al. ICML 2017.

AlexanderMath commented 3 years ago

I think that the authors should go over their benchmarking of the times again, ...

We report time for both forward pass and gradient computations, e.g., the time of torch.matrix_exp(A+A.T).sum().backward(). We apologize for not making this clear. Note that the article was submitted and presented before PyTorch 1.7.0 introduced torch.matrix_exp, we thus used your implementation of Padé approximation to make the plot https://github.com/Lezcano/expRNN. That said, the following dumb test with torch.matrix_exp seem to provide similar results, even though we ran it on newer hardware.

image

My group is very busy, so we won't have time to implement FastH in PyTorch.

lezcano commented 3 years ago

For what is worth, this is the time I get:

[-------- PyTorch matrix_exp and backward --------]
                 |  64x64  |  256x256  |  1024x1024
16 threads: ---------------------------------------
      batch: 1   |   3.6   |    3.9    |       8.35
Times are in milliseconds (ms)

when running the following statement:

"y = {}.matrix_exp(); _ = torch.autograd.grad(y, {}, {})".format(x, x, g)

with torch.utils.benchmark on an RTX 2080 with initialisation

x = torch.rand(batch, size, size).cuda() / size                                              
g = torch.rand(batch, size, size).cuda() / size                                              
x.requires_grad_(True) 

In other words, according to your plot, it is exactly as fast as your algorithm.

I believe that the difference comes from the fact that the exponential of matrices takes longer to compute for matrices of larger norm. On the other hand, to do optimisation with orthogonal constraints, (or on an arbitrary manifold for the matter) one can keep the norm of the matrix that one uses as small as one wants, as shown in https://arxiv.org/abs/1909.09501 The matrix may be initialised to zero and, in particular, it may be kept close to zero by changing the base of the Riemannian exponential map and hence making the exponential map cheaper to compute. The take-home here is that it is slightly trickier to compare the speed of your method and that based in the exponential of matrices than what you did in your workshop paper.

jbschlosser commented 3 years ago

I reviewed this issue and agree with @albanD & @Lezcano that proper reparameterization support will make this straightforward. Rather than encapsulating this functionality within an nn.Orthogonal module, I personally like the more flexible approach demonstrated in geotorch (btw: impressive work @Lezcano!) that would allow orthogonality constraints to be enforced on arbitrary tensor parameters (for example, the weight matrix of an RNN). I wonder if functionality like this belongs in PyTorch core or in geotorch, which offers a lot of relevant features. Thoughts on this, @Lezcano?

lezcano commented 3 years ago

I think that it would be worth it to add some basic functionality to core PyTorch, perhaps the Orthogonal layer and the SymmetricPositiveDefinite constraints (and perhaps low-rank as well?), as those are the most used two by far and will be extremely useful to researchers in when the reparametrisation framework is in place. We can discuss this once the reparametrisation framework is added.

About this second bit (reparametrisation framework), I have been meaning to contact @albanD this week, but I have been too busy. All the work regarding the reparametrisation framework is done, and it's just missing 2 architectural decisions to have it done.

albanD commented 3 years ago

The general reparametrization should definitely be in core indeed. For these specific Modules, it is less sure. In particular, we have quite a big process to check new nn.Modules and in particular we usually ask for it to show research papers that make use of it and show that there is a long term trend for this Module. But it should be fairly easy to implement in a third party lib once the main reparametrization is in core.

lezcano commented 3 years ago

For what is worth, there is a whole subfield doing optimisation on manifolds, but at the moment, the main software to do that is a one-off library (Manopt) which is written in MATLAB. If citations mean anything, the paper presenting the software has 600+ citations.

Boumal, Nicolas, Bamdev Mishra, et al. (2014). “Manopt, a MATLAB toolbox for optimization on manifolds”. In: Journal of Machine Learning Research, JMLR 15.1, pp. 1455–1459. issn: 1533-7928. doi: 10.5555/2627435.2638581

At the same time, one of the 5 most cited papers of all time from Mathematical Programming (the most important journal on optimisation) is on optimisation with orthogonal constraints. (628 citations) (which is a crazy amount in a field like optimisation, but not quite so in ML)

Wen, Z., Yin, W. A feasible method for optimization with orthogonality constraints. Math. Program. 142, 397–434 (2013). https://doi.org/10.1007/s10107-012-0584-1

Same with the journal Matrix Analysis Journal and the paper (2588 citations):

Edelman, Alan, Tomás A. Arias, and Steven Thomas Smith (1998). “The geometry of algorithms with orthogonality constraints”. In: SIAM Journal on Matrix Analysis and Applications 20.2, pp. 303–353. issn: 0895-4798. doi: 10.1137/S0895479895290954

I think that this is a very important topic and if anything, it is here to stay. As such, I think that PyTorch could definitely benefit of having some basic performant implementation of a parametrisation that makes all these ideas plug and play, but then again, I think that this is something that we could discuss down the road :)

albanD commented 3 years ago

Thanks for the reference! I would need to take a closer look but it feels like the othogonal case at least would be good wrt to these rules. But indeed we can discuss that further once the main PR is merged.

toshas commented 3 years ago

This package of mine might be relevant together with the upcoming parameterization functionality: https://github.com/toshas/torch-householder

lezcano commented 3 years ago

That package looks indeed interesting, I have left there in an issue a couple of questions that I had after a first look at it :)

Even then, the question of the practical efficiency of the householder reflections still holds, as I have not seen them work as well as the exponential / Cayley map in the setting of RNNs (See the results in (Z. Mhammedi et al.) cited above)---although I agree that an independent review paper of optimisation with orthogonal constraints would be of much help here. But then again, this should be part of a discussion on what optimisation algorithms to use in case of implementing optimisation with orthogonal constraints, which should wait until the parametrisation framework is in PyTorch core.