philipturner / metal-flash-attention

FlashAttention (Metal Port)
MIT License
347 stars 18 forks source link

`bfloat16` support #17

Closed cloneable closed 1 week ago

cloneable commented 5 months ago

Hi!

I was wondering if you would be interested in adding bf16 support to MFA or at least the GEMM kernels? For mlx Apple defined a custom type: https://github.com/ml-explore/mlx/blob/76c919b4ecf0cccaa1cfef214d12be0ad71485cc/mlx/backend/metal/kernels/bf16.h (MIT licensed), so I understand supporting this is not easy and maybe not even desirable because it's not a native type and performance is not great anyway.

Btw, I came here via huggingface/candle. It uses libMFA for matmul and FA.

philipturner commented 5 months ago

@ivarflakstad has been working on BF16 support for a while. A very rough draft of a BF16 kernel got 50%-60% performance of FP32, and 40-50% performance of FP16.

I had a hypothesis that we could get ~same performance as FP32, simply by unpacking BF16 and zero-filling in the bits to FP32. Actually, ignoring the remaining bits to avoid the compute cost of zero-filling. It loses one bit of precision compared to the MLX header, but likely 2-3x faster on earlier M1 chips (without hardware BF16 support).

BF16(0b10101010_10101010)
->
FP32(0b10101010_10101010_00000000_00000000)
or
FP32(0b10101010_10101010_garbageX_XXXXXXXX)

The latter could round to:
BF16(0b10101010_10101010)
BF16(0b10101010_10101011)
depending on the garbage bits

Therefore:
BF15(0b10101010_1010101X)
cloneable commented 5 months ago

Thanks a lot for the quick reply! Seems like the candle team is already on it. That's great! I'm going to follow ivarflakstad's work then.

we could get ~same performance as FP32, simply by unpacking BF16 [..] to FP32

That could indeed be a great option for M1s. Maybe data could be even quantized to bf15+0 to reduce the influence of garbage bits. Not sure if it's worth it, needs to be tested.

Anyway, thanks again!

cloneable commented 5 months ago

Sorry, one last quick question: Do you know if M1 GPU understands bf16, but M1 CPU does not? I just came across this doc: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf, which says bf16 is supported (search "brain float"). Is that possible? M1 CPU is ARM 8.5 based, which doesn't mandate bf16 support.

ivarflakstad commented 5 months ago

Hey! Haven’t had time to finish the work. It looked promising though so I’ll try to find the time.

Your machine can run metal code with bfloat as a language concept, but the gpu doesn’t have built-in support like it does with float. I believe that came with the M3.

philipturner commented 5 months ago

Do you know if M1 GPU understands bf16, but M1 CPU does not?

The M2 CPU first introduced hardware support for BF16. In both the NEON and AMX units. BFloat16 is not available from Swift SIMD vectors, Accelerate, or BLAS/LAPACK. However, you can access it from inlined assembly, and it might be used internally by CoreML. Same goes for regular Float16, except that it's accessible to Swift SIMD vectors.

For practical tasks (outside of AI), I have never found a speedup from using 16-bit lanes in SIMD vectorized CPU code. Except for compressing 16-bit integers and bitmasking. On the GPU, it also provides negligible speedup, except for (1) incrementally decreasing register pressure or (2) decreasing memory bandwidth. With the M3 architecture, there might be serious compute speedups though. If someone figures out how to utilize the FP32 and FP16/BF16 units simultaneously.