NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.42k stars 1.4k forks source link

[FusedRoPE] Fuse type conversion and cos/sin #1752

Closed yaox12 closed 12 months ago

yaox12 commented 1 year ago

This PR fuses sin/cos calculation of freqs and the data type conversion into the fused RoPE kernel, which reduces 4 tiny element-wise kernels.

yaox12 commented 12 months ago

@crcrpar Ready for merging. Thanks.