pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
BSD 3-Clause "New" or "Revised" License
3.53k stars 285 forks source link

Add 'on-the-fly' sample packing #1109

Closed joecummings closed 3 days ago

joecummings commented 1 week ago

Context

As investigated in #1097, it was shown that the offline approach to constructing the mask consumed waaaaaay too much memory. Therefore, this approach constructs tokens, labels, and input_pos offline and then constructs the mask during access (training). For a max_seq_len of 4096 (default for many models), we can expect the memory of a single pack to look like the following offline:

tokens: 88 (fixed size of Python object) + 8 (size of torch.int64) * 4096 = 32,896
labels: 88 + 8 * 4096 = 32,896
input_pos: 88 + 8 * 4096 = 32,896
seq_lens: 226 <-- varies based on num of samples that we do, but this is avg based on experiments
----------------------------------------
98,984 bytes ~= 0.1 MB

To provide a real-world example, let's use the Web Instruct Dataset from Tiger Labs. It comes in at 3.51 GB of size with 2.3 million samples. The average sample length (with instruct template applied) is about 100 tokens. This means that 40 samples fit in each pack if we don't split across packs. Therefore we can expect there to be about 57,500 packs. This number times 0.1MB is 5.75GB additional memory bringing the total on-disk memory needed to load (before training) this dataset is 9.26GB, well within reasonable bounds.

Why do we need seq_lens?: Technically we could calculate this using the input_pos, but this would save us negligible memory and increase processing time during training, which is undesirable.

Why are you using this dataset? It's a large dataset downloaded 33,026 times in the last month. Good a baseline as any.

Why did you update the signature to take in a padding_idx and hardcode in CROSS_ENTROPY_IGNORE_IDX? Excellent question. So before, the packed dataset made the assumption that padding_idx = 0 and to use the CROSS_ENTROPY_IGNORE_IDX. The former is NOT an assumption we can make therefore it should be actually configurable and the latter IS a reasonable assumption so we should just hardcore it instead of defaulting the param (which won't get used).

Changelog

Test plan

  1. Unit tests

All are passing

(joe-torchtune) [jrcummings@devvm050.nha0 ~/projects/joe-torchtune (pack-mask-on-the-fly)]$ pytest tests/torchtune/datasets/test_packed_dataset.py
================================================================ test session starts =================================================================
platform linux -- Python 3.11.9, pytest-8.1.1, pluggy-1.5.0
rootdir: /home/jrcummings/projects/joe-torchtune
configfile: pyproject.toml
plugins: integration-0.2.3, mock-3.14.0, cov-5.0.0, hypothesis-6.103.1
collected 11 items

tests/torchtune/datasets/test_packed_dataset.py ...........                                                                                    [100%]

================================================================= 11 passed in 5.03s =================================================================
  1. Direct memory / speed comparisons with old version

Using this gist: https://gist.github.com/joecummings/05586af0a08eef0714c7da3c56ee7365

Only packing 1% of the dataset which is 23k samples. Using our calculation from above we expect memory usage with the new implementation to take an additional 0.58GB.

Memory is an estimate based on psutil monitoring of virtual memory used. There are more things that affect this than just the dataset, but I think it gives us a good feel for memory usage. Also, I didn't want to go through and figure out how to zero out the psutil memory management in the same script so in between runs I just commented out the code I didn't care about in order to get memory estimates.

impl additional memory used for packing total additional memory used
old (all offline) 27 GB 27 GB
new (mask on the fly) 0.62 GB 1.14 GB

Our calculation looks pretty spot on for how much memory the new implementation should take. And it makes sense that there would be a little more memory used when the mask is constructed during dataloading.

impl time for packing
old (all offline) 134 s
new (mask on the fly) 22 s

Not surprising that the old packing takes much longer than the new masking.

  1. Compare iterations e2e during training

Why do we need to do this? Well, the above "loading" is not a true measurement of how packing a dataset will affect the training time. For instance, we are now passing in a constructed mask for attention instead of relying on SPDA to construct one for us.

CMD:

tune run lora_finetune_single_device \
    --config llama3/8B_lora_single_device \
    dataset._component_=torchtune.datasets.instruct_dataset \
    dataset.source=TIGER-LAB/WebInstructSub \
    template=torchtune.data.AlpacaInstructTemplate \
    column_map={"instruction":"question","output":"answer"} \
    max_seq_len=4096 \
    packed=True \
    split=train[:1%]
impl total time
old (all offline) 39 mins
new (mask on the fly) 48 mins
no packing 1 hr 58 mins

YAY, it's just (kinda) as fast as the old implementation in waaaaaay less memory.

pytorch-bot[bot] commented 1 week ago

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1109

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit cdf5cdfcc4c5ade680e9229ffd02c86cf7891599 with merge base abe798d5f7af7761fcf3064b42fb699c7ef19fcd (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ScottHoang commented 6 days ago

Hi Joe, Thank you for your work! I have been trying to do something similar. I'm just reading through your changes. Is the purpose of 'on-the-fly' packing to reduce overall memory overhead by generating the attn-mask on the fly instead of during _add_pack?

codecov-commenter commented 6 days ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 66.72%. Comparing base (abe798d) to head (60d19ab). Report is 5 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #1109 +/- ## =========================================== + Coverage 26.67% 66.72% +40.04% =========================================== Files 183 184 +1 Lines 8337 8586 +249 =========================================== + Hits 2224 5729 +3505 + Misses 6113 2857 -3256 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

joecummings commented 6 days ago

Hi Joe, Thank you for your work! I have been trying to do something similar. I'm just reading through your changes. Is the purpose of 'on-the-fly' packing to reduce overall memory overhead by generating the attn-mask on the fly instead of during _add_pack?

Yep, constructing the mask during training reduces memory by about 99% and only slightly slows down processing.

joecummings commented 5 days ago

Also, how difficult would it be to also move the input_pos creation to getitem as well? Probably not as significant of a memory saver as the mask, but might still be worthwhile

Great question! I think I actually could do this w just the seq_len information, but it would entail generating and concat-ing multiple tensor arrays during the getitem, which would slow down processing and only save a little bit of memory. I could do some tests to confirm the tradeoff though.