pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.25k stars 165 forks source link

Improved `repeat_kv` eager perf #418

Closed awgu closed 3 months ago

awgu commented 3 months ago

Stack from ghstack (oldest at bottom):

Overview This PR replaces x[:, :, :, None, :] with torch.unsqueeze(x, dim=3) to avoid unnecessary device copies and fill kernels in backward (namely, 4 fills and 4 copies with shape (bs, seq_len, n_kv_heads, head_dim)).

Traces Existing forward (CPU):

Screenshot 2024-06-24 at 3 33 24 PM

Existing backward (CPU):

Screenshot 2024-06-24 at 3 34 48 PM

Existing backward (GPU):

Screenshot 2024-06-24 at 3 35 40 PM

Each aten::slice in forward leads to a SliceBackward0, which is aten::zeros -> aten::copy_.

New forward (CPU):

Screenshot 2024-06-24 at 3 39 25 PM

New backward (CPU):

Screenshot 2024-06-24 at 3 37 34 PM

New backward (GPU): has no more fill kernels and device copies

Screenshot 2024-06-24 at 3 38 51 PM

Test Plan

awgu commented 3 months ago

@tianyu-l I opened a PR into the meta-llama/llama3 repo with this change.

Personally, I think that this kind of change is valid for us to do in torchtitan. It does not affect the model structure, is not intrusive (in that, llama3 repo could adopt this pretty easily if they wanted to), and is numerics preserving.