Closed ciaran-regan-ie closed 2 months ago
cc @gante too for the cache
Related, but not sure if this should be a separate issue. The problem is actually slightly more general than what you've described. For example, the forward
function of LlamaAttention
contains no reference to use_cache
.
This has the following consequences:
past_key_value
is passed to LlamaAttention
, the issue stated above will occur (i.e., the cache will always be updated)past_key_value
is not passed (i.e., past_key_value=None
), the returned past_key_value
will also always be None
, regardless of the value of use_cache
Also FYI, it appears this issue has existed since the new cache structure was introduced in v4.36
, the correct behaviour existed in prior versions when the tuples were used, e.g., here.
Hi @ciaran-regan-ie (and @nickfraser )👋 Thank you for opening the issue and elaborating on the problem!
Before taking your comments and projects into consideration, let me share my view (and the context behind some changes in v4.36
). We moved towards having a custom object to store the cache, the Cache
classes. Many use cases, such as using StaticCache
for torch.compile
, require passing an instantiated object, even if its actual contents are empty. As such, use_cache
lost some of its importance: we now often have the case where we pass an empty cache, which implies use_cache=True
. We have also identified some cases where use_cache
must be False
, such as train time. I dislike implicit/redundant flags, they tend to create problems, so I would love to remove in when (and if) we release transformers v5
.
Our code, therefore, gravitated to its current state where we check it in the core modeling class and create a new cache instance if needed. From that point onwards, we assume that past_key_values is not None
== we want to use cache. Why wouldn't we? We are either passing the cache object or telling the model to create a new one.
From what I'm reading in your comments, my assumption may be incorrect!
generate
. It creates a cache by default and passes it to the model, we are not checking use_cache
. Will open a PR to fix it! 💪 Hi @gante,
Thanks for the detailed reply. Also, please feel free to tell me to open a new issue if that is more appropriate. I understand the new behaviour. In my case, I was calling sub-layers of a Llama-based model directly (with use_cache=True
), for some research work, which causes some strange behaviour. In this case, the cache instantiation code in LlamaModel
is bypassed and LlamaDecoderLayer
returns (torch.Tensor, None)
.
I find this behaviour to be quite unintuitive, but I accept that this is a niche use-case.
@nickfraser If I understand correctly, you were expecting LlamaDecoderLayer
to return (torch.Tensor, Cache)
when use_cache=True
(Cache
being a new instance), correct? I would expect a cache to be returned too, given the input argument name 😅
Shifting the cache instantiation from the inner-most block (prior to v4.36
) to the outer-most block was a hard requirement to enable torch.compile
, but it does conflict with reasonable expectations for use_cache
.
Where's what I'm thinking to do:
use_cache
except in the core modeling class (e.g. LlamaModel
) and classes that use it (e.g. LlamaForCausalLM
). Internal blocks will essentially assume use_cache = past_key_values is not None
. Having multiple places where a new cache can be instantiated will bloat the modeling code and is prone to errors, and I think we can assume that folks that use internal layers are power users and know how to instantiate a cache :)past_key_values
.WDYT? It should make things much cleaner from a user perspective, while being manageable on our end 🤗
(cc @ArthurZucker )
If I understand correctly, you were expecting
LlamaDecoderLayer
to return(torch.Tensor, Cache)
whenuse_cache=True
(Cache
being a new instance), correct?
Yes, exactly.
Shifting the cache instantiation from the inner-most block (prior to v4.36) to the outer-most block was a hard requirement to enable
torch.compile
Makes sense.
WDYT? It should make things much cleaner from a user perspective, while being manageable on our end 🤗
Your suggestion makes a lot of sense to me - sounds great! Thanks for being so amenable too! <3
@gante Thank you so much!
System Info
transformers
version: 4.44.0Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I'm experimenting with shuffling layers in a pre-trained model. The
layer_idx
inside the Attention object makes this difficult as described in this issue. To work around this, I'm settinguse_cache = False
, however, even withuse_cache = False
, an error is occurring aspast_key_value.update
is being called in the Attention forward pass. A simple solution would be to useuse_cache
in the forward pass by adding the followingand
logic:Here is my code to reproduce. The first run through will run because the layers have not switched, but the second run will fail as the cache attempts to update.
Expected behavior
When
use_cache = False
, the cache should not be updating, right?Happy to help with PRs if you feel its necessary!