datamllab / LongLM

[ICML'24 Spotlight] LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning
https://arxiv.org/pdf/2401.01325.pdf
MIT License
608 stars 61 forks source link

LongLM isn't compatible with gemma-2-27b-it or gemma-2b-it #46

Closed uebian closed 2 months ago

uebian commented 3 months ago

I found that the current version of LongLM can not load Gemma 1 or Gemma 2 model successfully. I wrote a minimum test to help reproduce the issue:

# transfromers version 4.38.2
# this example is tested with 4 RTX3090s, 24GB memory each
import warnings
warnings.filterwarnings("ignore")

import torch 
import json
import time
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

import SelfExtend 

window_size = 1024
group_size = 32

model_name = '/tmp/gemma-2b-it/'
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
SelfExtend.apply(model, group_size, window_size, enable_flash_attention=False)
prompt = "How are you?"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

start_time = time.time()
tokens = model.generate(input_ids, max_new_tokens=4096)
answer = tokenizer.decode(tokens[0].tolist()[input_ids.shape[1]:], skip_special_tokens=True)
print( answer )

While trying to load the model, it fails with the error message below:

$ python3 test.py 
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.07it/s]
Traceback (most recent call last):
  File "/var/lib/condor/execute/slot1/dir_2652801/test.py", line 22, in <module>
    SelfExtend.apply(model, group_size, window_size, enable_flash_attention=False)
  File "/var/lib/condor/execute/slot1/dir_2652801/SelfExtend.py", line 160, in apply
    raise Exception(f"Failed to modify the attention method of {arch_name}")
Exception: Failed to modify the attention method of GemmaForCausalLM

I found that it fails in the duplicate check in the L24 of SelfExtend.py. When it fails, instance = False.

Below is a conda env export dump including package details in my Python environment:

channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - bzip2=1.0.8=h5eee18b_6
  - ca-certificates=2024.7.2=h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_1
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - ncurses=6.4=h6a678d5_0
  - openssl=3.0.14=h5eee18b_0
  - pip=24.0=py310h06a4308_0
  - python=3.10.14=h955ad1f_1
  - readline=8.2=h5eee18b_0
  - setuptools=69.5.1=py310h06a4308_0
  - sqlite=3.45.3=h5eee18b_0
  - tk=8.6.14=h39e8969_0
  - wheel=0.43.0=py310h06a4308_0
  - xz=5.4.6=h5eee18b_1
  - zlib=1.2.13=h5eee18b_1
  - pip:
      - accelerate==0.33.0
      - aiohttp==3.9.5
      - aiosignal==1.3.1
      - annotated-types==0.7.0
      - anyio==4.4.0
      - async-timeout==4.0.3
      - attrs==23.2.0
      - certifi==2024.7.4
      - charset-normalizer==3.3.2
      - click==8.1.7
      - cloudpickle==3.0.0
      - cmake==3.30.1
      - datasets==2.20.0
      - dill==0.3.8
      - diskcache==5.6.3
      - distro==1.9.0
      - dnspython==2.6.1
      - einops==0.8.0
      - email-validator==2.2.0
      - exceptiongroup==1.2.2
      - fastapi==0.111.1
      - fastapi-cli==0.0.4
      - filelock==3.15.4
      - flash-attn==2.6.3
      - frozenlist==1.4.1
      - fsspec==2024.5.0
      - h11==0.14.0
      - httpcore==1.0.5
      - httptools==0.6.1
      - httpx==0.27.0
      - huggingface-hub==0.24.2
      - idna==3.7
      - interegular==0.3.3
      - jinja2==3.1.4
      - jsonschema==4.23.0
      - jsonschema-specifications==2023.12.1
      - lark==1.1.9
      - llvmlite==0.43.0
      - lm-format-enforcer==0.10.3
      - markdown-it-py==3.0.0
      - markupsafe==2.1.5
      - mdurl==0.1.2
      - mpmath==1.3.0
      - msgpack==1.0.8
      - multidict==6.0.5
      - multiprocess==0.70.16
      - nest-asyncio==1.6.0
      - networkx==3.3
      - ninja==1.11.1.1
      - numba==0.60.0
      - numpy==1.26.4
      - nvidia-cublas-cu12==12.1.3.1
      - nvidia-cuda-cupti-cu12==12.1.105
      - nvidia-cuda-nvrtc-cu12==12.1.105
      - nvidia-cuda-runtime-cu12==12.1.105
      - nvidia-cudnn-cu12==8.9.2.26
      - nvidia-cufft-cu12==11.0.2.54
      - nvidia-curand-cu12==10.3.2.106
      - nvidia-cusolver-cu12==11.4.5.107
      - nvidia-cusparse-cu12==12.1.0.106
      - nvidia-ml-py==12.555.43
      - nvidia-nccl-cu12==2.20.5
      - nvidia-nvjitlink-cu12==12.5.82
      - nvidia-nvtx-cu12==12.1.105
      - openai==1.37.1
      - outlines==0.0.46
      - packaging==24.1
      - pandas==2.2.2
      - pillow==10.4.0
      - prometheus-client==0.20.0
      - prometheus-fastapi-instrumentator==7.0.0
      - protobuf==5.27.2
      - psutil==6.0.0
      - py-cpuinfo==9.0.0
      - pyairports==2.1.1
      - pyarrow==17.0.0
      - pyarrow-hotfix==0.6
      - pycountry==24.6.1
      - pydantic==2.8.2
      - pydantic-core==2.20.1
      - pygments==2.18.0
      - python-dateutil==2.9.0.post0
      - python-dotenv==1.0.1
      - python-multipart==0.0.9
      - pytz==2024.1
      - pyyaml==6.0.1
      - pyzmq==26.0.3
      - ray==2.33.0
      - referencing==0.35.1
      - regex==2024.7.24
      - requests==2.32.3
      - rich==13.7.1
      - rpds-py==0.19.1
      - safetensors==0.4.3
      - sentencepiece==0.2.0
      - shellingham==1.5.4
      - six==1.16.0
      - sniffio==1.3.1
      - starlette==0.37.2
      - sympy==1.13.1
      - tiktoken==0.7.0
      - tokenizers==0.19.1
      - torch==2.3.1
      - torchvision==0.18.1
      - tqdm==4.66.4
      - transformers==4.43.3
      - triton==2.3.1
      - typer==0.12.3
      - typing-extensions==4.12.2
      - tzdata==2024.1
      - urllib3==2.2.2
      - uvicorn==0.30.3
      - uvloop==0.19.0
      - vllm==0.5.3.post1
      - vllm-flash-attn==2.5.9.post1
      - watchfiles==0.22.0
      - websockets==12.0
      - xformers==0.0.27
      - xxhash==3.4.1
      - yarl==1.9.4
