pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.92k stars 354 forks source link

[FEATURE REQUEST] Compile model+loss_fn together #1228

Open gau-nernst opened 1 month ago

gau-nernst commented 1 month ago

Right now when compile=True, only the model is compiled

https://github.com/pytorch/torchtune/blob/e10142016798cf84f2e5c638a985014384f400a7/recipes/lora_finetune_single_device.py#L383-L386

We can further boost performance by including loss calculations in compile step.

From my benchmarks, the improvement is pretty significant. Running tune run lora_finetune_single_device --config llama3/8B_qlora_single_device compile=True on 4080, the improvements is ~1100 tok/s -> ~1450 tok/s (30% improvement)

Baseline - main e101420

Step 1 | loss:1.9794920682907104 lr:2.9999999999999997e-06 tokens_per_second_per_gpu:36.1405614689982 
Step 2 | loss:2.132385015487671 lr:5.999999999999999e-06 tokens_per_second_per_gpu:1057.5127828685158 
Step 3 | loss:2.0779435634613037 lr:8.999999999999999e-06 tokens_per_second_per_gpu:1107.9911968768445 
Step 4 | loss:2.107637643814087 lr:1.1999999999999999e-05 tokens_per_second_per_gpu:1050.867526697978 
Step 5 | loss:1.9378103017807007 lr:1.4999999999999999e-05 tokens_per_second_per_gpu:1122.8913807208219 
Step 6 | loss:2.0037074089050293 lr:1.7999999999999997e-05 tokens_per_second_per_gpu:1114.3397000842317 
Step 7 | loss:2.049872636795044 lr:2.1e-05 tokens_per_second_per_gpu:1108.8569776038305 
Step 8 | loss:2.1449005603790283 lr:2.3999999999999997e-05 tokens_per_second_per_gpu:1027.7933852631434 
Step 9 | loss:1.9285515546798706 lr:2.6999999999999996e-05 tokens_per_second_per_gpu:1113.8025743191083 
Step 10 | loss:2.0167834758758545 lr:2.9999999999999997e-05 tokens_per_second_per_gpu:1094.9996764298014 
Step 11 | loss:1.9725338220596313 lr:3.2999999999999996e-05 tokens_per_second_per_gpu:975.2314694129668 
Step 12 | loss:1.974663496017456 lr:3.5999999999999994e-05 tokens_per_second_per_gpu:1044.7823733103148 
Step 13 | loss:1.8836348056793213 lr:3.9e-05 tokens_per_second_per_gpu:1074.1156293314884 
Step 14 | loss:1.9150726795196533 lr:4.2e-05 tokens_per_second_per_gpu:1025.360553285061 
Step 15 | loss:1.7414907217025757 lr:4.4999999999999996e-05 tokens_per_second_per_gpu:1019.1227943080746 
Step 16 | loss:1.5234928131103516 lr:4.7999999999999994e-05 tokens_per_second_per_gpu:1159.2307364007843 
Step 17 | loss:1.5247085094451904 lr:5.1e-05 tokens_per_second_per_gpu:1114.9777996699304 
Step 18 | loss:1.50749933719635 lr:5.399999999999999e-05 tokens_per_second_per_gpu:1061.0344737894352 
Step 19 | loss:1.4408982992172241 lr:5.6999999999999996e-05 tokens_per_second_per_gpu:1064.36640499739 
Step 20 | loss:1.4391398429870605 lr:5.9999999999999995e-05 tokens_per_second_per_gpu:1029.261234123613 

My branch https://github.com/gau-nernst/torchtune/commit/98749ee07e0449d844548629e2f484ce437210d3

