Closed Cheung-Z closed 8 months ago
Hi,
Sorry for the late reply.
The function of min_thresh
is to avoid the probs_thresh
being too large and rule out too many tokens.
For example, by setting min_tokens_to_keep (int) = 3
, the probs_thresh
would not be higher than the prob of 3rd largest token, so the top-3 tokens will always be kept.
The implementation was borrowed from contrastive decoding: https://github.com/XiangLi1999/ContrastiveDecoding/blob/170e9142e92159c1237d731e240f5eb14aabf428/transformers/src/transformers/generation_logits_process.py#L235
In our experiment, we didn't use min_tokens_to_keep
(always set to 1), so it equals not setting any min_thresh
(we allow it to only keep the top-1 token)
In dola.py line112 get_relative_top_filter() sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True) min_thresh = sorted_logits[..., min_tokens_to_keep-1] probs_max = torch.max(scores_normalized, dim=-1).values probs_thresh = probs_max + np.log(relative_top) In my concern: is that min_thresh = probs_max ? and np.log(relative_top) < 0 then probs_thresh must lesser than min_thresh so whats the meanings of min_thresh ? Thx