It is very common for large models to encounter infinite loops during inference, and we need some methods to prevent this from happening. If infinite loops during inference are not monitored, it can significantly impact reasoning efficiency.
Therefore, I need a parameter no_repeat_n_gram to prevent the generation of sequences where n consecutive tokens repeat, thus mitigating the occurrence of infinite loops. The specific implementation method is as follows: for a generated token x_i, for each possible value of x_i (in the case of sampling, x_i could have multiple possibilities), we monitor whether generating this token violates the no_repeat_n_gram_size. If it does, we set its logit to negative infinity, thereby preventing the generation of n-gram repetitions.
In practice, I will set n as large as possible to act as a punishment for infinite loops without overly affecting the model's normal inference output. The reason I do not use repeat_penalty is that it penalizes all tokens that have appeared during inference, which I consider to be an overly harsh penalty, while I only need a mechanism that specifically targets infinite loops.
Alternatives
No response
Additional context
The implementation of no_repeat_n_gram from the transformers library might be helpful.
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
"""
Assume ngram_size=2 and prev_input_ids=tensor([[40, 2883, 2712, 4346]]). The output of generated ngrams look like
this {(40,): [2883], (2883,): [2712], (2712,): [4346]}.
Args:
ngram_size (`int`):
The number sequential tokens taken as a group which may only occur once before being banned.
prev_input_ids (`torch.Tensor`):
Generated token ids for the current hypothesis.
num_hypos (`int`):
The number of hypotheses for which n-grams need to be generated.
Returns:
generated_ngrams (`dict`):
Dictionary of generated ngrams.
"""
# Initialize an empty list of dictionaries, one for each hypothesis (index) in the range of num_hypos
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
# Loop through each n-gram of size ngram_size in the list of tokens (gen_tokens)
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
return generated_ngrams
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
"""
Determines the banned tokens for the current hypothesis based on previously generated n-grams.
Args:
banned_ngrams (`dict`):
A dictionary containing previously generated n-grams for each hypothesis.
prev_input_ids (`torch.Tensor`):
Generated token ids for the current hypothesis.
ngram_size (`int`):
The number sequential tokens taken as a group which may only occur once before being banned.
cur_len (`int`):
The current length of the token sequences for which the n-grams are being checked.
Returns:
List of tokens that are banned.
"""
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - ngram_size
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
return banned_ngrams.get(ngram_idx, [])
def _calc_banned_ngram_tokens(
ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
banned_tokens = [
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
for hypo_idx in range(num_hypos)
]
return banned_tokens
Before submitting a new issue...
[X] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
š The feature, motivation and pitch
It is very common for large models to encounter infinite loops during inference, and we need some methods to prevent this from happening. If infinite loops during inference are not monitored, it can significantly impact reasoning efficiency.
Therefore, I need a parameter
no_repeat_n_gram
to prevent the generation of sequences where n consecutive tokens repeat, thus mitigating the occurrence of infinite loops. The specific implementation method is as follows: for a generated token x_i, for each possible value of x_i (in the case of sampling, x_i could have multiple possibilities), we monitor whether generating this token violates theno_repeat_n_gram_size
. If it does, we set its logit to negative infinity, thereby preventing the generation of n-gram repetitions.In practice, I will set n as large as possible to act as a punishment for infinite loops without overly affecting the model's normal inference output. The reason I do not use
repeat_penalty
is that it penalizes all tokens that have appeared during inference, which I consider to be an overly harsh penalty, while I only need a mechanism that specifically targets infinite loops.Alternatives
No response
Additional context
The implementation of
no_repeat_n_gram
from the transformers library might be helpful.Before submitting a new issue...