huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.39k stars 27.09k forks source link

T5Attention forward pass failing when not using KV cache #34448

Open gardberg opened 4 weeks ago

gardberg commented 4 weeks ago

System Info

Who can help?

@zucchini-nlp

Information

Tasks

Reproduction

Problem: T5Attention forward pass fails when not using KV cache.

Caused by cache_position being None here. @zucchini-nlp

Code to reproduce:

import torch
from transformers.models.t5.modeling_t5 import T5Attention
from transformers.models.t5 import T5Config
import json
from huggingface_hub import hf_hub_download

T5_REPO = "google-t5/t5-small"

BATCH_SIZE = 2
tgt_len = 3
EMBED_SIZE = 512

CONFIG_NAME = "config.json"

t5_config_path = hf_hub_download(repo_id=T5_REPO, filename=CONFIG_NAME)

with open(t5_config_path, "r") as f:
    t5_config = json.load(f)

t5_config = T5Config.from_dict(t5_config)

xq = torch.randn((BATCH_SIZE, tgt_len, EMBED_SIZE))

torch_t5_mha = T5Attention(t5_config).eval()
with torch.no_grad():
    attn_out, kv_state, pos_bias = torch_t5_mha(xq)
Stack trace

{ "name": "TypeError", "message": "'NoneType' object is not subscriptable", "stack": "--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[1], line 29 27 torch_t5_mha = T5Attention(t5_config, has_relative_attention_bias=True).eval() 28 with torch.no_grad(): ---> 29 attn_out, kv_state, pos_bias = torch_t5_mha(xq) File ~/dev/attention/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File ~/dev/attention/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File ~/dev/attention/.venv/lib/python3.11/site-packages/transformers/models/t5/modeling_t5.py:525, in T5Attention.forward(self, hidden_states, mask, key_value_states, position_bias, past_key_value, layer_head_mask, query_length, use_cache, output_attentions, cache_position) 523 key_length = key_states.shape[-2] 524 # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) --> 525 real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 526 if not self.has_relative_attention_bias: 527 position_bias = torch.zeros( 528 (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype 529 ) TypeError: 'NoneType' object is not subscriptable" }

Expected behavior

T5Attention forward pass should not fail when not using the KV cache, and the use_cache flag should actually affect if the cache is used or not, which it currently doesnt

zucchini-nlp commented 3 weeks ago

Hmm you're right, that breaks BC. But also since we are going to get rid of past_key_value[0].shape[2] in favor of cache_position and since cache_position is not an arg tied to whether caching is enables or not, I thinkg it is better if you prepare and pass your own cache_position to the attention module

cache position is almost same as position ids, with a difference that pad tokens are also counted. So it should simply be arange of the input length if no caching is used cache_position = torch.arange(input_seq_length). In the subsequent calls the cache_position can be either cache_position = torch.arange(new_input_seq_length) when no caching is done or cache_position = prev_cache_position[-1:] + 1

gardberg commented 3 weeks ago

I see, do you mean that you are going to deprecate the use of the past_key_value parameter? If I understand correctly, you mean that the cache_position variable is being used to update the past_key_value, sorry for the confusion.

Thanks for the explanation on cache_position, I'll take a look and see if I can use it to get my code working.

In my opinion, if the usage of cache_position is needed to get the forward pass to work, it should not be an optional parameter. Alternatively, an informative error should be thrown detailing how to call forward properly. I'm not sure of the history of having a lot of optional parameters be None in most methods in transformers, but it makes it quite confusing to read and use the code.

I'd be happy to write a draft fix for this and set up a PR, but unfortunately I'm getting showered in dependency errors when trying to set up a dev environment according to the contribution guide. I might take a second look at this later.

Thanks again for the feedback 👍

zucchini-nlp commented 3 weeks ago

I see, do you mean that you are going to deprecate the use of the past_key_value parameter?

No, it is more of a new parameter that will be used to track the actual length of cache. Since we now have different varieties of cache like Static or Offloaded, it is easier to track lengths in a separate tensors.

The arguments actually should be there if T5 is initialized as T5Model but I guess you are trying to overwrite a few modules for your own architecture. Yeah, in that case it will be compulsory to prepare all inputs manually. I agree that setting them to None can be misleading and we don't have much documentation about cache_position yet except for the model docstring. Also we now have a small doc page with general about new cache format here

Let me know what would be a better way to document these

gardberg commented 3 weeks ago

I see, thanks for the info. I'll take a look at that doc page.

One idea is to add type hinting. I see that has been done for some methods, but I think a wider use would really improve readability, e.g. cache_position: Optional[torch.LongTensor] = None.

It is even possible to add information about the shape of the tensor to the type (might be a bit of a hack):

class LongArray(LongTensor, Generic[TypeVar("Shape")]):
    ...

cache_position: Optional[LongArray["seq_len"]] = None

Just an idea! Nice that you have created a doc page though, that is great.

mimbres commented 2 weeks ago

FYI, I've been getting this error since updating to version 4.46. It didn’t occur in 4.45.

transformers/models/t5/modeling_t5.py", line 525, in forward                                                             
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
TypeError: 'NoneType' object is not subscriptable
zucchini-nlp commented 2 weeks ago

@mimbres can you provide a reproducer? Are you calling also calling only the attention module with custom logic?

gardberg commented 2 weeks ago

Started a draft PR :) https://github.com/huggingface/transformers/pull/34621

mimbres commented 2 weeks ago

@zucchini-nlp My case is identical to @gardberg's with T5Attention. I can confirm that version 4.45 does not have the issue.