tomaarsen / attention_sinks

Extend existing LLMs way beyond the original training length with constant memory usage, without retraining
https://huggingface.co/blog/tomaarsen/attention-sinks
Apache License 2.0
650 stars 41 forks source link

Error when using Qwen 7b chat #36

Open Minami-su opened 7 months ago

Minami-su commented 7 months ago
import torch
from transformers import AutoTokenizer, TextStreamer, GenerationConfig,AutoConfig
from attention_sinks import AutoModelForCausalLM

model_id = "Qwen-7B-Chat"
# Note: instruct or chat models also work.
#config = AutoConfig.from_pretrained(model_id) 
#config.seq_length = 1_000_000
#config.max_position_embeddings = 1_000_000
#config.max_seq_len = 8192 
# Load the chosen model and corresponding tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # for efficiency:
    device_map="auto",
#config=config,
    torch_dtype=torch.float16,
    # `attention_sinks`-specific arguments:
    attention_sink_size=4,
    attention_sink_window_size=252,trust_remote_code=True # <- Low for the sake of faster generation
)

model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id,trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eod_id

# Our input text
text = "Vaswani et al. (2017) introduced the Transformers"

# Encode the text
input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)

with torch.no_grad():
    # A TextStreamer prints tokens as they're being generated
    streamer = TextStreamer(tokenizer)
    generated_tokens = model.generate(
        input_ids,
        generation_config=GenerationConfig(
            # use_cache=True is required, the rest can be changed up.
            use_cache=True,
            min_new_tokens=100_000,
            max_new_tokens=1_000_000,
            penalty_alpha=0.6,
            top_k=5,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        ),
        streamer=streamer,
    )
    # Decode the final generated text
    output_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

error:

The model is automatically converting to fp16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...
yWarning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 8/8 [00:07<00:00,  1.08it/s]
[Attention Sinks] Injected Position Shifting into 32 attention classes.
[Attention Sinks] Injected Attention Sink KV Cache into 1 model class.
Vaswani et al. (2017) introduced the Traceback (most recent call last):
  File "/home/luhao/test.py", line 33, in <module>
    generated_tokens = model.generate(
  File "/root/.cache/huggingface/modules/transformers_modules/Qwen-7B-Chat2/modeling_qwen.py", line 1261, in generate
    return super().generate(
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/transformers/generation/utils.py", line 1623, in generate
    return self.contrastive_search(
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/transformers/generation/utils.py", line 2132, in contrastive_search
    outputs = self(
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/Qwen-7B-Chat2/modeling_qwen.py", line 1045, in forward
    transformer_outputs = self.transformer(
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/attention_sinks/inject_mixin.py", line 140, in wrapped_forward
    outputs = old_forward(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/Qwen-7B-Chat2/modeling_qwen.py", line 893, in forward
    outputs = block(
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/Qwen-7B-Chat2/modeling_qwen.py", line 612, in forward
    attn_outputs = self.attn(
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/attention_sinks/models/qwen/pos_shift.py", line 100, in qwen_pos_shift_attention_forward
    attn_output, attn_weight = self._attn(
  File "/root/.cache/huggingface/modules/transformers_modules/Qwen-7B-Chat2/modeling_qwen.py", line 352, in _attn
    attn_weights = torch.where(
RuntimeError: The size of tensor a (33) must match the size of tensor b (17) at non-singleton dimension 2
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [6,0,0], thread: [96,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [6,0,0], thread: [97,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [6,0,0], thread: [98,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [6,0,0], thread: [99,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [6,0,0], thread: [100,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [6,0,0], thread: [101,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [6,0,0], thread: [102,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [6,0,0], thread: [103,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [6,0,0], thread: [104,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [6,0,0], thread: [105,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [6,0,0], thread: [106,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [6,0,0], thread: [107,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Minami-su commented 7 months ago

And then I try this:pip install git+https://github.com/tomaarsen/attention_sinks.git@model/qwen_fa error happen:

The repository for Qwen-7B-Chat2 contains custom code which must be executed to correctlyload the model. You can inspect the repository content at https://hf.co/Qwen-7B-Chat2.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y
The model is automatically converting to fp16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 8/8 [00:07<00:00,  1.06it/s]
[Attention Sinks] Injected Position Shifting into 32 attention classes.
[Attention Sinks] Injected Attention Sink KV Cache into 1 model class.
Vaswani et al. (2017) introduced the Traceback (most recent call last):
  File "/home/luhao/test.py", line 36, in <module>
    generated_tokens = model.generate(
  File "/root/.cache/huggingface/modules/transformers_modules/Qwen-7B-Chat2/modeling_qwen.py", line 1261, in generate
    return super().generate(
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/transformers/generation/utils.py", line 1623, in generate
    return self.contrastive_search(
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/transformers/generation/utils.py", line 2007, in contrastive_search
    outputs = self(
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/Qwen-7B-Chat2/modeling_qwen.py", line 1045, in forward
    transformer_outputs = self.transformer(
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/attention_sinks/inject_mixin.py", line 131, in wrapped_forward
    outputs = old_forward(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/Qwen-7B-Chat2/modeling_qwen.py", line 893, in forward
    outputs = block(
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/Qwen-7B-Chat2/modeling_qwen.py", line 612, in forward
    attn_outputs = self.attn(
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anaconda3/envs/train/lib/python3.9/site-packages/attention_sinks/models/qwen/pos_shift.py", line 217, in qwen_pos_shift_attention_forward
    causal_mask = registered_causal_mask[:, :, key.size(-2) - query.size(-2) : key.size(-2), : key.size(-2)]
TypeError: 'NoneType' object is not subscriptable