Closed ElleLeonne closed 7 months ago
Of note, if I bypass this and handle everything manually in setup, it appears that my GPU doesn't actually take advantage of the quantization benefits. I am told my quantized model is 2gb, the trainer says 6gb, but when I go to train, my 16gb GPU overflows immediately. It even happens on my 24gb GPU.
I'm unsure what's going wrong.
An update:
When I load the model using just quantization, it takes up 2gb.
Lightning says it takes up 6gb. I assume lightning does a sample backward pass and that the excess is stored gradients.
I can use the trainer.init_module() context manager to keep lightning faithful to its own stated size.
However as soon as my model receives any textual data at all, it goes OOM. My suspicion now, is that the optimizer is not properly handling the quantization or respecting the frozen layers. I can think of no other reason that my 2gb double quantized 4bit model would OOM on a single backwards pass.
Solved: Trainer defaults to mixed precision when handed tags that are not explicitly set to "32-true" or similar.
This produces duplicate tensor overhead.
Additionally, the trainer appears to use as a base of the weights that the model was saved in, rather than the weights that the model is currently in, for its size prediction. It can be safely ignored.
Bug description
I am attempting to make use of the Bitsandbytes precision plugin as found here: https://lightning.ai/docs/pytorch/2.1.0/api/lightning.pytorch.plugins.precision.BitsandbytesPrecisionPlugin.html
I am advised to initialize the module in the setup() hook, because otherwise the model is loaded in full 32bit precision, which is too large to fit on my GPU.
When I do, the trainer complains that it is attempting to load weights with a size mismatch, as below.
I suppose I am wondering what exactly the BitsAndBytesPrecision plugin is doing, and if we need to perform any additional steps that aren't handled under the hood. Normally when I load a model, I need to pass a BitsAndBytes configuration object to the model. I have abstained from this due to the tutorial. Lightning seems to imply that it can handle this for us, but I'm unsure if that's true or not.
What version are you seeing the problem on?
v2.1
How to reproduce the bug
Error messages and logs
Attempting to load a model like this results in a spam of the following for basically all layers.
Environment
More info
No response