facebookresearch / DomainBed

DomainBed is a suite to test domain generalization algorithms
MIT License
1.42k stars 298 forks source link

Discrepancies between CORAL implementation and paper #100

Closed cwognum closed 2 years ago

cwognum commented 2 years ago

Following equation (1), (2), and (3) in the paper, I am not sure if this is the same as the current implementation.

Equations from paper

image

Code example

import torch
import numpy as np

# Create a tensor for two domains with two samples per batch each 
data = [
    [[1.0, 0.0], [1.0, 1.0]], 
    [[0.0, 0.0], [1.0, 1.0]]
]
phis = torch.tensor(data)

cov_torch = torch.cov(phis[0])
# tensor([[0.5000, 0.0000],
#         [0.0000, 0.0000]])

cov_np = np.cov(data[0])
# array([[0.5, 0. ],
#        [0. , 0. ]])

mean_x = phis[0].mean(0, keepdim=True)
cent_x = phis[0] - mean_x
cova_x = (cent_x.t() @ cent_x) / (len(phis[0]) - 1)
# tensor([[0.0000, 0.0000],
#         [0.0000, 0.5000]])

The output is even more different for the second domain


cov_torch = torch.cov(phis[1])
# tensor([[0., 0.],
#         [0., 0.]])

# Left out numpy for brevity, but same results as `torch.cov()`

mean_y = phis[1].mean(0, keepdim=True)
cent_y = phis[1] - mean_y
cova_y = (cent_y.t() @ cent_y) / (len(phis[1]) - 1)
# tensor([[0.5000, 0.5000],
#         [0.5000, 0.5000]])

Putting it all together, the resulting penalties differ quite significantly too:

penalty = torch.norm(torch.cov(phis[0]) - torch.cov(phis[1]))
# tensor(0.5000)

penalty = (cova_x - cova_y).pow(2).mean()
# tensor(0.1875)

And then I'm not even considering the mean_diff, which also does not seem to be mentioned in the paper.

I can imagine that I'm missing something here. Could you elaborate if the deviation from the paper is intentional and if so, why?

cwognum commented 2 years ago

Ah wait, I think I see it now. It's because I should be computing the column-wise covariance, not row-wise. So that gives:

cov_torch = torch.cov(phis[0].T)
# tensor([[0.0000, 0.0000],
#         [0.0000, 0.5000]])

I'm still not 100% sure why both the first and second moment are matched rather than just the second, but I can see why that would be important. Closing the issue for now.