mistralai / mistral-inference

Official inference library for Mistral models
https://mistral.ai/
Apache License 2.0
9.16k stars 803 forks source link

Suggested improvement of eos logic in generate.py #180

Open vvatter opened 4 weeks ago

vvatter commented 4 weeks ago

https://github.com/mistralai/mistral-inference/blob/c24ac864ab623ca39bda4f48c334eed6e55f13a2/src/mistral_inference/generate.py#L91

In the generate() function of generate.py, there is some curious XOR logic for updating the boolean is_finished vector:

        if eos_id is not None:
            is_finished = is_finished ^ (next_token == eos_id).cpu()

Even once it reaches an eos token, Mistral likes to keep talking, so this means that if you are running large batches, the shortest response might hit eos and then generate another eos and flip back to is_finished == False before the longest response has finished, which will often keep happening up until you hit max_tokens. It seems to me that this should be an OR.

Additionally, the current approach allows tokens following an EOS to be included in outputs, which, since the tokenizer decodes EOS as an empty string, might contribute to confusing output sequences. This could potentially relate to the issues discussed in #149 .

To address both issues, I suggest the following modifications to ensure that is_finished remains True after encountering an eos token and to not return tokens after this point.

        if eos_id is not None:
            is_finished = is_finished | (next_token == eos_id).cpu()
            next_token = next_token * (~is_finished).to(next_token.device)
            next_token = next_token + eos_id * is_finished.to(next_token.device)