hkust-nlp / deita

Deita: Data-Efficient Instruction Tuning for Alignment [ICLR2024]
Apache License 2.0
458 stars 28 forks source link

Cosine distance computation #20

Closed sangkeun00 closed 6 months ago

sangkeun00 commented 6 months ago

Hello,

While the paper and the code both say that cosine distance is used to promote diversity, it seems that the current implementation computes cosine similarity instead of distance:

https://github.com/hkust-nlp/deita/blob/983e98fe0441946f41720553b919f665824001d9/src/deita/selection/filter/base.py#L33-L35

If cosine similarity is used, it rather enforces data similarity than diversity. Any clarifications would be much appreciated!

Best, Sang

VPeterV commented 6 months ago

Hi, we utilize cosine similarity as our metric. We will filter out any new samples whose cosine similarities with the already selected samples exceed a certain threshold. In other words, while we measure cosine similarity, we primarily use it to eliminate redundant data.

sangkeun00 commented 6 months ago

Thanks for the prompt reply. My point is that if two data are similar, cosine "similarity" is high whereas cosine "distance" is low, as their names imply. As you mentioned, the current implementation uses cosine similarity instead of cosine distance, and set the threshold to be 0.9 as below:

https://github.com/hkust-nlp/deita/blob/983e98fe0441946f41720553b919f665824001d9/src/deita/selection/filter/base.py#L94

Therefore, all data chosen by the above code should have very high (>0.9) cosine similarities. Intuitively, when we say data is diverse, my understanding is that they are dissimilar to each other. So, it seems the above code rather reduces diversity in my opinion. Let me know if my understanding is wrong!

VPeterV commented 6 months ago

Got it. As you can see in our code here, we will select those elements using "~distance_bool" rather than "distance_bool". I guess the name "filtered_indices" is sorta confusing. They are actually indices after filtering (i.e. "selected_indices"). https://github.com/hkust-nlp/deita/blob/983e98fe0441946f41720553b919f665824001d9/src/deita/selection/filter/base.py#L103

sangkeun00 commented 6 months ago

My bad! I missed ~. I think I got confused because the inequality sign in Algorithm 1 (d<T) and the one in page 6 (F:= d>T) are opposite. Thanks for the explanation!