SSARCandy / DeepCORAL

🧠 A PyTorch implementation of 'Deep CORAL: Correlation Alignment for Deep Domain Adaptation.', ECCV 2016
https://ssarcandy.tw/2017/10/31/deep-coral/
226 stars 42 forks source link

I think you have a error. #19

Open typhoon1104 opened 4 years ago

typhoon1104 commented 4 years ago

DEEP_CORAL_LOSS: def CORAL(source, target): d = source.data.shape[1] ns = source.data.shape[0] nt = target.data.shape[0]

# source covariance
xm = torch.mean(source, 0, keepdim=True) - source
xc = (xm.t() @ xm) / (ns-1)

# target covariance
xmt = torch.mean(target, 0, keepdim=True) - target
xct = xmt.t() @ xmt / (nt-1)

print(xc, xct)

# frobenius norm between source and target
loss = torch.sum(torch.mul((xc - xct), (xc - xct)))
loss = loss/(4*d*d)
return loss
ZhouWenjun2019 commented 3 years ago

I don't understand this part, too. Do you have some idea?

ch-andrei commented 2 years ago

tldr, no error, this code is "correct". See my answer for this issue .