Step 1 | loss:1.9793018102645874 lr:2.9999999999999997e-06 tokens_per_second_per_gpu:33.92528594071858 
Step 2 | loss:2.1328954696655273 lr:5.999999999999999e-06 tokens_per_second_per_gpu:1467.253065986233 
Step 3 | loss:2.078062057495117 lr:8.999999999999999e-06 tokens_per_second_per_gpu:1495.8065584468234 
Step 4 | loss:2.1072309017181396 lr:1.1999999999999999e-05 tokens_per_second_per_gpu:1455.5678994475297 
Step 5 | loss:1.9366379976272583 lr:1.4999999999999999e-05 tokens_per_second_per_gpu:1523.1019285925645 
Step 6 | loss:2.00321364402771 lr:1.7999999999999997e-05 tokens_per_second_per_gpu:1503.3646039188136 
Step 7 | loss:2.0503153800964355 lr:2.1e-05 tokens_per_second_per_gpu:1475.4495262068608 
Step 8 | loss:2.144784450531006 lr:2.3999999999999997e-05 tokens_per_second_per_gpu:1429.8221717785034 
Step 9 | loss:1.9278817176818848 lr:2.6999999999999996e-05 tokens_per_second_per_gpu:1513.377603433333 
Step 10 | loss:2.016079902648926 lr:2.9999999999999997e-05 tokens_per_second_per_gpu:1497.2795946343672 
Step 11 | loss:1.971487283706665 lr:3.2999999999999996e-05 tokens_per_second_per_gpu:1385.554533423888 
Step 12 | loss:1.9740395545959473 lr:3.5999999999999994e-05 tokens_per_second_per_gpu:1453.5935243661404 
Step 13 | loss:1.881981372833252 lr:3.9e-05 tokens_per_second_per_gpu:1488.3835529158055 
Step 14 | loss:1.912611484527588 lr:4.2e-05 tokens_per_second_per_gpu:1423.6128942828243 
Step 15 | loss:1.7379623651504517 lr:4.4999999999999996e-05 tokens_per_second_per_gpu:1416.9445297790708 
Step 16 | loss:1.5198403596878052 lr:4.7999999999999994e-05 tokens_per_second_per_gpu:1559.7370398385906 
Step 17 | loss:1.5182536840438843 lr:5.1e-05 tokens_per_second_per_gpu:1494.3229959889875 
Step 18 | loss:1.5012562274932861 lr:5.399999999999999e-05 tokens_per_second_per_gpu:1455.363997890013 
Step 19 | loss:1.4366376399993896 lr:5.6999999999999996e-05 tokens_per_second_per_gpu:1469.8507724317951 
Step 20 | loss:1.4373903274536133 lr:5.9999999999999995e-05 tokens_per_second_per_gpu:1446.520713139556 

There doesn't seem to be any downside for including loss in compile step. Has anyone tried this before or known any potential issues with this? I haven't done a full fine-tuning + model eval, just did a small benchmark + sanity check above.

felipemello1 commented 1 month ago

Thanks for sharing it! I will reproduce it later today on my side and confirm the results. I also wonder if we have to put the model/loss together, or if we can just torch.compile(loss_function) too.

joecummings commented 1 month ago

cc @msaroufim

gau-nernst commented 1 month ago

@felipemello1 My understanding is that compiling loss function separately won't make a difference because the loss function is just the simple cross entropy. The eager code should already call into an efficient implement of cross entropy.

Upon closer inspection, I see that the two following .contiguous() calls will copy data

https://github.com/pytorch/torchtune/blob/e10142016798cf84f2e5c638a985014384f400a7/recipes/lora_finetune_single_device.py#L571-L575

By compiling the model+loss_fn together, the data copy (and .transpose()) can potentially be optimized away. But it also raises the question: why is the eager implementation done in this way? Making copy of the logits can be expensive and memory-consuming, especially with large vocab size like in Llama3. I'm guessing the code was written this way to fit what nn.CrossEntropy expects?

On the top of my head, one way to avoid slicing the logits (thus avoid copying the logits) is to shift the target label in data loader. Thus, we can calculate loss at all positions, instead of just the first seq_len-1 positions. It's also kinda strange to me that though we call it labels, we still have to manually shift them in the main training loop to get correct label for each position (I know HuggingFace is doing it this way https://github.com/huggingface/transformers/blob/81233c069c166af033794134bd8888783ac49ebe/src/transformers/models/llama/modeling_llama.py#L1168-L1178)

msaroufim commented 1 month ago

Compiling the loss function has anecdotally made a tremendous difference for workloads I've studied https://github.com/mlcommons/algorithmic-efficiency/pull/597

awgu commented 1 month ago

I have seen good results compiling the cross entropy loss by itself in torchtitan as well.

felipemello1 commented 1 month ago

Upon closer inspection, I see that the two following .contiguous() calls will copy data

I removed it completely. The loss becomes horrible, but there is no change in memory. So I dont think that the slicing + contiguous is impacting it from a memory perspective. Small difference in TPS though, but maybe thats just my machine.

