tomaarsen / attention_sinks

Extend existing LLMs way beyond the original training length with constant memory usage, without retraining
https://huggingface.co/blog/tomaarsen/attention-sinks
Apache License 2.0
650 stars 41 forks source link

Experiments with MPT7b with seqlen > 2048 #14

Open vchiley opened 9 months ago

vchiley commented 9 months ago

here

For MPT-7B-chat, a RuntimeError is encountered for transformers when the input length exceeds 2048.

Can you comment on what the RuntimeError was? You have ran mpt7b with seq len > 8k.

If the mpt7b model config has max_seq_len=2048, by design, if seq len exceeds the configured value the model will throw an error. To fix this, simply configure with the a longer seq len.

tomaarsen commented 9 months ago

Hello!

You're very right - well spotted. For the subsequent prompting experiments I indeed didn't increase the seq length to 8k, while I did do this for the perplexity experiments as can be seen here. The RuntimeError is indeed just due to hitting 2048 tokens.

I'll rerun the experiments with the configuration set to 8k or 32k or so, I think that'll be quite interesting.

tomaarsen commented 9 months ago

I've completed the experiments in 7a437d17b5d435635af1c7b8bfda6338deee8c75. When setting the max_seq_len to 8192, the result is extremely poor responses after 2048 tokens for transformers.

I hope that clarifies it!

Nintorac commented 9 months ago

I tried this too, I'm aiming to evaluate embedding quality wrt different window lengths. I am using the hf feature-extraction pipeline.

I believe the issue is that the model is encoding the entire context in a single shot using.

I think there are several approaches that could be used to encode long sequences.

the simplest of which would be to truncate the full context to the window length and then autoregressively feed through the rest of the sequence token by token. Definitely some performance issue with that approach.

From what I can see this problem will exist for all models, not just MPT? Please correct if I'm wrong. Any tips on how to approach implementing a streaming inferences.

Nintorac commented 9 months ago

had a very rough crack, as expected really slow.

also didnt really get a lot of matching in logits between the outputs so maybe some issues.

anyway, probably the wrong place to implement this, would need to be unique per model I guess. hope it helps to get the idea across at least.

        def overwrite_forward(module):
            from tqdm import tqdm
            import torch
            # Create the new cache
            module.attention_sink_kv_cache = AttentionSinkKVCache(**attention_sink_kwargs)

            # Keep track of the old forward method, we need it in the wrapped one
            old_forward = module.forward

            # Wrap the forward by overriding the past_key_values using the cache
            def wrapped_forward(self, *args, **kwargs):
                outputs = old_forward(*args, **kwargs)
                outputs.past_key_values = self.attention_sink_kv_cache(outputs.past_key_values)
                return outputs

            def wrapped_wrapped_forward(self, *args, **kwargs):
                print(kwargs.keys())
                # print('attention_mask', kwargs['attention_mask'])
                # print('position_ids', kwargs['position_ids'])
                # print('head_mask', kwargs['head_mask'])
                x = args[0]
                attention = kwargs['attention_mask']
                window_size = self.attention_sink_kv_cache.attention_sink_window_size
                t = tqdm(total=x.shape[-1])
                while x.shape[-1] > 0:
                    kwargs['use_cache'] = True
                    if kwargs['past_key_values'] is None:
                        x_step = x[:,:window_size]
                        attn_step = attention[:,:window_size]
                        x = x[:,window_size:]
                        attention = attention[:,window_size:]
                    else:
                        x_step = x[:,:1]
                        attn_step = torch.cat([attn_step, attention[:,:1]], -1)
                        attn_step = attn_step[:,-(window_size+self.attention_sink_kv_cache.attention_sink_size+1):]
                        x = x[:,1:]
                        attention = attention[:,1:]
                    t.update(x_step.shape[-1])

                    kwargs['attention_mask'] = attn_step
                    output = wrapped_forward(self, x_step, **kwargs)
                    kwargs['past_key_values'] = output.past_key_values
                output.past_key_values = None
                return output

            module.forward = types.MethodType(wrapped_wrapped_forward, module)

the main idea is here, where we either encode a full window and afterwards step through

                while x.shape[-1] > 0:
                    kwargs['use_cache'] = True
                    if kwargs['past_key_values'] is None:
                        x_step = x[:,:window_size]
                        attn_step = attention[:,:window_size]
                        x = x[:,window_size:]
                        attention = attention[:,window_size:]
                    else:
                        x_step = x[:,:1]
                        attn_step = torch.cat([attn_step, attention[:,:1]], -1)
                        attn_step = attn_step[:,-(window_size+self.attention_sink_kv_cache.attention_sink_size+1):]
                        x = x[:,1:]
                        attention = attention[:,1:]

I think you could improve performance by creating an overlapped window while encoding, something like [x,x,x,x,x,x,o,o,o,o,o,o,o,o,o] where xs are the window and os are new tokens. here the total encoding step length will still be limited to the models context window. so if your sink window=context window then you can at most single step.

Lots of parameters to tweak here, (window_size, encoding_overlap_size) pretty interested to see how embedding quality changes wrt to these.