huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.57k stars 26.91k forks source link

GenerationMixin: model_kwargs not passed to model in assisted decoding #25020

Closed sinking-point closed 1 year ago

sinking-point commented 1 year ago

System Info

Who can help?

@gante

Information

Tasks

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("gpt2")
assist = AutoModelForCausalLM.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

inputs = tokenizer("The first rule of fight", return_tensors='pt')

outputs = model.generate(**inputs, token_type_ids=torch.tensor([[0,0,0,0,0]], dtype=torch.long))
print(tokenizer.decode(outputs[0]))

# output: The first rule of fight!!!!!!!!!!!!!!!

outputs = model.generate(**inputs, token_type_ids=torch.tensor([[0,0,0,0,0]], dtype=torch.long), num_beams=1, assistant_model=assist)
print(tokenizer.decode(outputs[0]))

# output: The first rule of fight-or-flight is to be prepared for the enemy. If you are

Expected behavior

I would expect the outputs to be the same for the assisted generation as for the regular generation, as the token_type_ids is being passed into generate in both cases. It is expected that the generate method passes extra arguments to the model via its prepare_inputs_for_generation method.

In fact, the assisted generation does not forward the model_kwargs to the model as the other generation methods do.

sinking-point commented 1 year ago

I'm happy to have a go at fixing this if a maintainer is willing to support.

gante commented 1 year ago

@sinking-point thank you for spotting it! Yes, I'd be very happy to support you in fixing this :D

sinking-point commented 1 year ago

@gante No problem, thank you for offering to support.

I've come up against a problem. This is from the GPT2 prepare_inputs_for_generation method, but I imagine it's the same for many other models:

        # only last token for inputs_ids if past is defined in kwargs
        if past_key_values:
            input_ids = input_ids[:, -1].unsqueeze(-1)
            if token_type_ids is not None:
                token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

It assumes that if past_key_values is given, you only need the last token. In assisted generation, this is not the case, as multiple candidate tokens go in one pass.

Arguably, this is a bug in the implementation of prepare_inputs_for_generation. It would be better to only cut off as many tokens as we have past_key_values. E.g. with 20 past_key_values and 25 tokens given, it should take the last 5 tokens.

I have 2 options:

  1. Fix prepare_inputs_for_generation in all models. This seems like it could be a lot of work, so I'm not sure I can take that on alone.
  2. Modify the output from prepare_inputs_for_generation in assisted_decoding to correct the input_ids. This would be easier, but it removes control of this process from the models. It also may be insufficient, as models may create other kwargs in prepare_inputs_for_generation to match the shape of input_ids.

What do you think?

sinking-point commented 1 year ago

I propose to implement a prepare_inputs_for_assisted_generation method in GenerationMixin.

It will call the prepare_inputs_for_generation method and modify the input_ids in the returned dict to the correct number of candidate tokens.

Models can then override this if they need to implement custom logic.

gante commented 1 year ago

Hey @sinking-point 👋 I appreciate your bias for action with #25135, but I'd like to propose a different route. A route that would benefit us all in the long run and implies a shorter PR :)

With your solution in #25135, we have a new function to maintain. From experience, different models will eventually make us add conditional branches to accumulate all expected input flavors -> it will be a burden (and a mess) in the long run 😞

You mentioned an alternative plan, fixing the existing prepare_inputs_for_generation to detect how many new tokens there are. In the long run, this is a much better route -- no additional maintenance burden and may unblock future applications with similar problems. However, fixing all models takes a very long time (and may not be the best use of our time, as some models are not used with assisted generation). So... let's modify it for the model you are using now, and raise an exception with instructions regarding how to enable other models :) I've successfully used this strategy in the past, e.g. here.

Would you be up for it?

sinking-point commented 1 year ago

Hi @gante . Thanks for taking a look, but I don't agree with your assessment here.

The issue with your suggestion is it would break assisted generation for models it currently works with. This would be a regression of functionality, and could break people's code.

The prepare_inputs_for_assisted_generation default implementation is intended to work for most, but not necessarily all models. If a new model is added that it doesn't work with, the model can override this method (as with prepare_inputs_for_generation). This avoids the need for adding conditional branches to the default implementation.

gante commented 1 year ago

The issue with your suggestion is it would break assisted generation for models it currently works with. This would be a regression of functionality, and could break people's code.

How so? input ids length = Attention mask length - past KV length, which would be true in all generation methods.

sinking-point commented 1 year ago

