fagp / sinkhorn-rebasin

MIT License
15 stars 1 forks source link

sinkhorn-rebasin incompatible with default torchvision VGG #9

Closed epistoteles closed 1 year ago

epistoteles commented 1 year ago

The implementation of the library assumes that a CNN has a linear classifier with only one layer.

When using the default torchvision.models.vgg11 (i.e. without overwriting the model.classifier with a single linear layer like in the example), the functional equivalence with .identity_init() or .random_init() is not given anymore.

fagp commented 1 year ago

Hello @epistoteles, thanks for using our library.

From a design perspective, I can assure you that RebasinNet does not make any assumptions regarding the number of layers. Therefore, my hypothesis for your issue is that you might be experiencing a floating-point precision problem. The more layers and classes you have, the greater the impact. The technical explanation for this is that the permutation transformation is implemented as the product $P_i Wi P{i-1}^T$, and this accumulates errors in the less significant decimal digits (RebasinNet L59). This kind of implementation of the permutation transformation is standard in the community.

Besides that, I don't expect you to encounter any other issues with vgg11. Since you haven't provided any code, I tried to reproduce it on my side in the following Colab. There, I provide an example of a successful re-basin for the original vgg11 from torchvision (1000 classes and three linear layers). I also show in the Colab that the functional equivalence remains up to 5 decimal digits for 32-bit precision and 14 decimal digits for 64-bit precision.