NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.33k stars 1.39k forks source link

[FusedRoPE] Fuse type conversion and cos/sin #1752

Closed yaox12 closed 10 months ago

yaox12 commented 10 months ago

This PR fuses sin/cos calculation of freqs and the data type conversion into the fused RoPE kernel, which reduces 4 tiny element-wise kernels.

yaox12 commented 10 months ago

@crcrpar Ready for merging. Thanks.