ShipengWang / Adam-NSCL

PyTorch implementation of our Adam-NSCL algorithm from our CVPR2021 (oral) paper "Training Networks in Null Space for Continual Learning"
MIT License
50 stars 2 forks source link

Normalization of transform #4

Closed danielm1405 closed 10 months ago

danielm1405 commented 1 year ago

Hi, I do not understand the way the transform is normalized: https://github.com/ShipengWang/Adam-NSCL/blob/a2f39b4273aa300739c460b33a5e8d2c674632b2/optim/adam_svd.py#L114 As far as I understand we would like self.transforms to be orthogonal, i.e. the L2 norm of each column is equal to 1. In this case, self.transforms is not orthogonal - norms of all the columns are less than 1. Therefore, transformation of update decreases the norm of update and, effectively, acts as the learning rate would be decreased.

A proper way to orthogonalize self.transforms is as follows:

self.transforms[p] = transform / transform.norm(dim=0, p=2)

When I ran the experiment with the fix above, I got very poor results: 58.82% on 10-split-CIFAR-100 vs 73.77% with the original method.

Could you explain why the transform is normalized in such way?

ShipengWang commented 1 year ago

Hi, It makes sense that the norm acts as the learning rate, but it is not our intention. By doing so, we can provide an upper bound for the scale of gradients after projection, since

torch.mm(P, grad).norm() <= grad.norm()

where P is the projection matrix. (Please refer to here for reasons)

Indeed, the projection matrix is not necessarily an orthogonal matrix, and there seems to be some misunderstanding about the definition of the orthogonal matrix in your comment.

In your implementation, self.transforms[p] is no longer a projection matrix associated with approximate null space.