huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.9k stars 26.27k forks source link

past_key_values not accepted in generate with GPTNeoX #20347

Closed ValeKnappich closed 1 year ago

ValeKnappich commented 1 year ago

System Info

Python 3.7.13 transformers 4.22.2

Who can help?

@LysandreJik @patrickvonplaten

Information

Tasks

Reproduction

The past_key_values kwarg is not accepted when calling model.generate(..., past_key_values=pkv) on a GPTNeoxForCausalLM, even though the model.forward does accept this kwarg. It does seem to work fine with other model classes like GPT2.

Minimal example to reproduce error:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import transformers

model_id = "NinedayWang/PolyCoder-160M" # small model with GPTNeoXForCausalLM class
model = AutoModelForCausalLM.from_pretrained(model_id)
tok = AutoTokenizer.from_pretrained(model_id)
assert isinstance(model, transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM)
pkv = torch.rand(
    (
        1,      # batch size      
        10,    # number of tokens
        2 * model.config.num_hidden_layers, 
        model.config.num_attention_heads, 
        model.config.hidden_size // model.config.num_attention_heads
    )
)
out = model.generate(**tok("Hello world"), past_key_values=pkv)

Error message:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/st/st_us-052400/st_st175337/conda/envs/thesis/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/st/st_us-052400/st_st175337/conda/envs/thesis/lib/python3.7/site-packages/transformers/generation_utils.py", line 1146, in generate
    self._validate_model_kwargs(model_kwargs.copy())
  File "/home/st/st_us-052400/st_st175337/conda/envs/thesis/lib/python3.7/site-packages/transformers/generation_utils.py", line 862, in _validate_model_kwargs
    f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
ValueError: The following `model_kwargs` are not used by the model: ['past_key_values'] (note: typos in the generate arguments will also show up in this list)

I checked the error location and located the bug ("transformers/generation_utils.py", line 862, in _validate_model_kwargs):

        unused_model_args = []
        model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
        # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
        # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
        if "kwargs" in model_args:
            model_args |= set(inspect.signature(self.forward).parameters)
        for key, value in model_kwargs.items():
            if value is not None and key not in model_args:
                unused_model_args.append(key)

        if unused_model_args:
            raise ValueError(
                f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
                " generate arguments will also show up in this list)"
            )

It first checks the args of prepare_inputs_for_generation and only adds the args of forward to the accepted list if "kwargs" is in the args of prepare_inputs_for_generation. However, contrary to GPT2, it only contains model_kwargs instead of kwargs for GPTNeox.

So either the GPTNeoX class should be adapted, or the _validate_model_kwargs method in generation_utils.py.

Expected behavior

generate should be able to pass along all valid model_kwargs

sgugger commented 1 year ago

cc @gante

gante commented 1 year ago

Hey @ValeKnappich 👋

Yeah, model_kwargs needs to be added to _validate_model_kwargs. I'm on it :)

ValeKnappich commented 1 year ago

Great, thanks :)

ValeKnappich commented 1 year ago

@gante @sgugger

The kwarg validation was only a superficial issue. In fact, now it does not throw an error anymore, however, the past_key_values are still not passed on to the forward method. Looks like the prepare_inputs_for_generation method is at the core of the problem:

    def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
        input_shape = input_ids.shape

        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_shape)

        # cut decoder_input_ids if past is used
        if past and past[0] is not None:
            input_ids = input_ids[:, -1:]

        return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}

Note that model_kwargs is simply swallowed here. I will create a PR shortly

patrickvonplaten commented 1 year ago

@gante @ArthurZucker I think we should rename all occurrences of "past" to "past_key_values" in prepare_inputs_for_generation and deprecate "past" if necessary.

"past" was simply the name for the past key values states before we renamed everything to past_key_values, so this is just a left-over.

ArthurZucker commented 1 year ago

Agreed

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.