pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.93k stars 356 forks source link

This QLoRA config makes the model initialize with 15GB, instead of 6GB #1398

Open felipemello1 opened 3 weeks ago

felipemello1 commented 3 weeks ago

Running:

tune run lora_finetune_single_device --config llama3_1/8B_qlora_single_device model.lora_attn_modules="[q_proj, v_proj]" model.apply_lora_to_mlp=False model.apply_lora_to_output=False model.lora_rank=8 model.lora_alpha=16 

Gives mes

INFO:torchtune.utils.logging:Memory stats after model init:
        GPU peak memory allocation: 14.41 GiB
        GPU peak memory reserved: 14.62 GiB
        GPU peak memory active: 14.41 GiB

but running

tune run lora_finetune_single_device --config llama3_1/8B_qlora_single_device

gives me

INFO:torchtune.utils.logging:Memory stats after model init:
        GPU peak memory allocation: 6.19 GiB
        GPU peak memory reserved: 6.27 GiB
        GPU peak memory active: 6.19 GiB

this doesnt make sense, since the default of the config is to have more modules.

model:
  _component_: torchtune.models.llama3_1.qlora_llama3_1_8b
  lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj']
  apply_lora_to_mlp: True
  apply_lora_to_output: False
  lora_rank: 8
  lora_alpha: 16

environment:

conda create -n torchtune_debugging python=3.11
conda activate torchtune_debugging
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121
pip install -e ".[dev]"
pip uninstall torchao
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121
tune run lora_finetune_single_device --config llama3_1/8B_qlora_single_device model.lora_attn_modules="[q_proj, v_proj]" model.apply_lora_to_mlp=False model.apply_lora_to_output=False model.lora_rank=8 model.lora_alpha=16 
felipemello1 commented 3 weeks ago

cc: @ebsmothers

felipemello1 commented 3 weeks ago

I believe it happens because we only quantize base if it has LoRA. By using LoRA on less modules, we quantize less: https://github.com/pytorch/torchtune/blob/f9f75bb563ecae371492a9d49da4a9f514c081b3/torchtune/models/llama3_1/_component_builders.py#L337

Is that the expected behavior?

felipemello1 commented 3 weeks ago

need to add this:https://github.com/pytorch/torchtune/blob/f9f75bb563ecae371492a9d49da4a9f514c081b3/torchtune/models/llama3/_component_builders.py#L222

here: https://github.com/pytorch/torchtune/blob/f9f75bb563ecae371492a9d49da4a9f514c081b3/torchtune/models/llama3_1/_component_builders.py#L99