InternLM / lmdeploy

LMDeploy is a toolkit for compressing, deploying, and serving LLMs.
https://lmdeploy.readthedocs.io/en/latest/
Apache License 2.0
3.11k stars 280 forks source link

drop stop words #1823

Open grimoire opened 1 week ago

grimoire commented 1 week ago

fix for https://github.com/InternLM/lmdeploy/pull/1754

Stop words should NOT be cached.

  1. User should be able to get the very same result if they gather all input/output and recompute it. Since we won't give stop words to user, we should not cache it too.
  2. Baichuan2 would forgot the history if we put the eos in cache.
grimoire commented 1 week ago

@zhulinJulia24

zhulinJulia24 commented 5 days ago

@zhulinJulia24

Fixed!

grimoire commented 5 days ago

@lvhan028 Is the behavior aligned with turbomind?

lvhan028 commented 5 days ago

@lvhan028 Is the behavior aligned with turbomind?

Turbomind caches the stop_words but not the eos_id. https://github.com/InternLM/lmdeploy/blob/da439dfd186265faf8074797f5ed4c8a3f3c4f2d/lmdeploy/turbomind/turbomind.py#L751

lvhan028 commented 4 days ago

"Since we won't give stop words to user, we should not cache it too." I don't think so. In the non-stateful case, it is OK, since message2pormpt will add the stop_words in between. So it means, in the stateful case, the stop_words should be saved.

Regarding the baichuan model, I think we should experiment with the eos_id token. Addeos_id after each assistant's answer in a multi-round scenario and use transformers to do the inference. Let's check if it loses the memory

grimoire commented 3 days ago

Regarding the baichuan model, I think we should experiment with the eos_id token. Addeos_id after each assistant's answer in a multi-round scenario and use transformers to do the inference. Let's check if it loses the memory

Sure it does

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

def main():
    model_path = '/path/to/Baichuan2-13B-Chat/'

    eos = ''
    eos = '</s>'
    messages = [
        {"role": "user", "content": "Do you know John Wick?"},
        {"role": "assistant", "content": f"Yes, it is a movie. {eos}"},
        {"role": "user", "content": "Tell me more about it."},
    ]

    tokenizer = AutoTokenizer.from_pretrained(model_path,
        revision="v2.0",
        use_fast=False,
        trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_path,
        revision="v2.0",
        device_map="auto",
        torch_dtype=torch.bfloat16,
        trust_remote_code=True)
    model.generation_config = GenerationConfig.from_pretrained(model_path, revision="v2.0")
    with torch.inference_mode():
        response = model.chat(tokenizer, messages)
        print(response)

if __name__ == '__main__':
    main()

output with eos

I'm sorry, I am not sure what you are referring to. Can you provide more context or clarification?

output w/o eos

"John Wick" is a 2014 action thriller film directed by Chad Stahelski and written by Derek Kolstad. The film stars Keanu Reeves as the title character, an ex-secret service agent who goes on a revenge mission after his car and dog are stolen at the behest of a Russian mobster (played by Alfie Allen). The film was released to positive reviews from critics, who praised Reeves' performance and the film's fast-paced action sequences. A sequel, "John Wick: Chapter 2", was released in 2017, and a third installment, "John Wick: Chapter 3 – Parabellum", was released in 2019.
grimoire commented 3 days ago

I can align the behavior with TurboMind, but putting the logic in the template is more reasonable.