moskomule / anatome

Ἀνατομή is a PyTorch library to analyze representation of neural networks
MIT License
61 stars 6 forks source link

Add Orthogonal Procrustes Distance + similarity->distance #11

Closed moskomule closed 2 years ago

moskomule commented 2 years ago

10

brando90 commented 2 years ago

warning this pull request makes anatome fail my sanity checks e.g. when D is really large (much larger than # data points) sim should be 1.0 since there is a lot of power for the linear model to correlate the two data sets.

brando90 commented 2 years ago

Code to do sanity check

brando90 commented 2 years ago

#%%
"""
The similarity of the same network should always be 1.0 on same input.
"""
import torch
import torch.nn as nn

import uutils.torch_uu
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_identity_one_layer_linear_model

print('--- Sanity check: sCCA = 1.0 when using same net twice with same input. --')

Din: int = 10
Dout: int = Din
B: int = 2000
mdl1: nn.Module = get_named_identity_one_layer_linear_model(D=Din)
mdl2: nn.Module = mdl1
layer_name = 'fc0'
# cxa_dist_type = 'pwcca'
cxa_dist_type = 'svcca'

# - ends up comparing two matrices of size [B, Dout], on same data, on same model
X: torch.Tensor = torch.distributions.Normal(loc=0.0, scale=1.0).sample((B, Din))
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)

print(f'Should be very very close to 1.0: {sim=}')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0))

#%%
"""
Reproducing: How many data points: https://github.com/google/svcca/blob/master/tutorials/001_Introduction.ipynb

As n increases, the cca sim should decrease until it converges to the true max linear correlation in the data.
This is because when D is small it's easy to correlate via Xw, Yw since there are less equations (m data) than unknown (D features).
Similarly, the similarity decreases because the more data there is, the more variation has to be captured and thus the less
correlation there will be.
This is correct because 1/4*E[|| Xw - Yw||^2]^2 is proportional the pearson's correlation (assuming Xw, Yw is standardized).

"""
from pathlib import Path
from matplotlib import pyplot as plt

import torch
import torch.nn as nn

import uutils
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_one_layer_random_linear_model

import uutils.plot as uulot

print('\n--- Sanity check: when number of data points B is smaller than D, then it should be trivial to make similiarty 1.0 '
      '(even if nets/matrices are different)')
B: int = 10
Dout: int = 300
mdl1: nn.Module = get_named_one_layer_random_linear_model(B, Dout)
mdl2: nn.Module = get_named_one_layer_random_linear_model(B, Dout)
layer_name = 'fc0'
# cxa_dist_type = 'pwcca'
cxa_dist_type = 'svcca'

# - get sim for B << D e.g. [B=10, D=300] easy to "fit", to many degrees of freedom
X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
# mdl1(X) : [B, Dout] = [B, B] [B, Dout]
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} (since we have many features to match the two Xw1, Yw2).')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
# assert(approx_equal(sim, 1.0))

print('\n-- Santity: just makes sure that when low data is present sim is high and afterwards (as n->infty) sim (CCA) '
      'converges to the "true" cca value (eventually)')
# data_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000]
data_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000, 10_000]
# data_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000, 10_000, 50_000, 100_000]
# data_sizes: list[int] = [10, 25, 50, 100, 200, 500, 1_000, 2_000, 5_000, 10_000]
sims: list[float] = []
for b in data_sizes:
    X: torch.Tensor = uutils.torch_uu.get_identity_data(b)
    mdl1: nn.Module = get_named_one_layer_random_linear_model(b, Dout)
    mdl2: nn.Module = get_named_one_layer_random_linear_model(b, Dout)
    # print(f'{b=}')
    sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
    # print(f'{sim=}')
    sims.append(sim)

print(f'{sims=}')
uulot.plot(x=data_sizes, y=sims, xlabel='number of data points (n)', ylabel='similarity (svcca)', show=True, save_plot=True, plot_filename='ndata_vs_svcca_sim', title='Features (D) vs Sim (SVCCA)', x_hline=Dout, x_hline_label=f'B=D={Dout}')

#%%

from pathlib import Path
from matplotlib import pyplot as plt

import torch
import torch.nn as nn

import uutils
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_one_layer_random_linear_model

from uutils.plot import plot, save_to_desktop
import uutils.plot as uuplot

B: int = 10  # [101, 200, 500, 1000, 2000, 5000, 10000]
Din: int = B
Dout: int = 300
mdl1: nn.Module = get_named_one_layer_random_linear_model(Din, Dout)
mdl2: nn.Module = get_named_one_layer_random_linear_model(Din, Dout)
layer_name = 'fc0'
# cxa_dist_type = 'pwcca'
cxa_dist_type = 'svcca'

