Open l-berg opened 3 months ago
Not a torchtune author/contributor, but from the memory usage, I'm guessing that the old version performs NF4 quantization on GPU, while the new version performs it on CPU.
Not a torchtune author/contributor, but from the memory usage, I'm guessing that the old version performs NF4 quantization on GPU, while the new version performs it on CPU.
Makes sense, this was suggested by @msaroufim, too. I will confirm.
@l-berg Apologies for the late response - do you notice the slowdown on 0.2.0 as well? This will help me narrow down where these changes could be coming from.
Yes, upgrading from 0.1.1 to 0.2.0 results in the same increase from ~10s to >5min loading time on my machine.
Hi @l-berg - thanks for bringing this to our attention! The AO folks dug deep into this and saw that a version guarded inplace_copy function was the offending issue. Please read more about it here: https://github.com/pytorch/ao/issues/642.
This will be fixed by @ebsmothers in #1294
When working on my customized LoRAFinetuneRecipeSingleDevice recipe and upgrading from
torchtune
version 0.1.1 to 0.2.1 andtorchao
0.1 to 0.3.1, I noticed that model loading times went up dramatically when using QLoRA. Now, loading llama3-8b takes about 5 minutes, where it used to only be a few seconds in version 0.1.1. I was able to pinpoint it to the callmodel.load_state_dict(base_model_state_dict, strict=False)
.Here are the steps I took to reproduce this issue in a new conda environment (pytorch version 2.4).
outputs
When I downgrade to version 0.1.1, it's fast again:
outputs
after only a few seconds.
I am on a slurm machine with 4 cpu threads and a nvidia 3090 24gb.
Any ideas on what might be the cause? The lora recipes without quantization work just fine.