Open gau-nernst opened 1 month ago
Thanks for sharing it! I will reproduce it later today on my side and confirm the results. I also wonder if we have to put the model/loss together, or if we can just torch.compile(loss_function) too.
cc @msaroufim
@felipemello1 My understanding is that compiling loss function separately won't make a difference because the loss function is just the simple cross entropy. The eager code should already call into an efficient implement of cross entropy.
Upon closer inspection, I see that the two following .contiguous()
calls will copy data
By compiling the model+loss_fn together, the data copy (and .transpose()) can potentially be optimized away. But it also raises the question: why is the eager implementation done in this way? Making copy of the logits can be expensive and memory-consuming, especially with large vocab size like in Llama3. I'm guessing the code was written this way to fit what nn.CrossEntropy
expects?
On the top of my head, one way to avoid slicing the logits (thus avoid copying the logits) is to shift the target label in data loader. Thus, we can calculate loss at all positions, instead of just the first seq_len-1
positions. It's also kinda strange to me that though we call it labels
, we still have to manually shift them in the main training loop to get correct label for each position (I know HuggingFace is doing it this way https://github.com/huggingface/transformers/blob/81233c069c166af033794134bd8888783ac49ebe/src/transformers/models/llama/modeling_llama.py#L1168-L1178)
Compiling the loss function has anecdotally made a tremendous difference for workloads I've studied https://github.com/mlcommons/algorithmic-efficiency/pull/597
I have seen good results compiling the cross entropy loss by itself in torchtitan as well.
Upon closer inspection, I see that the two following
.contiguous()
calls will copy data
I removed it completely. The loss becomes horrible, but there is no change in memory. So I dont think that the slicing + contiguous is impacting it from a memory perspective. Small difference in TPS though, but maybe thats just my machine.
logits = self._model(tokens, mask=mask, input_pos=input_pos)
# Shift so that tokens < n predict n
# logits = logits[..., :-1, :].contiguous()
# labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
Compile model and loss separately on my machine
Step 1 | loss:1.9794920682907104 lr:2.9999999999999997e-06 tokens_per_second_per_gpu:36.755080345686565
Step 2 | loss:2.132385015487671 lr:5.999999999999999e-06 tokens_per_second_per_gpu:1058.22240515447
Step 3 | loss:2.077510356903076 lr:8.999999999999999e-06 tokens_per_second_per_gpu:1108.0076494783045
Step 4 | loss:2.1064255237579346 lr:1.1999999999999999e-05 tokens_per_second_per_gpu:1051.6327880166907
Step 5 | loss:1.9360378980636597 lr:1.4999999999999999e-05 tokens_per_second_per_gpu:1123.9045892277497
Step 6 | loss:2.0012407302856445 lr:1.7999999999999997e-05 tokens_per_second_per_gpu:1115.0391648126629
Step 7 | loss:2.047700881958008 lr:2.1e-05 tokens_per_second_per_gpu:1109.8240339511951
Step 8 | loss:2.141397476196289 lr:2.3999999999999997e-05 tokens_per_second_per_gpu:1028.9185764491856
Step 9 | loss:1.9249967336654663 lr:2.6999999999999996e-05 tokens_per_second_per_gpu:1114.5460984127874
Step 10 | loss:2.011702060699463 lr:2.9999999999999997e-05 tokens_per_second_per_gpu:1095.843168540178
Step 11 | loss:1.9648287296295166 lr:3.2999999999999996e-05 tokens_per_second_per_gpu:975.7730938888137
Step 12 | loss:1.9662505388259888 lr:3.5999999999999994e-05 tokens_per_second_per_gpu:1045.3402925831765
Step 13 | loss:1.871081829071045 lr:3.9e-05 tokens_per_second_per_gpu:1074.8372566416333
Step 14 | loss:1.8996593952178955 lr:4.2e-05 tokens_per_second_per_gpu:1026.27401980707
Step 15 | loss:1.7271528244018555 lr:4.4999999999999996e-05 tokens_per_second_per_gpu:1020.6450301260769
Step 16 | loss:1.5104990005493164 lr:4.7999999999999994e-05 tokens_per_second_per_gpu:1160.5853742757959
Step 17 | loss:1.5073328018188477 lr:5.1e-05 tokens_per_second_per_gpu:1115.639225098661
Step 18 | loss:1.4909707307815552 lr:5.399999999999999e-05 tokens_per_second_per_gpu:1061.72986127453
Step 19 | loss:1.4257352352142334 lr:5.6999999999999996e-05 tokens_per_second_per_gpu:1065.7145108114698
Step 20 | loss:1.426650047302246 lr:5.9999999999999995e-05 tokens_per_second_per_gpu:1029.9716948072396
So it seems like only compiling them together gives a good speed boost. I tried inspecting the triton-generated code but there doesn't seem to be any special fusion going on with cross entropy in forward pass. Maybe the backward pass is more optimized? Also I use torch nightly 2.5.0.dev20240709+cu121
if that matters.
From a performance perspective, compiling a larger piece of code should give more opportunity for fusion and optimization. Similar to how gpt-fast compile model+sampling together (https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/generate.py#L324). Maybe in the future torch compiler can even fuse matmul (LM head) with cross entropy? (ref: https://github.com/pytorch/pytorch/issues/124480) I suppose you are curious about compiling loss function separately so the implementation is cleaner? Don't need to wrap model + loss function in another function or nn.Module.
@felipemello1 Do you have more data from your side about different compile combinations? i.e. compile model only, compile model+loss together, compile model and loss separately.
Hey @gau-nernst , thanks for running it! Thats very interesting. @ebsmothers said that he was going to investigate it, since he is working on profiling/optimization, so lets see if can reproduce it too.
Hi @gau-nernst thanks for this nice find! With this kind of performance improvement I think it makes sense to update our recipes to compile model and loss together whenever we set compile=True
. @felipemello1 I am happy to do it but won't have bandwidth for it until Wednesday. Will leave it up-for-grabs until then, if anyone wants to pick it up before then they are welcome to. Otherwise I will assign to myself as soon as I'm ready to work on it.
@ebsmothers Just to check. Have you got the time to validate this on your machines yet? Happy to contribute a PR, but I think it's important that others can also reproduce the speedup.
@gau-nernst thanks for the bump and apologies for not getting to it before now. After patching your changes, I think I actually see an even bigger speedup than you did.. pasting my results below:
This seems pretty compelling to me, we should move forward with this. Would you want to clean up your branch and publish to a PR?
Wow, that is awesome! May I know what is your GPU? Is it A100?
I will push a PR. Design-wise, is the one in my current branch ok? i.e. having a model_loss()
function outside of the recipe class, and assign it to self._model_loss
. I was thinking if the loss computation can be included in the model definition itself, but seems like DPO uses a different loss function (and potentially other fine-tuning schemes that I'm not familiar with). I'm open to other suggestions.
Apart from the QLoRA recipes, do you know if any other recipes would benefit from this also? Full fine-tune single device low-memory probably doesn't benefit, since "optimizer in backward" (and also my other "CPU offload optimizer" work) is not compatible with torch.compile
.
This is great, thanks so much for your work @gau-nernst.
but seems like DPO uses a different loss function (and potentially other fine-tuning schemes that I'm not familiar with). I'm open to other suggestions.
The only non-CE losses we have are for alignment fine tuning atm: https://pytorch.org/torchtune/main/api_ref_modules.html#loss, but they could definitely benefit from compile speedups (particularly PPO). Broadly speaking, in the recipes which use these losses, there's a lot of logic between (simplifying here) logits = self._model
, and self._loss_fn(logits, labels)
.
I would v much appreciate your help in understanding how this works here: from looking at your branch, to compile the loss and model together we need to wrap the loss logic and relevant model calls together into a single callable?
Yes, to compile the model and loss together, we have to wrap them into a single unit, either as a function or an nn.Module
. Typically for projects I work on, since my model can have multiple components, I often wrap all the components under a single nn.Module
class, then I can just call .compile()
on the wrapper. Wrapping them in a function is also fine, like I did in my branch here, which is also simpler.
Making the function a method of the recipe class is also possible I think, but sometimes the self
object may not play nicely with torch.compile()
, so I prefer to make it a standalone function.
Thanks!
I think the DPO/PPO losses might be out if this is the case, since the loss computations are not easily factored out of the recipe, right?
I like seeing it as a standalone function here for CE, I'm not sure we'd want to tie losses in model defn. I think Evan and Felipe will have more intelligent things to say here but happy to chip in when you've got a PR up : )
DPO/PPO losses should still be possible. For example, you can probably just move this out to a standalone function and compile it (at least from the glance of it, I might miss important details π ): https://github.com/pytorch/torchtune/blob/f1db07480458518d9367ca596ae96cf1e9588eca/recipes/lora_dpo_single_device.py#L503-L538
So each recipe can have a standalone (model+loss) function in its own recipe.py
file. My only concern is that it feels "detached" or "far away" from the training loop, so it might be confusing for people to look at.
I want to collect people feedback and thoughts before making a PR, since it seems pointless if I need to re-do it when people want another design π . The changes will be very minimal, as you can see in my proof-of-concept branch, but making the design feels nice and can easily extend to other recipes is the hard part.
These results are so nice! regarding the design, i can take a closer look at it later today if you guys havent solved it by then.
My only concern is that it feels "detached" or "far away" from the training loop, so it might be confusing for people to look at.
I agree, it would be even worse if the function requires a bunch of inputs/outputs. But, for this speed up, i think its worth it. PS: I dont think we would need the "del policy" part, if this is in another function. Maybe we could try to have a simple code mock just see how it would look like?
So each recipe can have a standalone (model+loss) function in its own
recipe.py
file. My only concern is that it feels "detached" or "far away" from the training loop, so it might be confusing for people to look at.
FWIW I don't think this would be possible for the PPO loss, there are many model forwards from different models scattered throughout the recipe https://github.com/pytorch/torchtune/blob/main/recipes/ppo_full_finetune_single_device.py.
What you're describing here looks to me like a step()
function, which would broadly have the signature
def step(self, batch, self._model) -> Tuple[torch.Tensor, ...]
*
where the return type is whatever outputs you need to capture from your loss function. In the example in your branch, you've definted the model_loss
fn outside the recipe class - is this a hard requirement? IMO it'd be much much cleaner to keep it inside.
edit: or, if we don't want any loss.backwards
happening, we could do something like def train_forward
?, batch_forward
?
The changes will be very minimal, as you can see in my proof-of-concept branch, but making the design feels nice and can easily extend to other recipes is the hard part.
Thanks for helping us integrate this awesome feature neatly : ) agreed!
you've definted the model_loss fn outside the recipe class - is this a hard requirement? IMO it'd be much much cleaner to keep it inside.
I understand your point. I have mentioned it before π. In particular, it's regarding the self
object.
Making the function a method of the recipe class is also possible I think, but sometimes the
self
object may not play nicely with torch.compile(), so I prefer to make it a standalone function.
torch.compile()
might have improved since the last time I tried with compiling a class method. You can give it a try with PPO, since I don't know it in detail. I will give it a try for QLoRA recipe too.
Wow, that is awesome! May I know what is your GPU? Is it A100?
@gau-nernst yes it's A100.
I will push a PR. Design-wise, is the one in my current branch ok? i.e. having a
model_loss()
function outside of the recipe class, and assign it toself._model_loss
. I was thinking if the loss computation can be included in the model definition itself, but seems like DPO uses a different loss function (and potentially other fine-tuning schemes that I'm not familiar with). I'm open to other suggestions.
Took an initial look at your PR, I like the move into self._model_loss
, that was going to be my primary comment anyways. But if you do run into problems compiling the recipe method we can try to figure out a reasonable alternative here. On that note, I agree with other comments about having separate nn.Modules for loss and model as much as possible, this also has implications for reusability at inference time.
Apart from the QLoRA recipes, do you know if any other recipes would benefit from this also? Full fine-tune single device low-memory probably doesn't benefit, since "optimizer in backward" (and also my other "CPU offload optimizer" work) is not compatible with
torch.compile
.
I would still try the full finetune recipes. I'd need to check exactly which ones use optimizer_in_backward, but definitely some of the recently-added smaller Qwen2 models do not (e.g.). I think this is a great example of a case where we should give more detailed recommendations based on people's hardware though.. like if I have an A100 I realistically don't need those low-memory optimizations and should instead use this feature. I wanna make sure it's easy for people to understand that and get the best perf they can with minimal friction. Getting off my soapbox.. another point is that we will soon move our distributed recipes to FSDP2 (I have an open PR for our full finetune that should land pretty soon). In that case I think we can try to enable this version of compile there as well.
@ebsmothers Sounds good! I will try to see if I can run full finetune single device with smaller models on my 16GB VRAM GPU π€£. If not, I would need your help to do the final testing.
Tracker
Single device:
Distributed:
I think for non-DPO recipes, the changes is the same as what I have done, but I can't test the distributed recipes myself. Will give the DPO single device a try next.
I can run the distributed this week @gau-nernst
@felipemello1 Thank you for the help! The change is simple, I think it's easier and more convenient if you create the PR and test it yourself. I will help debug if there are any problems (though I don't know much about FSDP2 π₯² and how it interacts with torch.compile())
Heads up @felipemello1, Evan merged the FSDP2 changes into the full_finetune_distributed recipe so that is now officially FSDP2, not FSDP1. And for FSDP2, we've seen a few issues with compile. See #1152.
Curious how your experiments go!
Right now when
compile=True
, only the model is compiledhttps://github.com/pytorch/torchtune/blob/e10142016798cf84f2e5c638a985014384f400a7/recipes/lora_finetune_single_device.py#L383-L386
We can further boost performance by including loss calculations in compile step.
From my benchmarks, the improvement is pretty significant. Running
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device compile=True
on 4080, the improvements is ~1100 tok/s -> ~1450 tok/s (30% improvement)Baseline - main e101420
My branch https://github.com/gau-nernst/torchtune/commit/98749ee07e0449d844548629e2f484ce437210d3
There doesn't seem to be any downside for including loss in compile step. Has anyone tried this before or known any potential issues with this? I haven't done a full fine-tuning + model eval, just did a small benchmark + sanity check above.