brando90 / ultimate-anatome

Ἀνατομή is a PyTorch library to analyze representation of neural networks
MIT License
11 stars 1 forks source link

TODO: what normalization does CKA need + sanity checks for it? #5

Closed brando90 closed 3 years ago

brando90 commented 3 years ago

TODO: what normalization does CKA need + sanity checks for it?

brando90 commented 3 years ago

ref: https://arxiv.org/pdf/1905.00414.pdf Similarity of Neural Network Representations Revisited

brando90 commented 3 years ago

ref: https://github.com/google-research/google-research/tree/master/representation_similarity

brando90 commented 3 years ago

https://github.com/google-research/google-research/blob/master/representation_similarity/Demo.ipynb official google tutorial

brando90 commented 3 years ago

issue: https://github.com/google-research/google-research/issues/865

brando90 commented 3 years ago

sanity check to write?

brando90 commented 3 years ago

for now this works:

cxa_dist_type = 'lincka'
X: torch.Tensor = torch.distributions.Normal(loc=0.0, scale=1.0).sample((B, Din))
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} ({cxa_dist_type=})')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0)), f'Sim should be close to 1.0 but got {sim=}'

with correct normalization too.

brando90 commented 3 years ago

same results qualitatively! sucess!

OPD is a little slower though, so careful when using it...