Maybe I'm misunderstanding you, but isn't your suggestion to:

  1. Make prepare_inputs_for_generation compatible with assisted generation in the models I need only
  2. Raise an exception when anyone tries to use assisted generation with other models?

Currently, most models work with assisted generation. After implementing your suggestion, they would not.

gante commented 1 year ago

I see, you are correct, if we change the code to use prepare_inputs_for_generation instead of manual input preparation, then the models that don't update this function will fail with assisted generation because the function only prepares one token at a time. In other words, we have to update them all.

Still, I'm very biased toward updating them all, it is a much wiser long-term solution and it is not that much more work -- all variations of assisted generation/speculative decoding will need it. It is more work to you (if you still want to implement it), but this sort of choice is critical to ensure we can keep maintaining transformers 🤗

sinking-point commented 1 year ago

I don't want to go through 170+ models and fix them manually one by one.

I'm hoping they're similar enough that I can script it. I'll give that a go.

sinking-point commented 1 year ago

If I'm honest though, I still disagree with you that this is a more maintainable approach.

The reason this repetitive effort is necessary is that the logic is reapeated for every model rather than being implemented in the mixin.

If the logic in my PR needs to be changed, you just have to change it once, in one place (the mixin). Your concern regarding an eventual need for conditional branches is addressed by the ability of models to override the function, implementing their own logic only if they need to rather than every single time.

If I change all the prepare_inputs_for_generation functions individually and then the logic needs to be changed again, someone will have to go through and update all the models again.

If we're optimising for future dev time, we should focus on hoisting logic from the models to the mixin when the opportunity presents itself, in my opinion.

Is there anyone who can chime in to give a third opinion?

gante commented 1 year ago

The reason this repetitive effort is necessary is that the logic is reapeated for every model rather than being implemented in the mixin.

The reason the logic is repeated is a core principle of our design philosophy -- https://huggingface.co/docs/transformers/philosophy. This philosophy is one of the reasons transformers is so successful.

You are saying that we can code the wrapper once in the mixin and then overwrite it on a per-model basis... so pretty much the same as updating prepare_inputs_for_generation, but with extra steps and additional encapsulation. This is precisely why I want to avoid going this route.

As the main developer and maintainer of everything generate-related, I can assure you your suggestion is worse in the long run. Generalist functions containing functionality that is strongly model-dependent are the main reason why generate is so hard to develop at the moment, their complexity grows very quickly.

To wrap up: if we end up going in this direction, there has to be a much stronger reason than saving an hour or two of work.

Is there anyone who can chime in to give a third opinion?

Feel free to ping others, but ultimately it's me who you have to convince :)

sinking-point commented 1 year ago

I hope I didn't come across as trying to undermine your authority. I just find that when there's a disagreement between two people, a third perspective can help to form a better consensus. If you agree, you would know better than me who to tag.

You are saying that we can code the wrapper once in the mixin and then overwrite it on a per-model basis... so pretty much the same as updating prepare_inputs_for_generation, but with extra steps and additional encapsulation. This is precisely why I want to avoid going this route.

It's not the same. With my solution, in most cases the default implementation would suffice and there would be no need to override it. In fact, as it stands the tests pass for all models - none of them need to override the method. I'm just saying that in the event that, as you fear, you would have to add a conditional branch to the default implementation, you could instead override it in the model.

I don't think we have any fundamental disagreement on design philosophy. At the extreme end of the spectrum, you could do away with GenerationUtils and implement it all in every model. I think we can agree that to take the 'repeat yourself' philosophy to that extent is impractical. All we disagree on is where to draw the line.

That said, since you're the one who will have to deal with the consequences of whatever approach we take, I'm willing to defer to your preference.

gante commented 1 year ago

I hope I didn't come across as trying to undermine your authority. I just find that when there's a disagreement between two people, a third perspective can help to form a better consensus.

Not interpreted as so 🤗 We are internally aligned that generate consists of too many nested calls and that adding generalist functions on model-dependent parts is a recipe for chaos, hence my assertive comment. I hope this doesn't come across as downplaying your comments and suggestions -- since we bear the load of maintenance, sometimes we have to say no to seemingly good suggestions, using our past experience as a guide.

All we disagree on is where to draw the line.

Precisely :)

That said, since you're the one who will have to deal with the consequences of whatever approach we take, I'm willing to defer to your preference.

Thank you for being understanding 🤗 Let me know if I can help in any way!

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

sinking-point commented 1 year ago

This should not be closed yet. It should be closed when https://github.com/huggingface/transformers/pull/25242 is merged.