SegoleneMartin / PADDLE

5 stars 2 forks source link

Incorrect distance computation in `get_logits` #2

Closed danoneata closed 6 months ago

danoneata commented 7 months ago

Hello! The current code for distance computation in the get_logits method reads as follows:

logits = (
    -samples.matmul(self.w.transpose(1, 2))
    - 1 / 2 * (self.w**2).sum(2).view(n_tasks, 1, -1)
    - 1 / 2 * (samples**2).sum(2).view(n_tasks, -1, 1)
)

But it should rather be

logits = (
-    -samples.matmul(self.w.transpose(1, 2))
+    +samples.matmul(self.w.transpose(1, 2))
    - 1 / 2 * (self.w**2).sum(2).view(n_tasks, 1, -1)
    - 1 / 2 * (samples**2).sum(2).view(n_tasks, -1, 1)
)

because -½‖x - y‖² =〈x, y〉- ½‖x‖² - ½‖y‖². This means that the current implementation updates the centroids that are far away from the query points.

SegoleneMartin commented 6 months ago

Hello! Thanks. I fixed it. All the results in the paper were produced with the correct sign but I must have introduced this error while cleaning the code. A somewhat simpler and cleaner version of PADDLE can be fined in the repo transductive_CLIP.

danoneata commented 6 months ago

Thanks for the quick response and the follow-up clarification! Do you have a link to the transductive_CLIP repo?

SegoleneMartin commented 6 months ago

Sure, here: https://github.com/SegoleneMartin/transductive-CLIP

danoneata commented 6 months ago

Thank you!