tomaarsen / attention_sinks

Extend existing LLMs way beyond the original training length with constant memory usage, without retraining
https://huggingface.co/blog/tomaarsen/attention-sinks
Apache License 2.0
650 stars 41 forks source link

TypeError: 'NoneType' object is not subscriptable #43

Open Kuchiriel opened 6 months ago

Kuchiriel commented 6 months ago

Error: TypeError Traceback (most recent call last)

in () 95 if DEVICE == "cuda": 96 with amp.autocast(): ---> 97 result = pipe( 98 prompt, 99 max_new_tokens=1024, 8 frames /usr/local/lib/python3.10/dist-packages/attention_sinks/generation/utils.py in _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, standardize_cache_format) 30 # Only this branch is changed, it's required to stop the attention_mask from extending itself too far 31 attention_mask_size = model_kwargs["attention_mask"].size(-1) ---> 32 past_key_values_size = model_kwargs["past_key_values"][0][0].size(2) 33 if attention_mask_size == past_key_values_size: 34 model_kwargs["attention_mask"] = torch.cat( TypeError: 'NoneType' object is not subscriptable The error is caused by a missing past_key_values argument in the _update_model_kwargs_for_generation function. This argument is required for the attention_sinks model, which uses a different attention mechanism than the standard Transformer model. To fix the error, you can add the past_key_values argument to the function definition and then pass it to the function when you call it. For example: def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, standardize_cache_format): if is_encoder_decoder: model_kwargs["encoder_past"] = outputs["encoder_past"] if self.config.is_attention_sink: model_kwargs["past_key_values"] = outputs["past_key_values"] attention_mask_size = model_kwargs["attention_mask"].size(-1) past_key_values_size = model_kwargs["past_key_values"][0][0].size(2) if attention_mask_size == past_key_values_size: model_kwargs["attention_mask"] = torch.cat( (model_kwargs["attention_mask"], torch.zeros((1, past_key_values_size), dtype=torch.int64)), dim=-1 ) return model_kwargs