abertsch72 / unlimiformer

Public repo for the NeurIPS 2023 paper "Unlimiformer: Long-Range Transformers with Unlimited Length Input"
MIT License
1.05k stars 77 forks source link

Working with 8bit and 4bit quantized models #19

Open jordancole21 opened 1 year ago

jordancole21 commented 1 year ago

Hey! Great work on this project! I got it t work on a couple of t5 instruction tuned models from huggingface, I was just curious, has anyone been able to get the code to work with quantized modes? Currently when I set it to 'load_in_4bit=True' I get this error:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ in <cell line: 1>:1 │ │ │ │ /content/unlimiformer/src/unlimiformer.py:707 in convert_model │ │ │ │ 704 │ @classmethod │ │ 705 │ def convert_model(cls, model, *args, **kwargs): │ │ 706 │ │ model_clone = AutoModelForSeq2SeqLM.from_config(model.config) │ │ ❱ 707 │ │ model_clone.load_state_dict(model.state_dict()) │ │ 708 │ │ type_to_class = { │ │ 709 │ │ │ BartModel: UnlimiformerBART, │ │ 710 │ │ │ BartForConditionalGeneration: UnlimiformerBART, │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:2041 in load_state_dict │ │ │ │ 2038 │ │ │ │ │ │ ', '.join('"{}"'.format(k) for k in missing_keys))) │ │ 2039 │ │ │ │ 2040 │ │ if len(error_msgs) > 0: │ │ ❱ 2041 │ │ │ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( │ │ 2042 │ │ │ │ │ │ │ self.__class__.__name__, "\n\t".join(error_msgs))) │ │ 2043 │ │ return _IncompatibleKeys(missing_keys, unexpected_keys) │ │ 2044 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ RuntimeError: Error(s) in loading state_dict for T5ForConditionalGeneration: size mismatch for encoder.block.0.layer.0.SelfAttention.q.weight: copying a param with shape torch.Size([524288, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024]). size mismatch for encoder.block.0.layer.0.SelfAttention.k.weight: copying a param with shape torch.Size([524288, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).

Does anyone have any solutions to this?

urialon commented 1 year ago

Hi @jordancole21 , Thank you for your interest in our work!

The idea there is that we are duplicating the model before overriding its functions to inject Unlimiformer. One quick and dirty hack would be to just remove this model duplication, and change the function to:

    def convert_model(cls, model, *args, **kwargs):
            model_clone = model
            .... # rest of the function here ....

However, if you manage to get the following to work and submit a PR, that would be great: In the file src/unlimiformer.py, when calling the function load_state_dict, can you try replacing this line to something like:

model_clone.load_state_dict(model.state_dict(), load_in_8bit=model.is_loaded_in_8bit, load_in_4bit=model.is_loaded_in_4bit)

I'm not sure if that should be the exact code, but the idea is to call model_clone.load_state_dict() with the same quantization as the original model.

Can you try and let me know if it works?

Thanks, Uri

urialon commented 1 year ago

Sorry, I checked and my previous suggestion doesn't work.

Do you have any idea of how to duplicate the model object, including its quantization settings?

If not, I'd recommend just using model_clone = model without actually cloning the model.

Best, Uri

jordancole21 commented 1 year ago

Sorry, I checked and my previous suggestion doesn't work.

Do you have any idea of how to duplicate the model object, including its quantization settings?

If not, I'd recommend just using model_clone = model without actually cloning the model.

Best, Uri

Thank you so much for the quick reply!

Ok so I tried that, and it looks like it gets through the 'model = Unlimiformer.convert_model(model)' code without any issue, though when I inference the model (lmsys/fastchat-t5-3b-v1.0) it's giving me cuda errors. even though it's only a 3 billion parameter model running on an A100-40G in 4bit.

Edit: For context, I'm passing in about 8k tokens into the input. And these are my unlimiformer and generate arguments:

unlimiformer:

unlimiformer_args = UnlimiformerArguments(
    gpu_datastore=False,
    use_datastore=True,
    unlimiformer_chunk_size=512,
    unlimiformer_verbose=True
)

Generate

# Sample Input 
input_ids = tokenizer(prompt, return_tensors="pt").to('cuda')

# Call model
with torch.no_grad():
    outputs = model.generate(
        **input_ids,
        max_length=2048,
        temperature=.7, 
        early_stopping=True,
        do_sample=True,
        top_p=.92,
        top_k=0,
        repetition_penalty=1.1
        )

Any ideas on why I would still be getting a cuda memory error?

urialon commented 1 year ago

What kind of cuda errors? Out of memory?

jordancole21 commented 1 year ago

What kind of cuda errors? Out of memory?

Yes sorry, just out of memory errors:

OutOfMemoryError: CUDA out of memory. Tried to allocate 13.19 GiB (GPU 0; 39.56 GiB total capacity; 30.03 GiB 
already allocated; 7.91 GiB free; 30.06 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory
try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and 
PYTORCH_CUDA_ALLOC_CONF
urialon commented 1 year ago

Hi @jordancole21 , I don't know.

I'm not sure whether the 4-bit is the problem, or is it anything else. Did you get any stack trace? I wonder if the problem occurs during the initial encoding of the input, or during decoding.

jordancole21 commented 1 year ago

Hi @jordancole21 , I don't know.

I'm not sure whether the 4-bit is the problem, or is it anything else. Did you get any stack trace? I wonder if the problem occurs during the initial encoding of the input, or during decoding.

Oop sorry I should have sent the full thing earlier. Here's the full traceback:

Token indices sequence length is longer than the specified maximum sequence length for this model (14873 > 2048). Running this sequence through the model will result in indexing errors
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <cell line: 5>:6                                                                              │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115 in decorate_context       │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1325 in generate        │
│                                                                                                  │
│   1322 │   │   if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:      │
│   1323 │   │   │   # if model is encoder decoder encoder_outputs are created                     │
│   1324 │   │   │   # and added to `model_kwargs`                                                 │
│ ❱ 1325 │   │   │   model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(           │
│   1326 │   │   │   │   inputs_tensor, model_kwargs, model_input_name                             │
│   1327 │   │   │   )                                                                             │
│   1328                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:638 in                  │
│ _prepare_encoder_decoder_kwargs_for_generation                                                   │
│                                                                                                  │
│    635 │   │   model_input_name = model_input_name if model_input_name is not None else self.ma  │
│    636 │   │   encoder_kwargs["return_dict"] = True                                              │
│    637 │   │   encoder_kwargs[model_input_name] = inputs_tensor                                  │
│ ❱  638 │   │   model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)          │
│    639 │   │                                                                                     │
│    640 │   │   return model_kwargs                                                               │
│    641                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165 in new_forward                   │
│                                                                                                  │
│   162 │   │   │   with torch.no_grad():                                                          │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │
│   164 │   │   else:                                                                              │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │
│   167 │                                                                                          │
│   168 │   module.forward = new_forward                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/t5/modeling_t5.py:1090 in forward    │
│                                                                                                  │
│   1087 │   │   │   │   │   None,  # past_key_value is always None with gradient checkpointing    │
│   1088 │   │   │   │   )                                                                         │
│   1089 │   │   │   else:                                                                         │
│ ❱ 1090 │   │   │   │   layer_outputs = layer_module(                                             │
│   1091 │   │   │   │   │   hidden_states,                                                        │
│   1092 │   │   │   │   │   attention_mask=extended_attention_mask,                               │
│   1093 │   │   │   │   │   position_bias=position_bias,                                          │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165 in new_forward                   │
│                                                                                                  │
│   162 │   │   │   with torch.no_grad():                                                          │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │
│   164 │   │   else:                                                                              │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │
│   167 │                                                                                          │
│   168 │   module.forward = new_forward                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/t5/modeling_t5.py:693 in forward     │
│                                                                                                  │
│    690 │   │   else:                                                                             │
│    691 │   │   │   self_attn_past_key_value, cross_attn_past_key_value = None, None              │
│    692 │   │                                                                                     │
│ ❱  693 │   │   self_attention_outputs = self.layer[0](                                           │
│    694 │   │   │   hidden_states,                                                                │
│    695 │   │   │   attention_mask=attention_mask,                                                │
│    696 │   │   │   position_bias=position_bias,                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165 in new_forward                   │
│                                                                                                  │
│   162 │   │   │   with torch.no_grad():                                                          │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │
│   164 │   │   else:                                                                              │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │
│   167 │                                                                                          │
│   168 │   module.forward = new_forward                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/t5/modeling_t5.py:600 in forward     │
│                                                                                                  │
│    597 │   │   output_attentions=False,                                                          │
│    598 │   ):                                                                                    │
│    599 │   │   normed_hidden_states = self.layer_norm(hidden_states)                             │
│ ❱  600 │   │   attention_output = self.SelfAttention(                                            │
│    601 │   │   │   normed_hidden_states,                                                         │
│    602 │   │   │   mask=attention_mask,                                                          │
│    603 │   │   │   position_bias=position_bias,                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:165 in new_forward                   │
│                                                                                                  │
│   162 │   │   │   with torch.no_grad():                                                          │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │
│   164 │   │   else:                                                                              │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │
│   167 │                                                                                          │
│   168 │   module.forward = new_forward                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/t5/modeling_t5.py:550 in forward     │
│                                                                                                  │
│    547 │   │   │   │   position_bias = position_bias[:, :, -hidden_states.size(1) :, :]          │
│    548 │   │   │                                                                                 │
│    549 │   │   │   if mask is not None:                                                          │
│ ❱  550 │   │   │   │   position_bias = position_bias + mask  # (batch_size, n_heads, seq_length  │
│    551 │   │                                                                                     │
│    552 │   │   if self.pruned_heads:                                                             │
│    553 │   │   │   mask = torch.ones(position_bias.shape[1])                                     │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 13.19 GiB (GPU 0; 39.56 GiB total capacity; 30.03 GiB 
already allocated; 7.91 GiB free; 30.06 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory
try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and 
PYTORCH_CUDA_ALLOC_CONF

I'm thinking it may be at the encoding of the input? lol this is all still a little over my head honestly

urialon commented 1 year ago

I can't tell which step is it in, because unlimiformer.py does not appear in the traceback. Without Unlimiformer - does it work well?

jordancole21 commented 1 year ago

I can't tell which step is it in, because unlimiformer.py does not appear in the traceback. Without Unlimiformer - does it work well?

Hm yeah even when I tried just the 4bit model without unlimiformer it also ran into Cuda memory issues, and then when I tried the model with full weights + Unlimiformer it still gave me a cuda memory error. But it seems to work without 4bit on this smaller model with longformer:

MBZUAI/LaMini-T5-738M

Also if it helps, this is the colab notebook I'm working in: https://colab.research.google.com/drive/1U1Pt6-htLzQ5gQdMBl3ZMkDXi9phzsnO?usp=sharing

urialon commented 1 year ago

Thanks @jordancole21 , We are currently working on supporting decoder models such as LLaMA, so I am keeping this issue open. If you manage to solve this in the meantime, we would love to accept contributions.

Best, Uri