NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.43k stars 1.4k forks source link

A fused `apply_rotary_pos_emb` implementation for Megatron-Core #1746

Closed yaox12 closed 1 year ago

yaox12 commented 1 year ago

This is a fused apply_rotary_pos_emb implementation for Megatron-Core.

In my preliminary benchmark, it gives 2x - 4x speedup over the unfused version. batch_size=2 and head_num=64 are fixed.

dtype=torch.float32, seq_length=2048, hidden_size=128, rotary_percent=0.5
unfused rope: 0.45 ms
fused rope: 0.14 ms

dtype=torch.float32, seq_length=2048, hidden_size=128, rotary_percent=1.0
unfused rope: 0.67 ms
fused rope: 0.15 ms

dtype=torch.float32, seq_length=2048, hidden_size=256, rotary_percent=0.5
unfused rope: 0.84 ms
fused rope: 0.27 ms

dtype=torch.float32, seq_length=2048, hidden_size=256, rotary_percent=1.0
unfused rope: 1.3 ms
fused rope: 0.3 ms

dtype=torch.float32, seq_length=4096, hidden_size=128, rotary_percent=0.5
unfused rope: 0.85 ms
fused rope: 0.23 ms

dtype=torch.float32, seq_length=4096, hidden_size=128, rotary_percent=1.0
unfused rope: 1.3 ms
fused rope: 0.3 ms

dtype=torch.float32, seq_length=4096, hidden_size=256, rotary_percent=0.5
unfused rope: 1.6 ms
fused rope: 0.75 ms

dtype=torch.float32, seq_length=4096, hidden_size=256, rotary_percent=1.0
unfused rope: 2.6 ms
fused rope: 0.58 ms