logits = self._model(tokens, mask=mask, input_pos=input_pos)
# Shift so that tokens < n predict n
# logits = logits[..., :-1, :].contiguous()
# labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
image image
gau-nernst commented 1 month ago

Compile model and loss separately on my machine

Step 1 | loss:1.9794920682907104 lr:2.9999999999999997e-06 tokens_per_second_per_gpu:36.755080345686565 
Step 2 | loss:2.132385015487671 lr:5.999999999999999e-06 tokens_per_second_per_gpu:1058.22240515447 
Step 3 | loss:2.077510356903076 lr:8.999999999999999e-06 tokens_per_second_per_gpu:1108.0076494783045 
Step 4 | loss:2.1064255237579346 lr:1.1999999999999999e-05 tokens_per_second_per_gpu:1051.6327880166907 
Step 5 | loss:1.9360378980636597 lr:1.4999999999999999e-05 tokens_per_second_per_gpu:1123.9045892277497 
Step 6 | loss:2.0012407302856445 lr:1.7999999999999997e-05 tokens_per_second_per_gpu:1115.0391648126629 
Step 7 | loss:2.047700881958008 lr:2.1e-05 tokens_per_second_per_gpu:1109.8240339511951 
Step 8 | loss:2.141397476196289 lr:2.3999999999999997e-05 tokens_per_second_per_gpu:1028.9185764491856 
Step 9 | loss:1.9249967336654663 lr:2.6999999999999996e-05 tokens_per_second_per_gpu:1114.5460984127874 
Step 10 | loss:2.011702060699463 lr:2.9999999999999997e-05 tokens_per_second_per_gpu:1095.843168540178 
Step 11 | loss:1.9648287296295166 lr:3.2999999999999996e-05 tokens_per_second_per_gpu:975.7730938888137 
Step 12 | loss:1.9662505388259888 lr:3.5999999999999994e-05 tokens_per_second_per_gpu:1045.3402925831765 
Step 13 | loss:1.871081829071045 lr:3.9e-05 tokens_per_second_per_gpu:1074.8372566416333 
Step 14 | loss:1.8996593952178955 lr:4.2e-05 tokens_per_second_per_gpu:1026.27401980707 
Step 15 | loss:1.7271528244018555 lr:4.4999999999999996e-05 tokens_per_second_per_gpu:1020.6450301260769 
Step 16 | loss:1.5104990005493164 lr:4.7999999999999994e-05 tokens_per_second_per_gpu:1160.5853742757959 
Step 17 | loss:1.5073328018188477 lr:5.1e-05 tokens_per_second_per_gpu:1115.639225098661 
Step 18 | loss:1.4909707307815552 lr:5.399999999999999e-05 tokens_per_second_per_gpu:1061.72986127453 
Step 19 | loss:1.4257352352142334 lr:5.6999999999999996e-05 tokens_per_second_per_gpu:1065.7145108114698 
Step 20 | loss:1.426650047302246 lr:5.9999999999999995e-05 tokens_per_second_per_gpu:1029.9716948072396 

So it seems like only compiling them together gives a good speed boost. I tried inspecting the triton-generated code but there doesn't seem to be any special fusion going on with cross entropy in forward pass. Maybe the backward pass is more optimized? Also I use torch nightly 2.5.0.dev20240709+cu121 if that matters.

