pytorch / FBGEMM

FB (Facebook) + GEMM (General Matrix-Matrix Multiplication) - https://code.fb.com/ml-applications/fbgemm/
Other
1.12k stars 453 forks source link

FP8 Triton matmul code silently requires contiguous tensors #2713

Open rationalism opened 1 month ago

rationalism commented 1 month ago

Hello! Thank you very much for this FP8 rowwise matmul code, it's been extremely helpful. However, there is a subtle bug/hidden requirement when eg. calling this code here:

https://github.com/pytorch/FBGEMM/blob/735f27b9070ceff43b289bd6fbeb8bf11c141adf/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py#L97

This works great, but only if the second matrix is contiguous in transposed format (eg. for M, N, K equal to (4,096, 2,048, 1,024), the second matrix must be contiguous in the shape (2,048, 1,024)). If it's not contiguous, the matmul will finish, but the results will be numerically nonsensical.

q10 commented 1 month ago

CC @choutim

sryap commented 4 weeks ago

Hello @rationalism, thank you for your questions.

These https://github.com/triton-lang/triton/pull/3952 and https://github.com/pytorch/pytorch/issues/125437 should be related.

rationalism commented 1 week ago

@q10 @sryap Tri Dao just released a paper on Flash Attention 3, which also has to deal with contiguous-layout FP8 matmul issues. Might be helpful?

https://tridao.me/publications/flash3/flash3.pdf