I have some question about the following codes
{
def whitened_barlow_twins_loss(self, y1, y2):
z1 = self.whiten_bn(self.projector(self.backbone(y1)))
z2 = self.whiten_bn(self.projector(self.backbone(y2)))
if self.args.whiten=='true':
z1 = self.whiten_net.zca_forward(z1) # N * d
z2 = self.whiten_net.zca_forward(z2)
c = torch.mm(z1.transpose(0, 1), z2) #d * d
c.div_(self.args.batch_size)
loss = torch.diagonal(c).add_(-1).pow_(2).sum()
if self.args.off == 'true':
off_diag = off_diagonal(c).pow_(2).sum()
loss += self.args.lambd * off_diag
return loss
}
Can you answer me the following question:
1、You can calculate the loss about the instance and feature level. But I only find the loss about feature level (d * d).
You can add z1 = z1.transpose(1, 2) before z1 = self.whiten_net.zca_forward(z1).
Then, the whitening process will be calculated on the instance dimension.
I have some question about the following codes { def whitened_barlow_twins_loss(self, y1, y2): z1 = self.whiten_bn(self.projector(self.backbone(y1))) z2 = self.whiten_bn(self.projector(self.backbone(y2)))
} Can you answer me the following question: 1、You can calculate the loss about the instance and feature level. But I only find the loss about feature level (d * d).