Closed QuentinJGMace closed 5 months ago
I think there is a mistake in the way logits are computed in paddle.py file
This is the code right now logits = (- samples.matmul(self.w.transpose(1, 2)) - 1 / 2 * (self.w**2).sum(2).view(n_tasks, 1, -1) - 1 / 2 * (samples**2).sum(2).view(n_tasks, -1, 1)) return - logits
logits = (- samples.matmul(self.w.transpose(1, 2)) - 1 / 2 * (self.w**2).sum(2).view(n_tasks, 1, -1) - 1 / 2 * (samples**2).sum(2).view(n_tasks, -1, 1))
return - logits
I think it should be logits = ( samples.matmul(self.w.transpose(1, 2)) - 1 / 2 * (self.w**2).sum(2).view(n_tasks, 1, -1) - 1 / 2 * (samples**2).sum(2).view(n_tasks, -1, 1)) return logits
logits = ( samples.matmul(self.w.transpose(1, 2)) - 1 / 2 * (self.w**2).sum(2).view(n_tasks, 1, -1) - 1 / 2 * (samples**2).sum(2).view(n_tasks, -1, 1))
return logits
Yes sorry about that and thanks for pointing it out ! I fixed it.
I think there is a mistake in the way logits are computed in paddle.py file
This is the code right now
logits = (- samples.matmul(self.w.transpose(1, 2)) - 1 / 2 * (self.w**2).sum(2).view(n_tasks, 1, -1) - 1 / 2 * (samples**2).sum(2).view(n_tasks, -1, 1))
return - logits
I think it should be
logits = ( samples.matmul(self.w.transpose(1, 2)) - 1 / 2 * (self.w**2).sum(2).view(n_tasks, 1, -1) - 1 / 2 * (samples**2).sum(2).view(n_tasks, -1, 1))
return logits