huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.71k stars 26.22k forks source link

Add support for prefix_allowed_tokens_fn to maintain a state throughout decoding #28935

Open John-Boccio opened 7 months ago

John-Boccio commented 7 months ago

Feature request

Add an optional argument in prefix_allow_tokens_fn to allow for state to maintained throughout decoding or add a stateful alternative to prefix_allowed_tokens_fn.

Motivation

prefix_allowed_tokens_fn is great but has one major downfall which is that you cannot maintain a state throughout decoding. This is inefficient because at each step you must go through your past inputIds, build up your current "state", and then figure out which tokens are allowed to appear next.

Instead, there should be a class we can subclass that gets passed the next token ID at each step of decoding (Constraint does not achieve this as update does not get every token ID). For example if you are trying to create a function to output json format (https://gist.github.com/BorisTheBrave/969f303a082c9da1916d04ee1eb04452), then you could track where you currently on in the json as each token ID is being received instead of going through everything on each new token.

Your contribution

Unfortunately can't make a PR.

amyeroberts commented 7 months ago

cc @gante

gante commented 6 months ago

Hi @John-Boccio 👋

We don't need transformers changes to enable your request :D You can parameterize an arbitrary function with a mutable input and use that mutable input as the state in your prefix_allowed_tokens_fn.

Here's an example of how to create a function with state:

from functools import partial

# A function with a `state` input, assumed to be a dictionary
def some_fn(foo, state=None):
    if state is None or not isinstance(state, dict):
        raise ValueError('`state` must be provided as a dictionary.')
    else:
        if 'bar' in state:
            state['bar'] += 1
        else:
            state['bar'] = 1
    return foo + state['bar']

# partial() allows us to create a new function from `some_fn` with a fixed value for `state`.
# Because the `state` input is mutable, the new function will keep track of the changes to `state`.
parameterized_fn = partial(some_fn, state={'bar': 0})

print(parameterized_fn(0))  # 1
print(parameterized_fn(0))  # 2
print(parameterized_fn(0))  # 3
John-Boccio commented 6 months ago

Hi @gante ! Thank you for the suggestion. I actually had a similar idea and it does work well but has one catch - you must be performing greedy decoding. As soon as you add more than 1 beam, then all the beams will be sharing the objects passed into the partial function (i.e. all beams share the same state with no way to distinguish which beam you're operating on currently).

I think there will have to be some sort of new parameter to generate along the lines of prefix_allowed_tokens_cls which allows you to pass in a class that should be created for each beam that is used during generation.

gante commented 6 months ago

For beam search to track specific beams, you would have to change a few things indeed -- including the API of the LogitsProcessors, to pass the previous beam indices so it could be passed to prefix_allowed_tokens_fn.

This falls outside the scope of what we want to support in transformers, at least for now 🤗 My suggestion would be to fork the library and change the generation loop to your needs :)