SegoleneMartin / PADDLE

5 stars 2 forks source link

Error on logits sign (paddle.py) #3

Closed QuentinJGMace closed 5 months ago

QuentinJGMace commented 5 months ago

I think there is a mistake in the way logits are computed in paddle.py file

This is the code right now logits = (- samples.matmul(self.w.transpose(1, 2)) - 1 / 2 * (self.w**2).sum(2).view(n_tasks, 1, -1) - 1 / 2 * (samples**2).sum(2).view(n_tasks, -1, 1)) return - logits

I think it should be logits = ( samples.matmul(self.w.transpose(1, 2)) - 1 / 2 * (self.w**2).sum(2).view(n_tasks, 1, -1) - 1 / 2 * (samples**2).sum(2).view(n_tasks, -1, 1)) return logits

SegoleneMartin commented 5 months ago

Yes sorry about that and thanks for pointing it out ! I fixed it.