Lightning-AI / litgpt

Pretrain, finetune, deploy 20+ LLMs on your own data. Uses state-of-the-art techniques: flash attention, FSDP, 4-bit, LoRA, and more.
https://lightning.ai
Apache License 2.0
6.85k stars 726 forks source link

Eliminate cuda syncs #1374

Closed robieta closed 2 weeks ago

robieta commented 2 weeks ago

This PR fixes two CUDA syncs that I ran across when optimizing Gemma:

1) max(1, non_masked_elems) This punts to python int before being implicitly converted to a Tensor. (I'm pretty sure I'm responsible for this one.) We need to use the uglier but more performant non_masked_elems.maximum(torch.ones_like(non_masked_elems)).

2) torch.tensor(self.lora_ind, device=result.device) This one is a little harder because we genuinely do need to move data from host to device. However, lora_ind is set in __init__ and doesn't change. So the best we can do is cache the first time we see it on a given device.

NOTE: It's very important that we do our own caching rather than use functools.cache, as the latter extends the life of self by storing it in the cache.

rasbt commented 2 weeks ago

Thanks a lot for the PR! Do you have some rough estimates in terms of how the performance is before and after? E.g., if it is a noticeable difference, it could potentially be related to #1369

robieta commented 2 weeks ago

@rasbt It's going to be super case dependent. (The LoRA one is definitely the much more important one.) I saw ~5%, but host-device syncs can vary from no difference to several-fold slowdown. For https://github.com/Lightning-AI/litgpt/issues/1369 it's impossible to say anything without a profile. (It's not clear to me that it should be related, but stranger things have happened.)

robieta commented 2 weeks ago

By the way, I did an audit of other uses of torch.tensor and was pleasantly surprised to find no other cases that looked problematic. (Which is very unusual for a codebase of this size and complexity.) Thanks for keeping the bar high everyone!

robieta commented 2 weeks ago

Added comments and fixed the lora_ind issue.

Andrei-Aksionov commented 1 week ago

Hah, I have already forgotten that I've created a PR to eliminate unnecessary CUDA sync during the zero_pad call. I remember that there was an issue with a CUDA Stream overflow during the backward pass, which made the backward call slower, but thanks to a speedup during the forward pass the overall time during training was smaller. The funniest part is that I started to investigate it after @carmocca recommend watching the video where Taylor explained how to do profiling. Eventually, @robieta fixed the issue himself 🙃.