huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.25k stars 26.34k forks source link

How to manually stop the LLM output? #31963

Closed invokerbyxv closed 2 months ago

invokerbyxv commented 2 months ago

I'm using TextIteratorStreamer for streaming output.

Since LLM may repeat its output indefinitely, I would like to be able to have LLM stop generating when it receives a request to cancel.

Is there any way to accomplish this?

model: glm-4-9b-chat

async def predict(messages, model_id: str, raw_request: Request, gen_kwargs: Dict):
    global model, tokenizer
    choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(role='assistant'), finish_reason=None)
    chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object='chat.completion.chunk')
    yield '{}'.format(_dump_json(chunk, exclude_unset=True))

    inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt')
    inputs = inputs.to(model.device)
    streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True, timeout=60.0, skip_special_tokens=True)
    generation_kwargs = dict(input_ids=inputs, streamer=streamer)
    generation_kwargs.update(gen_kwargs)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    for new_text in streamer:
        print(new_text)
        if raw_request is not None and await raw_request.is_disconnected():
            print("disconnected")
            # todo stop generate
        choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(content=new_text), finish_reason=None)
        chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object='chat.completion.chunk')
        yield '{}'.format(_dump_json(chunk, exclude_unset=True))

    choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(content=''), finish_reason='stop')
    chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object='chat.completion.chunk')
    yield '{}'.format(_dump_json(chunk, exclude_unset=True))
    yield '[DONE]'
amyeroberts commented 2 months ago

cc @gante

gante commented 2 months ago

Hi @invokerbyxv 👋

 Following our issues guidelines, we reserve GitHub issues for bugs in the repository and/or feature requests. For any other matters, we'd like to invite you to use our forum or our discord 🤗   Since this is your first issue with us, I'm going to answer your question :)

Stopping text based on what we see in our stream is not possible. However, we can encourage our LLM to avoid the behavior we dislike! In this case, repetitions can be tamed with these two generate flags (which you can also pass to a pipeline):

  1. repetition_penalty=x, which lowers the odds of the model repeating tokens if x>1.0 (the higher the value, the bigger the impact)
  2. no_repeat_ngram_size=n, which forbids the model to repeat ngrams of size n

(docs here, click on Expand parameters)