I think that there's an error in the calculation of covariance matrices. In the line xm = torch.mean(source, 1, keepdim=True) - source, imo 1 must be replaced by 0, as I would like to calculate the mean of each feature along the entire batch. Please correct me if I am wrong.
Thanks, wonderful work by the way.
You're right, I'll fix it. Thanks
Interestingly, this bug didn't affect the performance much, the accuracy after bug fixed is about ~58%, it is ~56% before the fix. 😕
I think that there's an error in the calculation of covariance matrices. In the line
xm = torch.mean(source, 1, keepdim=True) - source
, imo1
must be replaced by0
, as I would like to calculate the mean of each feature along the entire batch. Please correct me if I am wrong. Thanks, wonderful work by the way.