elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.27k stars 90 forks source link

Halting Nx Serving streams with a stop token #288

Closed zblanco closed 7 months ago

zblanco commented 7 months ago

The ChatML format terminates a message with <|im_end|>; other LLMs have similar kinds of terminating strings.

When streaming results using Bumblebee / Nx Serving I'd want the serving to stop providing tokens when a token like this is produced so it doesn't do more work than necessary. This probably looks like passing in a stop_token: "<|im_end|>" or function to detect when to halt somewhere like Bumblebee.Text.generation/4.

jonatanklosko commented 7 months ago

What is the HF checkpoint you are using?

zblanco commented 7 months ago

{:hf, "teknium/OpenHermes-2-Mistral-7B"} or {:hf, "ehartford/dolphin-2.2-mistral-7b"} with similar results where, when streaming, a stop token like <|im_end|> would occur and the generation would continue producing a user prompt and answer and so on till max_new_tokens was reached.

Example serving:

def serving() do
    repo = {:hf, "teknium/OpenHermes-2-Mistral-7B"}

    {:ok, model_info} = Bumblebee.load_model(repo, backend: {EXLA.Backend, client: :host})
    {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "mistralai/Mistral-7B-v0.1"})

    {:ok, generation_config} =
      Bumblebee.load_generation_config(repo, spec_module: Bumblebee.Text.Mistral)

    generation_config =
      Bumblebee.configure(generation_config,
        max_new_tokens: 600,
        strategy: %{type: :multinomial_sampling, top_p: 0.14, top_k: 49}
      )

    Bumblebee.Text.generation(model_info, tokenizer, generation_config,
      compile: [batch_size: 1, sequence_length: 1028],
      stream: true,
      defn_options: [compiler: EXLA, preallocate_params: true]
    )
  end
jonatanklosko commented 7 months ago

This is mostly a checkpoint issue. Generally generation checkpoints have generation_config.json with more or less generation options, but they should at least have "eos_token_id", which exactly determines the token to stop at. Looking at teknium/OpenHermes-2-Mistral-7B it doesn't have generation_config.json at all. So in this case you need to set the token yourself with Bumblebee.configure(generation_config, eos_token_id: 32000).

zblanco commented 7 months ago

Oh that makes a lot of sense. Things are working as intended now without the workaround I had in place which was letting the generation stream run without using those last tokens.