pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.32k stars 436 forks source link

Context Length Increse Results in OOM #856

Closed BedirT closed 4 months ago

BedirT commented 6 months ago

I am testing TorchTune with some settings I trained my models on. My go-to single-device library was unsloth, as it provides great memory and time savings.

Based on my Llama 3 8B comparisons, the fine-tuning speed looks very comparable. However, unlike unsloth, I am getting an OOM when trying larger context sizes. I am using the default Lora recipe with a single device on a Docker container.

Do you think there could be something I did wrong, or is this expected? What is the context length for the vRAM reportings on the README?

[!Note] Testing on a single RTX 4090 with 8k context length. Tested w 2k and worked fine but my GPU usage readings were not matching the reported ones. I get all the way up to 24G

kartikayk commented 6 months ago

@BedirT thanks so much for sharing this! Glad to here the perf matches up with some of the other libraries you use.

What is the context length for the vRAM reportings on the README?

The README table includes numbers from the default configs which currently train using alpaca. This will be substantially shorter than the context lengths you're training on.

I set batch_size=1 and ran lora_finetune_single_device with dummy tensors of shape [1, 2048] and [1, 4096]. Both of these seem to run under 24GB peak memory with 4096 right at ~24GB (caveat I'm simulating on an A100, can test on 4090 in a bit). You're absolutely right that this will OOM on seq len 8192. I think we need to do some work on supporting large sequence lengths. Let me take a look at this and get back to you.

HaisongDing commented 6 months ago

It would be great to include the context length info in vRAM reportings at README for reference.

felipemello1 commented 4 months ago

@HaisongDing @BedirT , I see that we updated the readme to include the context length. If you still have questions, please feel free to reopen this issue. Thanks!! :)