moskomule / anatome

Ἀνατομή is a PyTorch library to analyze representation of neural networks
MIT License
61 stars 6 forks source link

why orthonormalize the cca combination instead of the cca vectors/canonical neurons? #30

Open brando90 opened 2 years ago

brando90 commented 2 years ago

Why does anatome compute pwcca by orthonormalizing the a vector instead of the CCA vector x_tilde = x @ a? e.g.

https://github.com/moskomule/anatome/blob/393b36df77631590be7f4d23bff5436fa392dc0e/anatome/distance.py#L161

the authors do the latter i.e. orthonormalize x_tilde not a: Screen Shot 2021-11-16 at 10 32 03 AM

https://arxiv.org/abs/1806.05759

current anatome code:

def pwcca_distance(x: Tensor,
                   y: Tensor,
                   backend: str
                   ) -> Tensor:
    """ Projection Weighted CCA proposed in Marcos et al. 2018.
    Args:
        x: input tensor of Shape DxH, where D>H
        y: input tensor of Shape DxW, where D>H
        backend: svd or qr
    Returns:
    """

    a, b, diag = cca(x, y, backend)
    a, _ = torch.linalg.qr(a)  # reorthonormalize
    alpha = (x @ a).abs_().sum(dim=0)
    alpha /= alpha.sum()
    return 1 - alpha @ diag

related: https://stackoverflow.com/questions/69993768/how-does-one-implement-pwcca-in-pytorch-match-the-original-pwcca-implemented-in

brando90 commented 2 years ago

I think what looks incorrect to me is that you have:

a, b, diag = cca(x, y, backend)  # a.size()  = [D1, C]
a, _ = torch.linalg.qr(a)  # a.size() = [D1, C]
alpha = (x @ a).abs_()   # alpha.size() = [B, C]

but you need to compute x_tilde first:

a, b, diag = cca(x, y, backend)  # a.size()  = [D1, C]
x_tilde = x @ a   # x_tilde.size() = [B, C]
x_tilde, _ = torch.linalg.qr(x_tilde)  # x_tilde.size() = [B, C]
alpha = (x_tilde.T @ a).abs_()   # alpha.size() = [C, D1]

in one line the difference btw the two is more obvious (ignore the extra orthonormalize for clarity):

a, b, diag = cca(x, y, backend)  # a.size()  = [D1, C]
alpha = ((x @ a).T @ a).abs_()   # alpha.size() = [C, D1]

it seems that is different from what anatome currently has?

moskomule commented 2 years ago

Hi, yes you're right. I misinterpreted what they say "CCA vectors".

brando90 commented 2 years ago

@moskomule also shouldn't we be using the centered x's, y's when do the rest of the computations? (note the current code uses the input x, y)

brando90 commented 2 years ago

Hi, yes you're right. I misinterpreted what they say "CCA vectors".

here is the corrected implementation with the if statement too:

def pwcca_distance3(x: Tensor,
                    y: Tensor,
                    backend: str,
                    use_layer_matrix: Optional[str] = None,
                    epsilon: float = 1e-10
                    ) -> Tensor:
    """ Projection Weighted CCA proposed in Marcos et al. 2018.

    Args:
        x: input tensor of Shape NxD1, where it's recommended that N>Di
        y: input tensor of Shape NxD2, where it's recommended that N>Di
        backend: svd or qr

    Returns:

    """
    x = _zero_mean(x, dim=0)
    y = _zero_mean(y, dim=0)
    # x = _divide_by_max(_zero_mean(x, dim=0))
    # y = _divide_by_max(_zero_mean(y, dim=0))
    B, D1 = x.size()
    B2, D2 = y.size()
    assert B == B2
    C_ = min(D1, D2)
    a, b, diag = cca(x, y, backend)
    C = diag.size(0)
    assert (C == C_)
    assert a.size() == torch.Size([D1, C])
    assert diag.size() == torch.Size([C])
    assert b.size() == torch.Size([D2, C])
    if use_layer_matrix is None:
        # sigma_xx_approx = x
        # sigma_yy_approx = y
        sigma_xx_approx = x.T @ x
        sigma_yy_approx = y.T @ y
        x_diag = torch.diag(sigma_xx_approx.abs())
        y_diag = torch.diag(sigma_yy_approx.abs())
        x_idxs = (x_diag >= epsilon)
        y_idxs = (y_diag >= epsilon)
        use_layer_matrix: str = 'x' if x_idxs.sum() <= y_idxs.sum() else 'y'
    if use_layer_matrix == 'x':
        x_tilde = x @ a
        assert x_tilde.size() == torch.Size([B, C])
        x_tilde, _ = torch.linalg.qr(input=x_tilde)
        assert x_tilde.size() == torch.Size([B, C])
        alpha_tilde_dot_x_abs = (x_tilde.T @ x).abs_()
        assert alpha_tilde_dot_x_abs.size() == torch.Size([C, D1])
        alpha_tilde = alpha_tilde_dot_x_abs.sum(dim=1)
        assert alpha_tilde.size() == torch.Size([C])
    elif use_layer_matrix == 'y':
        y_tilde = y @ b
        assert y_tilde.size() == torch.Size([B, C])
        y_tilde, _ = torch.linalg.qr(input=y_tilde)
        assert y_tilde.size() == torch.Size([B, C])
        alpha_tilde_dot_y_abs = (y_tilde.T @ y).abs_()
        assert alpha_tilde_dot_y_abs.size() == torch.Size([C, D2])
        alpha_tilde = alpha_tilde_dot_y_abs.sum(dim=1)
        assert alpha_tilde.size() == torch.Size([C])
    else:
        raise ValueError(f"Invalid input: {use_layer_matrix=}")
    assert alpha_tilde.size() == torch.Size([C])
    alpha = alpha_tilde / alpha_tilde.sum()
    assert alpha_tilde.size() == torch.Size([C])
    return 1.0 - (alpha @ diag)

fee free to remove the asserts.

brando90 commented 2 years ago

note that despite all that effort I still can't match their pwcca value:

Google's: pwcca_mean=0.21446671574218149
Google's (fixed): pwcca_mean2=0.21446671574218149
Our code: pwcca_ultimateanatome=tensor(0.2131, dtype=torch.float64)
Our code: pwcca_ultimateanatome_L1=tensor(0.2131, dtype=torch.float64)
Our code: pwcca_ultimateanatome_L2=tensor(0.2120, dtype=torch.float64)
Our code: pwcca_extended_anatome=tensor(0.2112, dtype=torch.float64)
Our code: pwcca_extended_anatome_L1=tensor(0.2112, dtype=torch.float64)
Our code: pwcca_extended_anatome_L2=tensor(0.2116, dtype=torch.float64)

which is puzzling because all the previous code - especially your svcca stuff - did match fairly well.

brando90 commented 2 years ago

I can confirm the code I shared above seems to give experiments that are consistent. I suggest we use that one.