Open winglian opened 3 months ago
Adding a comment to track this discussion in Pytorch core https://github.com/pytorch/pytorch/issues/130330
Adding a comment to track this discussion in Pytorch core pytorch/pytorch#130330
If this lands, we should enable this by default in torchtune until it lands in a PyTorch stable release.
@winglian where did we land on this? The numbers here are for reserved memory, right? Do we see similar improvements on peak allocated memory as well?
I think the prudent thing to would be to wait for this to land in core with all the necessary safety updates that go with it, and not enable this flag ourself.
Just wanted to confirm that running on A100, with the flag i can run bs=4, but without, it OOMs
editing for clarity: Green has bsz=3. When i go to bsz=4, it doesnt work, unless I use the flag. The info is in the run name. It does NOT impact allocated memory, only reserved.
tune run --nproc_per_node 8 lora_finetune_distributed --config llama3/8B_lora \
batch_size=4 \
model.lora_attn_modules="['q_proj', 'v_proj']" \
model.apply_lora_to_mlp=False \
model.apply_lora_to_output=False \
metric_logger.name=LoRA__8192__nproc8__bsz4__ActCkptTrue \
dataset.max_seq_len=8192 \
dataset.source=Yukang/LongAlpaca-12k \
dataset.packed=False \
dataset.split="train[:10%]" \
metric_logger=torchtune.utils.metric_logging.WandBLogger \
metric_logger.project=mem_prealloc \
gradient_accumulation_steps=1 \
log_every_n_steps=1 \
log_peak_memory_stats=True \
max_steps_per_epoch=20 \
epochs=1 \
compile=False \
enable_activation_checkpointing=True
Just wanted to confirm that running on A100, with the flag i can run bs=4, but without, it OOMs
This would imply that we should be paying attention to reserved attention too, right? @ebsmothers
Just wanted to confirm that running on A100, with the flag i can run bs=4, but without, it OOMs
This would imply that we should be paying attention to reserved attention too, right? @ebsmothers
Sorry I missed this before. But based on @felipemello1's screenshot I would say that's not the case. It looks like this flag also impacts both active and allocated memory. In that case it seems pretty clear to me that we should enable it
It looks like this flag also impacts both active and allocated memory
No, sorry, thats misleading. Green has bsz=3. When i go to bsz=4, it doesnt work, unless I use the flag. The info is in the run name.
It does NOT impact allocated memory, only reserved
@ebsmothers
Thanks @felipemello1 for clarifying, I should've looked at the legend more closely.
So for actual next steps here: it seems like there are cases where anecdotally this does prevent OOMs. However there is some feature gap around CUDA IPC (as discussed in the issue linked by @pbontrager) preventing core from enabling it by default at this exact moment. We can wait for that, but if we want to unblock this sooner, I would say we can:
(1) enable this config on a branch, run on a broader set of recipes to ensure no obvious breakages or huge perf regressions.
Depending on the result of that, either:
(2a) enable this by default for all our recipes, or (2b) gate it behind a config.
In my mind (2b) is not the ideal outcome cause it's just one extra bespoke field to worry about. But open to discussing once we have more comprehensive results here.
What's the support matrix for expandable segments? I keep getting denied with "not supported on this platform" using EasyDiffusion and EasyTraining GUIs, but it's no stating whether I'm missing hardware or software support. I'm running Cuda 12.1 on Win10x64 and RTX 2070S 8Gb, torch-2.3.1+cu121, torchvision-0.18.1+cu121.
@Seedmanc , seems related: https://github.com/pytorch/pytorch/issues/122057
In Unsloth:
using the defaults in torchtune:
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device
We should dig into why this provides such a large memory improvement and consider making it the default setting, or documenting it somewhere if not.
baseline: 16.39GiB @ 21.6sec/it w/ compile: 16.63GiB @ 8.5sec/it w/ expandable_segments: 10.831GiB @ 20.56sec/it w/ compile + expandable_segments: 10.919GiB @ 8.5sec/it