Synchronize to latest composable kernel commit which added inline-asm implementation of fp32 to bf16 RTN conversion. Using inline-asm RTN conversion is able to improve the performance when BF16+RTN is used
Add compiler options for compiling c++ extension on ROCM/HIP, which is able to improve the performance of HIP FMHA BWD on ROCM 6.2
The following are benchmark results compared with triton when using RTN with those compiling options added on ROCM 6.2
Run reference fwd:
Reference fwd time: 28.90159034729004
Run reference bwd:
Reference bwd time: 48.68329620361328
Run triton fwd:
Triton fwd time: 2.0252671241760254
Run triton bwd:
Triton bwd time: 6.977703094482422
Run CK fwd:
xformers fwd time: 1.8350895643234253
Run CK fwd:
xformers bwd time: 7.089707374572754
(triton_dq - ref_dq).abs().mean()=tensor(0.0002, device='cuda:0', dtype=torch.bfloat16)
(triton_dk - ref_dk).abs().mean()=tensor(0.0001, device='cuda:0', dtype=torch.bfloat16)
(triton_dv - ref_dv).abs().mean()=tensor(0.0004, device='cuda:0', dtype=torch.bfloat16)
(xformer_dq - ref_dq).abs().mean()=tensor(0.0002, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
(xformer_dk - ref_dk).abs().mean()=tensor(0.0002, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
(xformer_dv - ref_dv).abs().mean()=tensor(6.7234e-05, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
The following are benchmark results compared with triton when using RTN without those compiling options added on ROCM 6.2
Run reference fwd:
Reference fwd time: 28.867050170898438
Run reference bwd:
Reference bwd time: 48.91793441772461
Run triton fwd:
Triton fwd time: 2.056668996810913
Run triton bwd:
Triton bwd time: 6.982858180999756
Run CK fwd:
xformers fwd time: 1.8234171867370605
Run CK fwd:
xformers bwd time: 7.428786754608154
(triton_dq - ref_dq).abs().mean()=tensor(0.0002, device='cuda:0', dtype=torch.bfloat16)
(triton_dk - ref_dk).abs().mean()=tensor(0.0001, device='cuda:0', dtype=torch.bfloat16)
(triton_dv - ref_dv).abs().mean()=tensor(0.0004, device='cuda:0', dtype=torch.bfloat16)
(xformer_dq - ref_dq).abs().mean()=tensor(0.0002, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
(xformer_dk - ref_dk).abs().mean()=tensor(0.0002, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
(xformer_dv - ref_dv).abs().mean()=tensor(8.7738e-05, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
This PR provide: