Open gardberg opened 4 weeks ago
Hmm you're right, that breaks BC. But also since we are going to get rid of past_key_value[0].shape[2]
in favor of cache_position
and since cache_position
is not an arg tied to whether caching is enables or not, I thinkg it is better if you prepare and pass your own cache_position
to the attention module
cache position
is almost same as position ids, with a difference that pad tokens are also counted. So it should simply be arange
of the input length if no caching is used cache_position = torch.arange(input_seq_length)
. In the subsequent calls the cache_position
can be either cache_position = torch.arange(new_input_seq_length)
when no caching is done or cache_position = prev_cache_position[-1:] + 1
I see, do you mean that you are going to deprecate the use of the
If I understand correctly, you mean that the past_key_value
parameter?cache_position
variable is being used to update the past_key_value
, sorry for the confusion.
Thanks for the explanation on cache_position
, I'll take a look and see if I can use it to get my code working.
In my opinion, if the usage of cache_position
is needed to get the forward pass to work, it should not be an optional parameter. Alternatively, an informative error should be thrown detailing how to call forward
properly. I'm not sure of the history of having a lot of optional parameters be None
in most methods in transformers
, but it makes it quite confusing to read and use the code.
I'd be happy to write a draft fix for this and set up a PR, but unfortunately I'm getting showered in dependency errors when trying to set up a dev environment according to the contribution guide. I might take a second look at this later.
Thanks again for the feedback 👍
I see, do you mean that you are going to deprecate the use of the past_key_value parameter?
No, it is more of a new parameter that will be used to track the actual length of cache. Since we now have different varieties of cache like Static or Offloaded, it is easier to track lengths in a separate tensors.
The arguments actually should be there if T5 is initialized as T5Model but I guess you are trying to overwrite a few modules for your own architecture. Yeah, in that case it will be compulsory to prepare all inputs manually. I agree that setting them to None
can be misleading and we don't have much documentation about cache_position
yet except for the model docstring. Also we now have a small doc page with general about new cache format here
Let me know what would be a better way to document these
I see, thanks for the info. I'll take a look at that doc page.
One idea is to add type hinting. I see that has been done for some methods, but I think a wider use would really improve readability, e.g. cache_position: Optional[torch.LongTensor] = None
.
It is even possible to add information about the shape of the tensor to the type (might be a bit of a hack):
class LongArray(LongTensor, Generic[TypeVar("Shape")]):
...
cache_position: Optional[LongArray["seq_len"]] = None
Just an idea! Nice that you have created a doc page though, that is great.
FYI, I've been getting this error since updating to version 4.46
. It didn’t occur in 4.45
.
transformers/models/t5/modeling_t5.py", line 525, in forward
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
TypeError: 'NoneType' object is not subscriptable
@mimbres can you provide a reproducer? Are you calling also calling only the attention module with custom logic?
Started a draft PR :) https://github.com/huggingface/transformers/pull/34621
@zucchini-nlp My case is identical to @gardberg's with T5Attention. I can confirm that version 4.45 does not have the issue.
System Info
transformers
version: 4.46.0Who can help?
@zucchini-nlp
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Problem: T5Attention forward pass fails when not using KV cache.
Caused by
cache_position
beingNone
here. @zucchini-nlpCode to reproduce:
Stack trace
{ "name": "TypeError", "message": "'NoneType' object is not subscriptable", "stack": "--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[1], line 29 27 torch_t5_mha = T5Attention(t5_config, has_relative_attention_bias=True).eval() 28 with torch.no_grad(): ---> 29 attn_out, kv_state, pos_bias = torch_t5_mha(xq) File ~/dev/attention/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File ~/dev/attention/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File ~/dev/attention/.venv/lib/python3.11/site-packages/transformers/models/t5/modeling_t5.py:525, in T5Attention.forward(self, hidden_states, mask, key_value_states, position_bias, past_key_value, layer_head_mask, query_length, use_cache, output_attentions, cache_position) 523 key_length = key_states.shape[-2] 524 # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) --> 525 real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 526 if not self.has_relative_attention_bias: 527 position_bias = torch.zeros( 528 (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype 529 ) TypeError: 'NoneType' object is not subscriptable" }
Expected behavior
T5Attention forward pass should not fail when not using the KV cache, and the
use_cache
flag should actually affect if the cache is used or not, which it currently doesnt