moskomule / anatome

Ἀνατομή is a PyTorch library to analyze representation of neural networks
MIT License
61 stars 6 forks source link

Shouldn't the similarity of a pre-trained net vs a random net decrease from random data to real data? #20

Closed brando90 closed 2 years ago

brando90 commented 2 years ago

I was trying to sanity check anatome and I noticed that the cca similarity of a random net vs pre-trained net increased when moving from random data to real images.. I thought this was surprising since the pre-trained net was trained on images, so when that input is given it should accentuate the difference vs a random net.

Isn't that would you'd expect?

resnet18
Din ~ (C*H*W)=12288
Are random net & pre-trained net similar? They should not (so sim should be small):
-> scca_full_random_data=0.6459464728832245 (but might be more similar than expected on random data)
Are random net & pre-trained net similar? They should not (so sim should be small):
scca_full_mini_imagenet_data=0.9605575129389763 (the difference should be accentuated with real data, so lowest sim)
brando90 commented 2 years ago

snippets of code:

def cxa_dist(mdl1: nn.Module, mdl2: nn.Module, X: Tensor, layer_name: str,
             downsample_size: Optional[str] = None, iters: int = 1, cxa_dist_type: str = 'pwcca') -> float:
    import copy
    mdl1 = copy.deepcopy(mdl1)
    mdl2 = copy.deepcopy(mdl2)
    # print(cca_size)
    # meta_batch [T, N*K, CHW], [T, K, D]
    from anatome import SimilarityHook
    # get sim/dis functions
    hook1 = SimilarityHook(mdl1, layer_name, cxa_dist_type)
    hook2 = SimilarityHook(mdl2, layer_name, cxa_dist_type)
    mdl1.eval()
    mdl2.eval()
    for _ in range(iters):  # might make sense to go through multiple is NN is stochastic e.g. BN, dropout layers
        # x = torch_uu.torch_uu.distributions.Uniform(low=lb, high=ub).sample((num_samples_per_task, Din))
        # x = torch_uu.torch_uu.distributions.Uniform(low=-1, high=1).sample((15, 1))
        # x = torch_uu.torch_uu.distributions.Uniform(low=-1, high=1).sample((500, 1))
        mdl1(X)
        mdl2(X)
    # - size: size of the feature map after downsampling
    dist = hook1.distance(hook2, size=downsample_size)
    # - remove hook, to make sure code stops being stateful (I hope)
    # remove_hook(mdl1, hook1)
    # remove_hook(mdl2, hook2)
    return float(dist)

def cxa_sim(mdl1: nn.Module, mdl2: nn.Module, X: Tensor, layer_name: str,
             downsample_size: Optional[str] = None, iters: int = 1, cxa_dist_type: str = 'pwcca') -> float:
    dist = cxa_dist(mdl1, mdl2, X, layer_name, downsample_size, iters, cxa_dist_type)
    return 1.0 - dist

def dCXA(mdl1: nn.Module, mdl2: nn.Module, X: Tensor, layer_name: str,
             downsample_size: Optional[str] = None, iters: int = 1, cxa_dist_type: str = 'pwcca') -> float:
    return cxa_dist(mdl1, mdl2, X, layer_name, downsample_size, iters, cxa_dist_type)

def sCXA(mdl1: nn.Module, mdl2: nn.Module, X: Tensor, layer_name: str,
             downsample_size: Optional[str] = None, iters: int = 1, cxa_dist_type: str = 'pwcca') -> float:
    return cxa_sim(mdl1, mdl2, X, layer_name, downsample_size, iters, cxa_dist_type)

and

def anatome_test_are_random_vs_pretrain_resnets_different():
    """
    random vs pre-trained nets should show different nets
    - no downsample
    - still true if downsample (but perhaps similarity increases, due to collapsing nets makes r.v.s
    interact more btw each other, so correlation is expected to increase).
    """
    from torchvision.models import resnet18
    B = 512
    C, H, W = 3, 64, 64
    print(f'Din ~ {(C*H*W)=}')
    downsample_size = 4
    mdl1 = resnet18()
    mdl2 = resnet18(pretrained=True)
    # - layer name
    # layer_name = 'bn1'
    layer_name = 'layer2.1.bn2'
    # layer_name = 'layer4.1.bn2'
    # layer_name = 'fc'

    # # -- we expect low CCA/sim since random nets vs pre-trained nets are different (especially on real data)
    # # - random data test
    X: torch.Tensor = torch.distributions.Uniform(low=-1, high=1).sample((B, C, H, W))
    scca_full_random_data: float = sCXA(mdl1, mdl2, X, layer_name, downsample_size=None)
    print(f'Are random net & pre-trained net similar? They should not (so sim should be small):\n'
          f'-> {scca_full_random_data=} (but might be more similar than expected on random data)')
    # scca_downsampled: float = sCXA(mdl1, mdl2, X, layer_name, downsample_size=downsample_size)
    # print(f'Are random net & pre-trained net similar? They should not (so sim should be small): {scca_downsampled=}')

    #
    mdl1 = resnet18()
    mdl2 = resnet18(pretrained=True)
    # - mini-imagenet test (the difference should be accentuated btw random net & pre-trained on real img data)
    from uutils.torch_uu.dataloaders import get_set_of_examples_from_mini_imagenet
    k_eval: int = B  # num examples is about M = k_eval*(num_classes) = B*(num_classes)
    X: torch.Tensor = get_set_of_examples_from_mini_imagenet(k_eval)
    scca_full_mini_imagenet_data: float = sCXA(mdl1, mdl2, X, layer_name, downsample_size=None)
    print(f'Are random net & pre-trained net similar? They should not (so sim should be small):\n'
          f'{scca_full_mini_imagenet_data=} (the difference should be accentuated with real data, so lowest sim)')
    assert(scca_full_mini_imagenet_data < scca_full_random_data), f'Sim should decrease, because the pre-trained net' \
                                                                  f'was trained on real images, so the weights are' \
                                                                  f'tuned for it but random weights are not, which' \
                                                                  f'should increase the difference so the sim' \
                                                                  f'should be lowest here. i.e. we want ' \
                                                                  f'{scca_full_mini_imagenet_data}<{scca_full_random_data}'
    # scca_downsampled: float = sCXA(mdl1, mdl2, X, layer_name, downsample_size=downsample_size)
    # print(f'Are random net & pre-trained net similar? They should not (so sim should be small): {scca_downsampled=}')