voidism / DoLa

Official implementation for the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language Models"
https://arxiv.org/abs/2309.03883
419 stars 50 forks source link

questions about code #3

Closed Cheung-Z closed 8 months ago

Cheung-Z commented 1 year ago

In dola.py line112 get_relative_top_filter() sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True) min_thresh = sorted_logits[..., min_tokens_to_keep-1] probs_max = torch.max(scores_normalized, dim=-1).values probs_thresh = probs_max + np.log(relative_top) In my concern: is that min_thresh = probs_max ? and np.log(relative_top) < 0 then probs_thresh must lesser than min_thresh so whats the meanings of min_thresh ? Thx

voidism commented 8 months ago

Hi,

Sorry for the late reply. The function of min_thresh is to avoid the probs_thresh being too large and rule out too many tokens.

For example, by setting min_tokens_to_keep (int) = 3, the probs_thresh would not be higher than the prob of 3rd largest token, so the top-3 tokens will always be kept.

The implementation was borrowed from contrastive decoding: https://github.com/XiangLi1999/ContrastiveDecoding/blob/170e9142e92159c1237d731e240f5eb14aabf428/transformers/src/transformers/generation_logits_process.py#L235

In our experiment, we didn't use min_tokens_to_keep (always set to 1), so it equals not setting any min_thresh (we allow it to only keep the top-1 token)