Mooler0410 commented 3 months ago

Hi! Thanks for your interest!

Could you please try to print out the loaded models' architecture? The modification failure is triggered by cannot find the target module. In your case, for gemma-2, Gemma2ForCausalLM should be modified rather than GemmaForCausalLM. They have different classes in hugging transformers. We haven't implement SelfExtend for gemma-2 yet.

Another possible problem is : almost every version of hugging transformers has some changes to the {Model_name}ForCausalLM class. We will check the newest implementation for Gemma in hugging transformers and release a new one if needed.

uebian commented 3 months ago

Hi, thank you for getting back. I'm using gemma-2b-it, which is kind of gemma 1 model instead of gemma 2. That model can be downloaded from https://huggingface.co/google/gemma-2b-it

Mooler0410 commented 3 months ago

Sorry for the oversight. Could you please share the output of: print(loaded_model)? This should print out the name of all modules in the loaded model.

uebian commented 3 months ago

Sorry for the delayed response, I have printed the information that might be helpful to figure out the issue: Code:

import warnings
warnings.filterwarnings("ignore")

import torch
import json
import time
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

import SelfExtend

window_size = 1024
group_size = 32

model_id = '/tmp/gemma-2b-it/'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
print(model)

Output:

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaRMSNorm()
  )
  (lm_head): Linear(in_features=2048, out_features=256000, bias=False)
)
uebian commented 2 months ago

I also found that CodeLlama can not be loaded. Model structure:

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32016, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32016, bias=False)
)

Error log:

Traceback (most recent call last):
  File "/home/ubuntu/LongLM/test.py", line 27, in <module>
    SelfExtend.apply(model, group_size, window_size)
  File "/home/ubuntu/LongLM/SelfExtend.py", line 123, in apply
    raise Exception(f"Failed to modify the attention method of {arch_name}")
Exception: Failed to modify the attention method of LlamaForCausalLM
Mooler0410 commented 2 months ago

Seems the modification failure is caused by the change of default attention module. The modification function assumes that the default attention module is "LlamaAttention"/"GemmaAttention", however, it's actually "LlamaSdpaAttention"/"GemmaSdpaAttention". You may refer: https://github.com/datamllab/LongLM/issues/23#issuecomment-1986716092

uebian commented 2 months ago

Yes, by replacing "LlamaAttention" with "LlamaSdpaAttention", it works. Thank you very much.

FYI: below is the patch I applied:

diff --git a/SelfExtend.py b/SelfExtend.py
index 8f294fa..2aee66d 100644
--- a/SelfExtend.py
+++ b/SelfExtend.py
@@ -116,9 +116,9 @@ def apply(loaded_model, group_size, window_size, enable_flash_attention=False, s
                                             group_size_1=group_size, 
                                             group_size_2=window_size,
                                             scale_base=scale_base)
-            # after the default version of attention in 4.36 is LlamaSpdaAttention, but in before 4,36 or in 4.38, it is LlamaAttention
+            # after the default version of attention in 4.36 is LlamaSdpaAttention, but in before 4,36 or in 4.38, it is LlamaAttention
             # print("loaded_model", loaded_model)
-            modifed_2 = modify_method_of_instance(loaded_model, "LlamaAttention", "forward", self_extend_attention_forward)
+            modifed_2 = modify_method_of_instance(loaded_model, "LlamaSdpaAttention", "forward", self_extend_attention_forward)
             if not modifed_2:
                 raise Exception(f"Failed to modify the attention method of {arch_name}")
     elif 'Mistral' in arch_name:
piotr25691 commented 2 months ago

for CPU users, use fork from #25

it worked for me