RahulSChand / gpu_poor

Calculate token/s & GPU memory requirement for any LLM. Supports llama.cpp/ggml/bnb/QLoRA quantization
https://rahulschand.github.io/gpu_poor/
1.12k stars 56 forks source link

Results are inconsistent and is not reliable enough #1

Closed Anindyadeep closed 1 year ago

Anindyadeep commented 1 year ago

Hey @RahulSChand, Awesome work on creating this calculator. But there are some problems I am facing and getting unreliable results. Here are some of the issues I am facing:

The configurations I will be using are as follows:

Model: CodeLlama 
Param size: 7B
batch size: 1
context length: 2048
  1. QLoRA's GPU memory is showing more than LoRA

In LoRA it is showing: 177 GB and for QLoRA it is showing: 180 GB and full fine-tuning it is showing: 216 GB

  1. When I upload the config.json file vs. just the parameter number, it shows inconsistent results.

  2. The memory requirement number should not be this much. For example, I am using just 1 as batch size and 2048 context length size it is showing triple digits for LoRA and QLoRA, and now consider this graph. Reference

image

According to this graph, the memory requirement for LoRA is 16GB but in the calculation, it is showing 177 GB.

So, can you please address this doubts and if there is any way to fix this, it would be awesome.

RahulSChand commented 1 year ago

@Anindyadeep Thanks for letting me know. I checked this issue & ran QLoRA training with 1000 context length on my 4090 24GB GPU. Below is the memory screenshot (it takes ~23 GB & the website also gives you same values.

image

As far as the link that you provided is concerned where only 16GB of memory is used. This is because the finetuning is done on alpaca dataset (which has context length of around 700) Link: https://github.com/gururise/AlpacaDataCleaned#finetune-considerations. For length=700, website gives you 15GB memory requirement (same as the image you posted).

The memory requirements depend on the context length (activation memory) since there are many (length, dim) & (dim, dim, head) intermediate states generated in forward pass (which are also needed for backward pass). These vectors are not updated (they don't have grad) but they are needed to compute grad of LoRA params. So your memory requirement can increase a lot with context length.

Anindyadeep commented 1 year ago

Ahh, that makes quite a sense. Actually, I was pretty blown up by watching the GPU numbers, but the same blog post showed that it takes 240 GB (6 x 40 GB) for full fine-tuning on the same dataset. So it now kinda makes sense that it will take more with 4096 context length (~ quadratic increase).

However, can you please clear one more doubt, like right now, why the memory requirement is > memory requirement in LoRA?

RahulSChand commented 1 year ago

@Anindyadeep sorry I didn't get your last question. What do you mean by "memory requirement is > memory requirement in LoRA"? Do you mean that the website is giving memory requirement for QLoRA as being larger than LoRA? I checked & this doesn't seem to be the case for your configuration (codellama-7b, 2048 context length).

Let me know if I am misunderstanding your question.

Anindyadeep commented 1 year ago

Here:

For Full-finetuning

image


For LoRA

image


For QLoRA

image

RahulSChand commented 1 year ago

@Anindyadeep oh okay got it. This is because for QLoRA & any other bitsandbytes quantization (https://github.com/TimDettmers/bitsandbytes) method there is an overhead during forward pass (this overhead is usually small when context length is small). This is also present if you use bitsandbytes llm.int8 quantization

So even though QLoRA is smaller than LoRA (theoretically), the quantization overhead introduced by bitsandbytes can offset this when context length is large.

Below is an approximate way to calculate this overhead (this is an empirical way that I figured after lots of trial & error with 3b/7b/13b models & bitsandbytes QLoRA runs)

QLoRA overhead = (15*hidden_dim + 6*intermediate_dim) x (numLayers) x contextLen x 0.75 bytes

I am also not sure what happens at high context length regime (maybe for large context lengths like >2048 this approximation is very wrong and overhead doesn't grow linearly with contextLen). This is something I need to check

Anindyadeep commented 1 year ago

Well, that answers all my doubts. Thanks a lot, @RahulSChand. I learned some new stuff here and it seems like I need to revisit some of those nuances. But thanks again.

AaronZLT commented 2 months ago

@Anindyadeep oh okay got it. This is because for QLoRA & any other bitsandbytes quantization (https://github.com/TimDettmers/bitsandbytes) method there is an overhead during forward pass (this overhead is usually small when context length is small). This is also present if you use bitsandbytes llm.int8 quantization

So even though QLoRA is smaller than LoRA (theoretically), the quantization overhead introduced by bitsandbytes can offset this when context length is large.

Below is an approximate way to calculate this overhead (this is an empirical way that I figured after lots of trial & error with 3b/7b/13b models & bitsandbytes QLoRA runs)

QLoRA overhead = (15*hidden_dim + 6*intermediate_dim) x (numLayers) x contextLen x 0.75 bytes

I am also not sure what happens at high context length regime (maybe for large context lengths like >2048 this approximation is very wrong and overhead doesn't grow linearly with contextLen). This is something I need to check

Thank you for your reply! I 'm a little confused, we know that quantization has an overhead, but can this overhead (quant-factor, i.e., quantization and double quantization scaling factor) consumes such memory? For one nf4 parameter, how much space does it need in training (I assume, 4bit + ?)