Closed robieta closed 2 weeks ago
Thanks a lot for the PR! Do you have some rough estimates in terms of how the performance is before and after? E.g., if it is a noticeable difference, it could potentially be related to #1369
@rasbt It's going to be super case dependent. (The LoRA one is definitely the much more important one.) I saw ~5%, but host-device syncs can vary from no difference to several-fold slowdown. For https://github.com/Lightning-AI/litgpt/issues/1369 it's impossible to say anything without a profile. (It's not clear to me that it should be related, but stranger things have happened.)
By the way, I did an audit of other uses of torch.tensor
and was pleasantly surprised to find no other cases that looked problematic. (Which is very unusual for a codebase of this size and complexity.) Thanks for keeping the bar high everyone!
Added comments and fixed the lora_ind
issue.
Hah, I have already forgotten that I've created a PR to eliminate unnecessary CUDA sync during the zero_pad
call. I remember that there was an issue with a CUDA Stream overflow during the backward pass, which made the backward call slower, but thanks to a speedup during the forward pass the overall time during training was smaller.
The funniest part is that I started to investigate it after @carmocca recommend watching the video where Taylor explained how to do profiling.
Eventually, @robieta fixed the issue himself 🙃.
This PR fixes two CUDA syncs that I ran across when optimizing Gemma:
1)
max(1, non_masked_elems)
This punts to python int before being implicitly converted to a Tensor. (I'm pretty sure I'm responsible for this one.) We need to use the uglier but more performantnon_masked_elems.maximum(torch.ones_like(non_masked_elems))
.2)
torch.tensor(self.lora_ind, device=result.device)
This one is a little harder because we genuinely do need to move data from host to device. However,lora_ind
is set in__init__
and doesn't change. So the best we can do is cache the first time we see it on a given device.NOTE: It's very important that we do our own caching rather than use
functools.cache
, as the latter extends the life ofself
by storing it in the cache.