X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)

print(f'Should be very very close to 1.0: {sim=}')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')

# data_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000, 10_000, 50_000]
B: int = 300
D_feature_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000, 10_000]
sims: list[float] = []
for d in D_feature_sizes:
    X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
    mdl1: nn.Module = get_named_one_layer_random_linear_model(B, d)
    mdl2: nn.Module = get_named_one_layer_random_linear_model(B, d)
    sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
    # print(f'{d=}, {sim=}')
    sims.append(sim)

print(f'{sims=}')
uuplot.plot(x=D_feature_sizes, y=sims, xlabel='number of features/size of dimension (D)', ylabel='similarity (svcca)', show=True, save_plot=True, plot_filename='D_vs_sim_svcca', title='Features (D) vs Sim (SVCCA)', x_hline=B, x_hline_label=f'B=D={B}')
# uuplot.plot(x=D_feature_sizes, y=sims, xlabel='number of features/size of dimension (D)', ylabel='similarity (svcca)', show=True, save_plot=True, plot_filename='D_vs_sim', title='Features (D) vs Sim (SVCCA)')
```
brando90 commented 2 years ago

plots you should get: ndata_vs_svcca_sim D_vs_sim_svcca

brando90 commented 2 years ago

Something in this pull request breaks anatome...

moskomule commented 2 years ago

Something in this pull request breaks anatome...

Thank you for reporting. I introduced torch.linalg.svd etc, but it may work different from torch.svd.

brando90 commented 2 years ago

Something in this pull request breaks anatome...

Thank you for reporting. I introduced torch.linalg.svd etc, but it may work different from torch.svd.

for now I will use the current version... If I have time I will test or try to fix the new one, but I'm also busy hehehe :) hope this helps though, at least the santity check should allow us to chat obvious bugs.

brando90 commented 2 years ago

I see other differences too like:

_matrix_normalize

instead of

_zero_mean
brando90 commented 2 years ago

I don't think you need to _matrix_normalize for CCA (idk for the others). The formula already has it:

max_{a, b} a^T X^T Y b / (a^TXa) (b^T Y b)

it only assumes centering. In short, pearson-correlation already normalizes in the demoninator.

Though, I don't think this should have made a difference.

Centering is for sure needed since the product only gives the covariance when things are centered.

brando90 commented 2 years ago

if I may suggest this implementation for OPD - since it re-uses other code you already wrote:

def orthogonal_procrustes_distance(x: Tensor,
                                   y: Tensor,
                                   ) -> Tensor:
    """ Orthogonal Procrustes distance used in Ding+21.
    Returns in dist interval [0, 1].

    Note:
        -  for a raw representation A we first subtract the mean value from each column, then divide
    by the Frobenius norm, to produce the normalized representation A* , used in all our dissimilarity computation.
        - see uutils.torch_uu.orthogonal_procrustes_distance to see my implementation
    Args:
        x: input tensor of Shape DxH
        y: input tensor of Shape DxW
    Returns:
    """
    _check_shape_equal(x, y, 0)

    # frobenius_norm = partial(torch.linalg.norm, ord="fro")
    nuclear_norm = partial(torch.linalg.norm, ord="nuc")

    x = _matrix_normalize(x, dim=0)
    y = _matrix_normalize(y, dim=0)
    # x = _zero_mean(x, dim=0)
    # x /= frobenius_norm(x)
    # y = _zero_mean(y, dim=0)
    # y /= frobenius_norm(y)
    # frobenius_norm(x) = 1, frobenius_norm(y) = 1
    # 0.5*d_proc(x, y)
    # - note this already outputs it between [0, 1] e.g. it's not 2 - 2 nuclear_norm(<x1, x2>)
    return 1 - nuclear_norm(x.t() @ y)
brando90 commented 2 years ago

I can confirm that normalizing by the forbenius norm breaks one of the CCA santity checks:

normalizing by the forbenius norm breaks the sanity check when D is really large for cca.

See:

--- Sanity check: when number of data points B is smaller than D, then it should be trivial to make similiarty 1.0 (even if nets/matrices are different)
Should be very very close to 1.0: sim=0.9341215491294861 (since we have many features to match the two Xw1, Yw2).
Is it close to 1.0? False

Code:

from pathlib import Path
from matplotlib import pyplot as plt

import torch
import torch.nn as nn

import uutils
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_one_layer_random_linear_model

import uutils.plot as uulot

print('\n--- Sanity check: when number of data points B is smaller than D, then it should be trivial to make similiarty 1.0 '
      '(even if nets/matrices are different)')
B: int = 10
Dout: int = 100
mdl1: nn.Module = get_named_one_layer_random_linear_model(B, Dout)
mdl2: nn.Module = get_named_one_layer_random_linear_model(B, Dout)
layer_name = 'fc0'
# cxa_dist_type = 'pwcca'
cxa_dist_type = 'svcca'

# - get sim for B << D e.g. [B=10, D=300] easy to "fit", to many degrees of freedom
X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
# mdl1(X) : [B, Dout] = [B, B] [B, Dout]
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} (since we have many features to match the two Xw1, Yw2).')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
# assert(approx_equal(sim, 1.0))

but this didn't break the plots, surprisingly.

brando90 commented 2 years ago

Something in this pull request breaks anatome...

Thank you for reporting. I introduced torch.linalg.svd etc, but it may work different from torch.svd.

No that is not it (according to my sanity checks above that ran with U, S, Vh = torch.linalg.svd(input, full_matrices=False). I think it the division of the Frobenius norm for cca. It might be nice to figure out which need that. Afaik, only orthogonal Procrustes needs that.

brando90 commented 2 years ago

ok found the bug!

You need to do the centering correct because * binds stronger than -. So normalization is as follows:

def _matrix_normalize(input: Tensor,
                      dim: int
                      ) -> Tensor:
    """
    Center and normalize according to the forbenius norm (not the standard deviation).

    Warning: this does not create standardized random variables in a random vectors.

    Note: careful with this, it makes CCA behave in unexpected ways
    :param input:
    :param dim:
    :return:
    """
    from torch.linalg import norm
    return (input - input.mean(dim=dim, keepdim=True)) / norm(input, 'fro')

or even better reuse your _zerp_mean.

My sanity checks (for all metrics pass) now:

/Users/brando/anaconda3/envs/metalearning/bin/python /Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py --cmd-line --multiproc --qt-support=auto --client 127.0.0.1 --port 55456 --file /Users/brando/ultimate-utils/tutorials_for_myself/anatome_pg/sanity_checks_anatome.py
Connected to pydev debugger (build 212.5080.64)
--- Sanity check: sCCA = 1.0 when using same net twice with same input. --
Should be very very close to 1.0: sim=1.000000238418579 (cxa_dist_type='svcca')
Is it close to 1.0? True
Should be very very close to 1.0: sim=0.9999998807907104 (cxa_dist_type='pwcca')
Is it close to 1.0? True
Should be very very close to 1.0: sim=1.0 (cxa_dist_type='lincka')
Is it close to 1.0? True
Should be very very close to 1.0: sim=0.9997346997261047 (cxa_dist_type='opd')
Is it close to 1.0? True
--- Sanity check: when number of data points B is smaller than D, then it should be trivial to make similiarty 1.0 (even if nets/matrices are different)
Should be very very close to 1.0: sim=1.000000238418579 (since we have many features to match the two Xw1, Yw2).
Is it close to 1.0? True
-- Santity: just makes sure that when low data is present sim is high and afterwards (as n->infty) sim (CCA) converges to the "true" cca value (eventually)
sims=[0.9999998807907104, 0.9999999403953552, 0.9801047444343567, 0.9169793725013733, 0.6231850981712341, 0.3799371123313904, 0.2702748775482178, 0.18766719102859497, 0.11999624967575073, 0.08386451005935669]
Should be very very close to 1.0: sim=1.0
Is it close to 1.0? True
sims=[0.2898038625717163, 0.44516634941101074, 0.6200690865516663, 0.9168117046356201, 0.9173185229301453, 0.9742245674133301, 0.9898524284362793, 0.9903322458267212, 0.9898055791854858, 0.98990398645401, 0.9907135367393494]
import sys; print('Python %s on %s' % (sys.version, sys.p
brando90 commented 2 years ago

final sanity check code:

#%%
"""
The similarity of the same network should always be 1.0 on same input.
"""
import torch
import torch.nn as nn

import uutils.torch_uu
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_identity_one_layer_linear_model

print('--- Sanity check: sCCA = 1.0 when using same net twice with same input. --')

Din: int = 10
Dout: int = Din
B: int = 2000
mdl1: nn.Module = get_named_identity_one_layer_linear_model(D=Din)
mdl2: nn.Module = mdl1
layer_name = 'fc0'

# - ends up comparing two matrices of size [B, Dout], on same data, on same model
cxa_dist_type = 'svcca'
X: torch.Tensor = torch.distributions.Normal(loc=0.0, scale=1.0).sample((B, Din))
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} ({cxa_dist_type=})')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0)), f'Sim should be close to 1.0 but got {sim=}'

cxa_dist_type = 'pwcca'
X: torch.Tensor = torch.distributions.Normal(loc=0.0, scale=1.0).sample((B, Din))
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} ({cxa_dist_type=})')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0)), f'Sim should be close to 1.0 but got {sim=}'

cxa_dist_type = 'lincka'
X: torch.Tensor = torch.distributions.Normal(loc=0.0, scale=1.0).sample((B, Din))
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} ({cxa_dist_type=})')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0)), f'Sim should be close to 1.0 but got {sim=}'

cxa_dist_type = 'opd'
X: torch.Tensor = torch.distributions.Normal(loc=0.0, scale=1.0).sample((B, Din))
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} ({cxa_dist_type=})')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0, tolerance=1e-2)}')
assert(approx_equal(sim, 1.0, tolerance=1e-2)), f'Sim should be close to 1.0 but got {sim=}'

#%%
"""
Reproducing: How many data points: https://github.com/google/svcca/blob/master/tutorials/001_Introduction.ipynb

As n increases, the cca sim should decrease until it converges to the true max linear correlation in the data.
This is because when D is small it's easy to correlate via Xw, Yw since there are less equations (m data) than unknown (D features). 
Similarly, the similarity decreases because the more data there is, the more variation has to be captured and thus the less
correlation there will be.
This is correct because 1/4*E[|| Xw - Yw||^2]^2 is proportional the pearson's correlation (assuming Xw, Yw is standardized).

"""
from pathlib import Path
from matplotlib import pyplot as plt

import torch
import torch.nn as nn

import uutils
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_one_layer_random_linear_model

import uutils.plot as uulot

print('\n--- Sanity check: when number of data points B is smaller than D, then it should be trivial to make similiarty 1.0 '
      '(even if nets/matrices are different)')
B: int = 10
Dout: int = 100
mdl1: nn.Module = get_named_one_layer_random_linear_model(B, Dout)
mdl2: nn.Module = get_named_one_layer_random_linear_model(B, Dout)
layer_name = 'fc0'
# cxa_dist_type = 'pwcca'
cxa_dist_type = 'svcca'

# - get sim for B << D e.g. [B=10, D=300] easy to "fit", to many degrees of freedom
X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
# mdl1(X) : [B, Dout] = [B, B] [B, Dout]
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} (since we have many features to match the two Xw1, Yw2).')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0))

print('\n-- Santity: just makes sure that when low data is present sim is high and afterwards (as n->infty) sim (CCA) '
      'converges to the "true" cca value (eventually)')
# data_sizes: list[int] = [10, 25, 50, 100, 200, 500, 1_000, 2_000, 5_000]
data_sizes: list[int] = [10, 25, 50, 100, 200, 500, 1_000, 2_000, 5_000, 10_000]
# data_sizes: list[int] = [10, 25, 50, 100, 200, 500, 1_000, 2_000, 5_000, 10_000, 50_000, 100_000]
# data_sizes: list[int] = [10, 25, 50, 100, 200, 500, 1_000, 2_000, 5_000, 10_000]
sims: list[float] = []
for b in data_sizes:
    X: torch.Tensor = uutils.torch_uu.get_identity_data(b)
    mdl1: nn.Module = get_named_one_layer_random_linear_model(b, Dout)
    mdl2: nn.Module = get_named_one_layer_random_linear_model(b, Dout)
    # print(f'{b=}')
    sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
    # print(f'{sim=}')
    sims.append(sim)

print(f'{sims=}')
uulot.plot(x=data_sizes, y=sims, xlabel='number of data points (n)', ylabel='similarity (svcca)', show=True, save_plot=True, plot_filename='ndata_vs_svcca_sim', title='Features (D) vs Sim (SVCCA)', x_hline=Dout, x_hline_label=f'B=D={Dout}')

#%%

from pathlib import Path
from matplotlib import pyplot as plt

import torch
import torch.nn as nn

import uutils
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_one_layer_random_linear_model

from uutils.plot import plot, save_to_desktop
import uutils.plot as uuplot

B: int = 10  # [101, 200, 500, 1000, 2000, 5000, 10000]
Din: int = B
Dout: int = 300
mdl1: nn.Module = get_named_one_layer_random_linear_model(Din, Dout)
mdl2: nn.Module = get_named_one_layer_random_linear_model(Din, Dout)
layer_name = 'fc0'
# cxa_dist_type = 'pwcca'
cxa_dist_type = 'svcca'

X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=}')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0))

# data_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000, 10_000, 50_000]
B: int = 100
D_feature_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000, 10_000]
sims: list[float] = []
for d in D_feature_sizes:
    X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
    mdl1: nn.Module = get_named_one_layer_random_linear_model(B, d)
    mdl2: nn.Module = get_named_one_layer_random_linear_model(B, d)
    sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
    # print(f'{d=}, {sim=}')
    sims.append(sim)

print(f'{sims=}')
uuplot.plot(x=D_feature_sizes, y=sims, xlabel='number of features/size of dimension (D)', ylabel='similarity (svcca)', show=True, save_plot=True, plot_filename='D_vs_sim_svcca', title='Features (D) vs Sim (SVCCA)', x_hline=B, x_hline_label=f'B=D={B}')
# uuplot.plot(x=D_feature_sizes, y=sims, xlabel='number of features/size of dimension (D)', ylabel='similarity (svcca)', show=True, save_plot=True, plot_filename='D_vs_sim', title='Features (D) vs Sim (SVCCA)')
brando90 commented 2 years ago

reference on how to normalize: https://stats.stackexchange.com/questions/544812/how-should-one-normalize-activations-of-batches-before-passing-them-through-a-si

brando90 commented 2 years ago

Note, it's better to divide by centered data. The accuracy of OPD increases dramatically. Comparing the same matrix twice finally gives 1.0 up to 1e-4 instead of 1e-2

brando90 commented 2 years ago
def _matrix_normalize_using_centered_data(X: Tensor, dim: int = 1) -> Tensor:
    """
    Normalize matrix of size wrt to the data dimension according to the similarity preprocessing standard.
    Assumption is that X is of size [n, d].
    Otherwise, specify which simension to normalize with dim.

    ref: https://stats.stackexchange.com/questions/544812/how-should-one-normalize-activations-of-batches-before-passing-them-through-a-si
    """
    from torch.linalg import norm
    X_centered: Tensor = _zero_mean(X, dim=dim)
    X_star: Tensor = X_centered / norm(X_centered, "fro")
    return X_star

results:

Connected to pydev debugger (build 212.5080.64)
--- Sanity check: sCCA = 1.0 when using same net twice with same input. --
Should be very very close to 1.0: sim=1.0000004768371582 (cxa_dist_type='svcca')
Is it close to 1.0? True
Should be very very close to 1.0: sim=1.000000238418579 (cxa_dist_type='pwcca')
Is it close to 1.0? True
Should be very very close to 1.0: sim=1.0 (cxa_dist_type='lincka')
Is it close to 1.0? True
Should be very very close to 1.0: sim=1.0000001192092896 (cxa_dist_type='opd')
Is it close to 1.0? True
--- Sanity check: when number of data points B is smaller than D, then it should be trivial to make similiarty 1.0 (even if nets/matrices are different)
Should be very very close to 1.0: sim=0.9999998807907104 (since we have many features to match the two Xw1, Yw2).
Is it close to 1.0? True
-- Santity: just makes sure that when low data is present sim is high and afterwards (as n->infty) sim (CCA) converges to the "true" cca value (eventually)
sims=[1.0000001192092896, 1.0, 0.9794895648956299, 0.9188864231109619, 0.6179077625274658, 0.3843235969543457, 0.2695028781890869, 0.18886375427246094, 0.11978656053543091, 0.0842815637588501]
Should be very very close to 1.0: sim=1.0
Is it close to 1.0? True
sims=[0.24919700622558594, 0.43115103244781494, 0.6279942393302917, 0.9188255667686462, 0.9206753969192505, 0.9731308817863464, 0.9901297688484192, 0.9902339577674866, 0.990931510925293, 0.9907766580581665, 0.9902600049972534]
import sys; print('Python %s on %s' % (sys.version, sys.platform))
brando90 commented 2 years ago

@moskomule I am curious. What is the final conclusion for you for normalizing the matrices before computing the distances. Do you plan to divide by forbenius norm (of the centered matrix) for:

  1. Only for OPD?
  2. CCA?
  3. CKA?

My hunch is that OPD is the only one that needs it and only centering is enough for the other two.

moskomule commented 2 years ago

I agree with it and if I remember correctly, I implemented so.

brando90 commented 2 years ago

I agree with it and if I remember correctly, I implemented so.

in the risk of being redudant I do want to note that that is not what the authors of the OPD paper do [see here] (https://github.com/js-d/sim_metric/issues/4#issuecomment-953062107) (they normalize all the time) but with my sanity checks I doubt the difference will be large and I will do what you do and just center for CCA and CKA but only normalize for OPD.

Thanks for discussions! :)