From a performance perspective, compiling a larger piece of code should give more opportunity for fusion and optimization. Similar to how gpt-fast compile model+sampling together (https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/generate.py#L324). Maybe in the future torch compiler can even fuse matmul (LM head) with cross entropy? (ref: https://github.com/pytorch/pytorch/issues/124480) I suppose you are curious about compiling loss function separately so the implementation is cleaner? Don't need to wrap model + loss function in another function or nn.Module.

@felipemello1 Do you have more data from your side about different compile combinations? i.e. compile model only, compile model+loss together, compile model and loss separately.

felipemello1 commented 1 month ago

Hey @gau-nernst , thanks for running it! Thats very interesting. @ebsmothers said that he was going to investigate it, since he is working on profiling/optimization, so lets see if can reproduce it too.

ebsmothers commented 1 month ago

Hi @gau-nernst thanks for this nice find! With this kind of performance improvement I think it makes sense to update our recipes to compile model and loss together whenever we set compile=True. @felipemello1 I am happy to do it but won't have bandwidth for it until Wednesday. Will leave it up-for-grabs until then, if anyone wants to pick it up before then they are welcome to. Otherwise I will assign to myself as soon as I'm ready to work on it.

gau-nernst commented 1 month ago

@ebsmothers Just to check. Have you got the time to validate this on your machines yet? Happy to contribute a PR, but I think it's important that others can also reproduce the speedup.

ebsmothers commented 1 month ago

@gau-nernst thanks for the bump and apologies for not getting to it before now. After patching your changes, I think I actually see an even bigger speedup than you did.. pasting my results below:

Screenshot 2024-08-08 at 9 52 02β€―PM

This seems pretty compelling to me, we should move forward with this. Would you want to clean up your branch and publish to a PR?

gau-nernst commented 1 month ago

Wow, that is awesome! May I know what is your GPU? Is it A100?

I will push a PR. Design-wise, is the one in my current branch ok? i.e. having a model_loss() function outside of the recipe class, and assign it to self._model_loss. I was thinking if the loss computation can be included in the model definition itself, but seems like DPO uses a different loss function (and potentially other fine-tuning schemes that I'm not familiar with). I'm open to other suggestions.

Apart from the QLoRA recipes, do you know if any other recipes would benefit from this also? Full fine-tune single device low-memory probably doesn't benefit, since "optimizer in backward" (and also my other "CPU offload optimizer" work) is not compatible with torch.compile.

SalmanMohammadi commented 1 month ago

This is great, thanks so much for your work @gau-nernst.

but seems like DPO uses a different loss function (and potentially other fine-tuning schemes that I'm not familiar with). I'm open to other suggestions.

The only non-CE losses we have are for alignment fine tuning atm: https://pytorch.org/torchtune/main/api_ref_modules.html#loss, but they could definitely benefit from compile speedups (particularly PPO). Broadly speaking, in the recipes which use these losses, there's a lot of logic between (simplifying here) logits = self._model, and self._loss_fn(logits, labels).

I would v much appreciate your help in understanding how this works here: from looking at your branch, to compile the loss and model together we need to wrap the loss logic and relevant model calls together into a single callable?

gau-nernst commented 1 month ago

Yes, to compile the model and loss together, we have to wrap them into a single unit, either as a function or an nn.Module. Typically for projects I work on, since my model can have multiple components, I often wrap all the components under a single nn.Module class, then I can just call .compile() on the wrapper. Wrapping them in a function is also fine, like I did in my branch here, which is also simpler.

Making the function a method of the recipe class is also possible I think, but sometimes the self object may not play nicely with torch.compile(), so I prefer to make it a standalone function.

SalmanMohammadi commented 1 month ago

Thanks!

I think the DPO/PPO losses might be out if this is the case, since the loss computations are not easily factored out of the recipe, right?

I like seeing it as a standalone function here for CE, I'm not sure we'd want to tie losses in model defn. I think Evan and Felipe will have more intelligent things to say here but happy to chip in when you've got a PR up : )

gau-nernst commented 1 month ago

DPO/PPO losses should still be possible. For example, you can probably just move this out to a standalone function and compile it (at least from the glance of it, I might miss important details πŸ˜…): https://github.com/pytorch/torchtune/blob/f1db07480458518d9367ca596ae96cf1e9588eca/recipes/lora_dpo_single_device.py#L503-L538

So each recipe can have a standalone (model+loss) function in its own recipe.py file. My only concern is that it feels "detached" or "far away" from the training loop, so it might be confusing for people to look at.

I want to collect people feedback and thoughts before making a PR, since it seems pointless if I need to re-do it when people want another design πŸ˜…. The changes will be very minimal, as you can see in my proof-of-concept branch, but making the design feels nice and can easily extend to other recipes is the hard part.

felipemello1 commented 1 month ago

These results are so nice! regarding the design, i can take a closer look at it later today if you guys havent solved it by then.

My only concern is that it feels "detached" or "far away" from the training loop, so it might be confusing for people to look at.

I agree, it would be even worse if the function requires a bunch of inputs/outputs. But, for this speed up, i think its worth it. PS: I dont think we would need the "del policy" part, if this is in another function. Maybe we could try to have a simple code mock just see how it would look like?

SalmanMohammadi commented 1 month ago

So each recipe can have a standalone (model+loss) function in its own recipe.py file. My only concern is that it feels "detached" or "far away" from the training loop, so it might be confusing for people to look at.

FWIW I don't think this would be possible for the PPO loss, there are many model forwards from different models scattered throughout the recipe https://github.com/pytorch/torchtune/blob/main/recipes/ppo_full_finetune_single_device.py.

What you're describing here looks to me like a step() function, which would broadly have the signature def step(self, batch, self._model) -> Tuple[torch.Tensor, ...]*

where the return type is whatever outputs you need to capture from your loss function. In the example in your branch, you've definted the model_loss fn outside the recipe class - is this a hard requirement? IMO it'd be much much cleaner to keep it inside.

edit: or, if we don't want any loss.backwards happening, we could do something like def train_forward?, batch_forward?

The changes will be very minimal, as you can see in my proof-of-concept branch, but making the design feels nice and can easily extend to other recipes is the hard part.

Thanks for helping us integrate this awesome feature neatly : ) agreed!

gau-nernst commented 1 month ago

you've definted the model_loss fn outside the recipe class - is this a hard requirement? IMO it'd be much much cleaner to keep it inside.

I understand your point. I have mentioned it before πŸ˜„. In particular, it's regarding the self object.

Making the function a method of the recipe class is also possible I think, but sometimes the self object may not play nicely with torch.compile(), so I prefer to make it a standalone function.

torch.compile() might have improved since the last time I tried with compiling a class method. You can give it a try with PPO, since I don't know it in detail. I will give it a try for QLoRA recipe too.

ebsmothers commented 1 month ago

Wow, that is awesome! May I know what is your GPU? Is it A100?

@gau-nernst yes it's A100.

I will push a PR. Design-wise, is the one in my current branch ok? i.e. having a model_loss() function outside of the recipe class, and assign it to self._model_loss. I was thinking if the loss computation can be included in the model definition itself, but seems like DPO uses a different loss function (and potentially other fine-tuning schemes that I'm not familiar with). I'm open to other suggestions.

Took an initial look at your PR, I like the move into self._model_loss, that was going to be my primary comment anyways. But if you do run into problems compiling the recipe method we can try to figure out a reasonable alternative here. On that note, I agree with other comments about having separate nn.Modules for loss and model as much as possible, this also has implications for reusability at inference time.

Apart from the QLoRA recipes, do you know if any other recipes would benefit from this also? Full fine-tune single device low-memory probably doesn't benefit, since "optimizer in backward" (and also my other "CPU offload optimizer" work) is not compatible with torch.compile.

I would still try the full finetune recipes. I'd need to check exactly which ones use optimizer_in_backward, but definitely some of the recently-added smaller Qwen2 models do not (e.g.). I think this is a great example of a case where we should give more detailed recommendations based on people's hardware though.. like if I have an A100 I realistically don't need those low-memory optimizations and should instead use this feature. I wanna make sure it's easy for people to understand that and get the best perf they can with minimal friction. Getting off my soapbox.. another point is that we will soon move our distributed recipes to FSDP2 (I have an open PR for our full finetune that should land pretty soon). In that case I think we can try to enable this version of compile there as well.

gau-nernst commented 1 month ago

@ebsmothers Sounds good! I will try to see if I can run full finetune single device with smaller models on my 16GB VRAM GPU 🀣. If not, I would need your help to do the final testing.

gau-nernst commented 1 month ago

Tracker

Single device:

Distributed:

I think for non-DPO recipes, the changes is the same as what I have done, but I can't test the distributed recipes myself. Will give the DPO single device a try next.

felipemello1 commented 1 month ago

I can run the distributed this week @gau-nernst

gau-nernst commented 1 month ago

@felipemello1 Thank you for the help! The change is simple, I think it's easier and more convenient if you create the PR and test it yourself. I will help debug if there are any problems (though I don't know much about FSDP2 πŸ₯² and how it interacts with torch.compile())

joecummings commented 1 month ago

Heads up @felipemello1, Evan merged the FSDP2 changes into the full_finetune_distributed recipe so that is now officially FSDP2, not FSDP1. And for FSDP2, we've seen a few issues with compile. See #1152.

Curious how your experiments go!