Closed brando90 closed 3 years ago
ref: https://arxiv.org/pdf/1905.00414.pdf Similarity of Neural Network Representations Revisited
https://github.com/google-research/google-research/blob/master/representation_similarity/Demo.ipynb official google tutorial
sanity check to write?
for now this works:
cxa_dist_type = 'lincka'
X: torch.Tensor = torch.distributions.Normal(loc=0.0, scale=1.0).sample((B, Din))
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} ({cxa_dist_type=})')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0)), f'Sim should be close to 1.0 but got {sim=}'
with correct normalization too.
same results qualitatively! sucess!
OPD is a little slower though, so careful when using it...
TODO: what normalization does CKA need + sanity checks for it?