pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.37k stars 446 forks source link

Llama3 qlora load_state_dict takes forever #1246

Open l-berg opened 3 months ago

l-berg commented 3 months ago

When working on my customized LoRAFinetuneRecipeSingleDevice recipe and upgrading from torchtune version 0.1.1 to 0.2.1 and torchao 0.1 to 0.3.1, I noticed that model loading times went up dramatically when using QLoRA. Now, loading llama3-8b takes about 5 minutes, where it used to only be a few seconds in version 0.1.1. I was able to pinpoint it to the call model.load_state_dict(base_model_state_dict, strict=False).

Here are the steps I took to reproduce this issue in a new conda environment (pytorch version 2.4).

conda create --name tune python=3.11
conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
conda activate tune
pip install torchtune
tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir . --hf-token <token>
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device checkpointer.checkpoint_dir=$CHECKPOINT_DIR tokenizer.path=$CHECKPOINT_DIR/tokenizer.model checkpointer.output_dir=$CHECKPOINT_DIR output_dir=$CHECKPOINT_DIR

outputs

...
DEBUG:torchtune.utils.logging:Setting manual seed to local seed 4061266822. Local seed is seed + rank = 4061266822 + 0
Writing logs to /.../testtune/instruct/original/log_1722353553.txt
<WAITING FOR 5 MINUTES, RAM usage slowly rising to 15GB>
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils.logging:Memory stats after model init:
        GPU peak memory allocation: 6.79 GB
        GPU peak memory reserved: 6.97 GB
        GPU peak memory active: 6.79 GB
...

When I downgrade to version 0.1.1, it's fast again:

pip install torchtune==0.1.1
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device checkpointer.checkpoint_dir=$CHECKPOINT_DIR tokenizer.path=$CHECKPOINT_DIR/tokenizer.model checkpointer.output_dir=$CHECKPOINT_DIR output_dir=$CHECKPOINT_DIR

outputs

DEBUG:torchtune.utils.logging:Setting manual seed to local seed 1435681341. Local seed is seed + rank = 1435681341 + 0
Writing logs to /.../testtune/instruct/original/log_1722354102.txt
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils.logging:Memory Stats after model init:
{'peak_memory_active': 10.129489408, 'peak_memory_alloc': 10.129489408, 'peak_memory_reserved': 12.639535104}

after only a few seconds.

I am on a slurm machine with 4 cpu threads and a nvidia 3090 24gb.

Any ideas on what might be the cause? The lora recipes without quantization work just fine.

gau-nernst commented 3 months ago

Not a torchtune author/contributor, but from the memory usage, I'm guessing that the old version performs NF4 quantization on GPU, while the new version performs it on CPU.

joecummings commented 3 months ago

Not a torchtune author/contributor, but from the memory usage, I'm guessing that the old version performs NF4 quantization on GPU, while the new version performs it on CPU.

Makes sense, this was suggested by @msaroufim, too. I will confirm.

@l-berg Apologies for the late response - do you notice the slowdown on 0.2.0 as well? This will help me narrow down where these changes could be coming from.

l-berg commented 3 months ago

Yes, upgrading from 0.1.1 to 0.2.0 results in the same increase from ~10s to >5min loading time on my machine.

joecummings commented 3 months ago

Hi @l-berg - thanks for bringing this to our attention! The AO folks dug deep into this and saw that a version guarded inplace_copy function was the offending issue. Please read more about it here: https://github.com/pytorch/ao/issues/642.

This will be fixed by @ebsmothers in #1294