vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
31.14k stars 4.73k forks source link

[Feature]: need no_repeat_n_gram in SamplingParams #7842

Open pspdada opened 3 months ago

pspdada commented 3 months ago

šŸš€ The feature, motivation and pitch

It is very common for large models to encounter infinite loops during inference, and we need some methods to prevent this from happening. If infinite loops during inference are not monitored, it can significantly impact reasoning efficiency.

Therefore, I need a parameter no_repeat_n_gram to prevent the generation of sequences where n consecutive tokens repeat, thus mitigating the occurrence of infinite loops. The specific implementation method is as follows: for a generated token x_i, for each possible value of x_i (in the case of sampling, x_i could have multiple possibilities), we monitor whether generating this token violates the no_repeat_n_gram_size. If it does, we set its logit to negative infinity, thereby preventing the generation of n-gram repetitions.

In practice, I will set n as large as possible to act as a punishment for infinite loops without overly affecting the model's normal inference output. The reason I do not use repeat_penalty is that it penalizes all tokens that have appeared during inference, which I consider to be an overly harsh penalty, while I only need a mechanism that specifically targets infinite loops.

Alternatives

No response

Additional context

The implementation of no_repeat_n_gram from the transformers library might be helpful.


def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
    """
    Assume ngram_size=2 and prev_input_ids=tensor([[40, 2883, 2712, 4346]]). The output of generated ngrams look like
    this {(40,): [2883], (2883,): [2712], (2712,): [4346]}.

    Args:
        ngram_size (`int`):
            The number sequential tokens taken as a group which may only occur once before being banned.
        prev_input_ids (`torch.Tensor`):
           Generated token ids for the current hypothesis.
        num_hypos (`int`):
            The number of hypotheses for which n-grams need to be generated.

    Returns:
        generated_ngrams (`dict`):
            Dictionary of generated ngrams.
    """
    # Initialize an empty list of dictionaries, one for each hypothesis (index) in the range of num_hypos
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        # Loop through each n-gram of size ngram_size in the list of tokens (gen_tokens)
        for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
    return generated_ngrams

def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
    """
    Determines the banned tokens for the current hypothesis based on previously generated n-grams.

    Args:
        banned_ngrams (`dict`):
            A dictionary containing previously generated n-grams for each hypothesis.
        prev_input_ids (`torch.Tensor`):
            Generated token ids for the current hypothesis.
        ngram_size (`int`):
            The number sequential tokens taken as a group which may only occur once before being banned.
        cur_len (`int`):
            The current length of the token sequences for which the n-grams are being checked.

    Returns:
        List of tokens that are banned.
    """
    # Before decoding the next token, prevent decoding of ngrams that have already appeared
    start_idx = cur_len + 1 - ngram_size
    ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
    return banned_ngrams.get(ngram_idx, [])

def _calc_banned_ngram_tokens(
    ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
    banned_tokens = [
        _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
        for hypo_idx in range(num_hypos)
    ]
    return banned_tokens

Before submitting a new issue...

HeegonJin commented 3 weeks ago

are you developing this feature?

pspdada commented 3 weeks ago

are you developing this feature?

No.