elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.27k stars 90 forks source link

Support returning token count information #287

Closed brainlid closed 4 months ago

brainlid commented 7 months ago

Various LLMs work with different token lengths. (ex: 4k, 8k, 16k, etc)

When text is generated, it would be helpful to return the submitted token count and the resulting generated token count. This helps us know when we're nearing the max token limit for what the model can do.

We can also use this information to explain that a given text generation stopped because it reached the token limit.

jonatanklosko commented 7 months ago

For debugging purposes you can tokenize the text on your own:

%{"input_ids" => input_ids} = Bumblebee.apply_tokenizer(tokenizer, text)
{1, num_tokens} = Nx.shape(input_ids)

Returning more information from the serving should also be fine, the question is if we should return it when streaming as a last element, otherwise it's inconsistent.

josevalim commented 7 months ago

maybe on a stream we return on every chunk the number of tokens so far?

jonatanklosko commented 7 months ago

If we want to return the input token count too (an perhaps something else in the future), then we would have to repeat it for each chunk too?

josevalim commented 7 months ago

Maybe we can return the input one as metadata? Or is the serving required to return a stream with no metadata on the side?

jonatanklosko commented 7 months ago

We could, but that's limited to input metadata, which doesn't generalize. If there's any information we want to return at the end of generation we still have the same issue.

We could actually have an option to emit last stream element with the same output that non-streaming version would return.

josevalim commented 7 months ago

In this particular case, I would include it in every stream. The input size doesn’t need to come from the GPU, so that’s a Map.put on each stream and it is cheap. For the output, we can show how many tokens so far, which is useful in itself. So I wouldn’t complicate for now, WDYT?

jonatanklosko commented 7 months ago

Duplicating seems weird to me. I'm thinking about the end user having an explicit finish event could actually be useful anyway. Currently we can send a message after the stream halts, but then metadata would need to be aggregated explicitly, like this:

meta =
  Enum.reduce(stream, %{}, fn chunk, acc ->
    send(pid, {:chunk, chunk.text})

    Map.merge(acc, %{
      num_input_tokens: chunk.num_input_tokens,
      num_output_tokens: chunk.num_output_tokens
    })
  end)

send(pid, {:done, meta})

Note that currently we stream strings, so changing to a map is a breaking change, but that's not a big deal at this point.

josevalim commented 7 months ago

I think opting in to a last event with metadata can work, yeah!

brainlid commented 7 months ago

@jonatanklosko a explicit event to indicate it is done/complete is great. That way the caller can differentiate from exceeding a token limit, to a (possible future feature) cancelled stream, or some other error state like "we just stopped received messages".

I like the metadata :done message. Very nice. :+1: