Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.42k stars 3.39k forks source link

Huggingface model quantization has odd behavior with BitsandbyesPrecisionPlugin #19732

Closed ElleLeonne closed 7 months ago

ElleLeonne commented 7 months ago

Bug description

I am attempting to make use of the Bitsandbytes precision plugin as found here: https://lightning.ai/docs/pytorch/2.1.0/api/lightning.pytorch.plugins.precision.BitsandbytesPrecisionPlugin.html

I am advised to initialize the module in the setup() hook, because otherwise the model is loaded in full 32bit precision, which is too large to fit on my GPU.

When I do, the trainer complains that it is attempting to load weights with a size mismatch, as below.

I suppose I am wondering what exactly the BitsAndBytesPrecision plugin is doing, and if we need to perform any additional steps that aren't handled under the hood. Normally when I load a model, I need to pass a BitsAndBytes configuration object to the model. I have abstained from this due to the tutorial. Lightning seems to imply that it can handle this for us, but I'm unsure if that's true or not.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

# Pseudo code below
import torch
import lightning as L
from lightning.pytorch.plugins import BitsandbytesPrecision
from datasets import load_dataset

class LightningModel(L.LightningModule):
    def __init__(self, model, *args, **kwargs):
      super().__init__(*args, **kwargs)
      self.model = model
      self.instantiated = False

   def configure_model(self):
      if self.instantiated is False:
         self.model = AutoModelForCausalLM.from_pretrained(self.model)
         self.instantiated = True

model = LightningModel("google/gemma-2b")
trainer = L.Trainer(model=model,
                    plugins = BitsandbytesPrecision(mode=self.train_precision, dtype=torch.float16))

Error messages and logs

Attempting to load a model like this results in a spam of the following for basically all layers.

Seed set to 234
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/<redacted>/anaconda3/envs/<redacted>/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
 -- Instantiating --- 
Falling back to loading HuggingFace Model
Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████████████| 2/2 [00:00<00:00,  2.79it/s]
Traceback (most recent call last):
  File "/home/<redacted>/Desktop/<redacted>/train.py", line 89, in <module>
    main(train_config, data_config)
  File "/home/<redacted>/Desktop/<redacted>/train.py", line 83, in main
    trainer.fit(model)
  File "/home/<redacted>/anaconda3/envs/<redacted>/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/<redacted>/anaconda3/envs/<redacted>/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/<redacted>/anaconda3/envs/<redacted>/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/<redacted>/anaconda3/envs/<redacted>/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 951, in _run
    call._call_configure_model(self)
  File "/home/<redacted>/anaconda3/envs/<redacted>/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 109, in _call_configure_model
    _call_lightning_module_hook(trainer, "configure_model")
  File "/home/<redacted>/anaconda3/envs/<redacted>/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 157, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/<redacted>/Desktop/<redacted>/training/lightning_core.py", line 73, in configure_model
    self.model = AutoModelForCausalLM.from_pretrained(self.model_path)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/<redacted>/anaconda3/envs/<redacted>/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py", line 563, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/<redacted>/anaconda3/envs/<redacted>/lib/python3.12/site-packages/transformers/modeling_utils.py", line 3531, in from_pretrained
    ) = cls._load_pretrained_model(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/<redacted>/anaconda3/envs/<redacted>/lib/python3.12/site-packages/transformers/modeling_utils.py", line 4009, in _load_pretrained_model
    raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
RuntimeError: Error(s) in loading state_dict for GemmaForCausalLM:
    size mismatch for model.layers.0.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([262144000, 1]).
    size mismatch for model.layers.0.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.0.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.0.self_attn.o_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.1.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.1.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.1.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.1.self_attn.o_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.1.mlp.gate_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.1.mlp.up_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.2.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.2.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.2.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.2.self_attn.o_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.2.mlp.gate_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.2.mlp.up_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.2.mlp.down_proj.weight: copying a param with shape torch.Size([2048, 16384]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.3.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.3.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.3.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.3.self_attn.o_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.3.mlp.gate_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.3.mlp.up_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.3.mlp.down_proj.weight: copying a param with shape torch.Size([2048, 16384]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.4.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.4.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.4.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.4.self_attn.o_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.4.mlp.gate_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.4.mlp.up_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.4.mlp.down_proj.weight: copying a param with shape torch.Size([2048, 16384]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.5.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.5.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.5.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.5.self_attn.o_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.5.mlp.gate_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.5.mlp.up_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.5.mlp.down_proj.weight: copying a param with shape torch.Size([2048, 16384]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.6.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.6.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.6.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.6.self_attn.o_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.6.mlp.gate_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.6.mlp.up_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.6.mlp.down_proj.weight: copying a param with shape torch.Size([2048, 16384]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.7.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.7.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.7.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.7.self_attn.o_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.7.mlp.gate_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.7.mlp.up_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.7.mlp.down_proj.weight: copying a param with shape torch.Size([2048, 16384]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.8.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.8.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.8.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.8.self_attn.o_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.8.mlp.gate_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.8.mlp.up_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.8.mlp.down_proj.weight: copying a param with shape torch.Size([2048, 16384]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.9.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.9.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.9.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.9.self_attn.o_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.9.mlp.gate_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.9.mlp.up_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.9.mlp.down_proj.weight: copying a param with shape torch.Size([2048, 16384]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.15.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.15.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.16.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.16.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.16.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.16.self_attn.o_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.17.self_attn.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    size mismatch for model.layers.17.self_attn.k_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.17.self_attn.v_proj.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.17.self_attn.o_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([2097152, 1]).
    size mismatch for model.layers.17.mlp.gate_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([262144, 1]).
    size mismatch for model.layers.17.mlp.up_proj.weight: copying a param with shape torch.Size([16384, 2048]) from checkpoint, the shape in current model is torch.Size([16777216, 1]).
    size mismatch for model.layers.17.mlp.down_proj.weight: copying a param with shape torch.Size([2048, 16384]) from checkpoint, the shape in current model is torch.Size([1024, 1]).
    You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

Environment

More info

No response

ElleLeonne commented 7 months ago

Of note, if I bypass this and handle everything manually in setup, it appears that my GPU doesn't actually take advantage of the quantization benefits. I am told my quantized model is 2gb, the trainer says 6gb, but when I go to train, my 16gb GPU overflows immediately. It even happens on my 24gb GPU.

I'm unsure what's going wrong.

ElleLeonne commented 7 months ago

An update:

When I load the model using just quantization, it takes up 2gb.

Lightning says it takes up 6gb. I assume lightning does a sample backward pass and that the excess is stored gradients.

I can use the trainer.init_module() context manager to keep lightning faithful to its own stated size.

However as soon as my model receives any textual data at all, it goes OOM. My suspicion now, is that the optimizer is not properly handling the quantization or respecting the frozen layers. I can think of no other reason that my 2gb double quantized 4bit model would OOM on a single backwards pass.

ElleLeonne commented 7 months ago

Solved: Trainer defaults to mixed precision when handed tags that are not explicitly set to "32-true" or similar.

This produces duplicate tensor overhead.

Additionally, the trainer appears to use as a base of the weights that the model was saved in, rather than the weights that the model is currently in, for its size prediction. It can be safely ignored.