pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.93k stars 354 forks source link

[RFC] Config profiling #1252

Open felipemello1 opened 1 month ago

felipemello1 commented 1 month ago

I created a script to check the tokens_per_second and peak_memory for any config/recipe testing different flags. It still needs some polishing, but I wanted to get comments about the outcome before I go further and submit a PR.

Goal:

To have a page with these 2 graphs and table for every model config (thats a lot of configs :O), so users can understand their options.

PS: It sounds like it will be frustrating to update all of these anytime we make a major change. If we could automate it with CI, that would be nice, but is that possible?

TLDR:

1) Take DEFAULT config and iterate over [bsz, max_seq_len] 2) Apply all MEMORY related flags and iterate over [bsz, max_seq_len] 3) Apply all SPEED related flags and iterate over [bsz, max_seq_len] 4) Take smallest and max_seq_len that runs without OOM, and iterate over the flags one by one, so we can measure impact.

IMPORTANT: default values differ from our configs. They are the values without optimization, e.g. checkpointing, regular Adam, etc.

Outcomes:

Under the hood, the flags look like this:

SPEED_FLAGS = {
    "compile":
        {
            "values": [True],
        },
    "fsdp_sharding_strategy":
        {
            "values": ["NO_SHARD"],
            "allowed_recipes": ["lora_finetune_distributed", "full_finetune_distributed"]
        },
}

Blockers

Want these three low hanging fruits to land first:

Run:

python profile_config.py --yaml_name llama3_1/8B_lora_single_device --recipe_type lora_finetune_single_device

Outputs (still needs polishing):

max_peak_memory_alloc_vs_max_seq_len

mean_tokens_per_second_per_gpu_vs_max_seq_len

max_seq_len = 1024

description | memory Vs default (%) | toks/s vs default (%) | batch_size | max_seq_len | max_peak_memory_alloc | mean_tokens_per_second_per_gpu -- | -- | -- | -- | -- | -- | -- config_Default | 0 | 0 | 1 | 1024 | 28.97 | 1636.5 config_AllMemoryOptimization | -30.62 | -8.49 | 1 | 1024 | 20.1 | 1497.53 config_AllSpeedOptimization | -5.11 | 9.19 | 1 | 1024 | 27.49 | 1786.84 enable_activation_checkpointing-True | -29.72 | -19.56 | 1 | 1024 | 20.36 | 1316.32 minimal_lora-rank_8__qv_proj_only | -11.32 | 7.71 | 1 | 1024 | 25.69 | 1762.66 compile-True | -5.11 | 9.06 | 1 | 1024 | 27.49 | 1784.82 optimizer._component_-bitsandbytes.optim.PagedAdamW8bit | -0.48 | -6.42 | 1 | 1024 | 28.83 | 1531.44 optimizer_in_bwd-True | 0 | 0.41 | 1 | 1024 | 28.97 | 1643.22 memory_efficient_fsdp_wrap-False | 0 | 0.32 | 1 | 1024 | 28.97 | 1641.78 fsdp_sharding_strategy-NO_SHARD | 0 | 0.27 | 1 | 1024 | 28.97 | 1640.97 fsdp_cpu_offload-True | 0 | 0.08 | 1 | 1024 | 28.97 | 1637.74 max_seq_len = 4096 description | memory Vs default (%) | toks/s vs default (%) | batch_size | max_seq_len | max_peak_memory_alloc | mean_tokens_per_second_per_gpu -- | -- | -- | -- | -- | -- | -- config_Default | 0 | 0 | 1 | 4096 | 61.23 | 2286.6 config_AllMemoryOptimization | -58.03 | -13.49 | 1 | 4096 | 25.7 | 1978.14 config_AllSpeedOptimization | -9.64 | 13.99 | 1 | 4096 | 55.33 | 2606.41 enable_activation_checkpointing-True | -57.54 | -23.13 | 1 | 4096 | 26 | 1757.64 minimal_lora-rank_8__qv_proj_only | -20.22 | 8.47 | 1 | 4096 | 48.85 | 2480.22 compile-True | -9.64 | 13.72 | 1 | 4096 | 55.33 | 2600.22 optimizer._component_-bitsandbytes.optim.PagedAdamW8bit | -0.23 | -2.67 | 1 | 4096 | 61.09 | 2225.47 fsdp_cpu_offload-True | 0 | -0.18 | 1 | 4096 | 61.23 | 2282.41 fsdp_sharding_strategy-NO_SHARD | 0 | -0.19 | 1 | 4096 | 61.23 | 2282.29 optimizer_in_bwd-True | 0 | -0.21 | 1 | 4096 | 61.23 | 2281.8 memory_efficient_fsdp_wrap-False | 0 | -0.29 | 1 | 4096 | 61.23 | 2280.05
RdoubleA commented 1 month ago

These are really interesting experiments. What would be a bit more helpful is to plot each combination of (bsz, max_seq_len) on a single plot with the two axes being WPS and peak memory. That way I can immediately see the tradeoff between memory and speed. Although, with the way it's laid out now, I find it interesting that increasing max seq len does increase throughput upto a certain point, then it hurts it. And batch size seems to be the biggest contributor to increasing throughput, even over the speed optimized config.

Another thing that would be helpful is if you could list out what flags you're modifying for the memory and speed optimized configs.