moskomule / anatome

Ἀνατομή is a PyTorch library to analyze representation of neural networks
MIT License
61 stars 6 forks source link

Size Check #29

Closed Yupei-Du closed 2 years ago

Yupei-Du commented 2 years ago

Hello ~ Thank you for the implementation! It is amazing and helps me a lot!

I have a question: in CCA, you seem to check x.size(0) < x.size(1). I believe the implementation is correct, for I have seen some similar checks in other implementations of CCA. However, I don't understand the rationale behind this. Could you explain a bit? Thanks! Also, something related (it could be other reasons), my data has a larger feature size (x.size(1)) than the number of examples (x.size(0)), and sometimes I get this error: The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated singular values. I was wondering whether they were related? Could you provide any insight on this?

moskomule commented 2 years ago

Hi, for CCA, x.size(0) < x.size(1) cannot be defined.

You can use subsampling to make x.size(1) smaller than x.size(0) or use CKA.

import torch
from torchvision.models import resnet18
from anatome import Distance

random_model = resnet18()
learned_model = resnet18(pretrained=True)

distance = Distance(random_model, learned_model)
with torch.no_grad():
    distance.forward(torch.randn(256, 3, 224, 224))

# resize x.size(1) to size
distance.between("layer3.0.conv1", "layer3.0.conv1", size=8)

# use CKA
distance = Distance(random_model, learned_model, method='lincka')
...
brando90 commented 2 years ago

Hi, for CCA, x.size(0) < x.size(1) cannot be defined.

what do you mean it's not defined? I've removed that piece of code and the code runs fine. It will always give a sim of 1.0 since it's trivial to maximally correlate them if there are more dimensions/features than points, but in my experience it's well defined. I've never encountered numerical issues. That being said, one should never run experiments with x.size(0) < x.size(1). I usually run them such that 10 * x.size(0) >= x.size(1) as recommended by the original svcca authors.

Yupei-Du commented 2 years ago

Thank you both very much @moskomule @brando90 ! I think you made the point: when x.size(0) < x.size(1), the solution is trivial.