ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.23k stars 925 forks source link

[Question] Performance of mx.fast.scaled_dot_product_attention #1193

Open cccyf opened 2 months ago

cccyf commented 2 months ago

Describe the bug I implemented a model with mx.fast.scaled_dot_product_attention but observed the performance improves significantly when the I apply mask=mask.astype(q.dtype) before sdpa. The model dtype is bfloat16, and the mask before applying astype has dtype float32. Thus I did an experiment to perf test the mx.fast.scaled_dot_product_attention with different input dtypes.

To Reproduce

import time

import mlx.core as mx

def benchmark_sdpa(q: mx.array, k: mx.array, v: mx.array, mask: mx.array):
    mx.eval(q, k, v, mask)
    for _ in range(100):
        mx.eval(mx.fast.scaled_dot_product_attention(
            q, k, v, scale=1, mask=mask
        ))
    toi = time.perf_counter()
    for _ in range(3200):
        mx.eval(mx.fast.scaled_dot_product_attention(q, k, v, scale=1, mask=mask))
    toc = time.perf_counter()
    tpi = 1e3 * (toc - toi) / 3200
    print(f"q {q.dtype}, k {k.dtype}, v {v.dtype}, mask {mask.dtype} takes {tpi} ms")

if __name__ == "__main__":
    print(mx.default_device())
    q = mx.random.uniform(shape=(1, 32, 1, 4096 // 32)).astype(mx.bfloat16)
    k = mx.random.uniform(shape=(1, 32, 16, 4096 // 32)).astype(mx.bfloat16)
    v = mx.random.uniform(shape=(1, 32, 16, 4096 // 32)).astype(mx.bfloat16)
    mask = mx.zeros(shape=(1, 16)).astype(mx.bfloat16)
    mx.eval(q, k, v, mask)
    for i in range(5):
        mx.metal.clear_cache()
        print(f"run {i}")
        # O = softmax(Q @ K.T * scale + mask, dim=-1) @ V
        benchmark_sdpa(q, k, v, mask) 
        benchmark_sdpa(q, k, v, mask.astype(mx.float32)) 
        benchmark_sdpa(q, k, v.astype(mx.float32), mask.astype(mx.float32))
        benchmark_sdpa(q.astype(mx.float32), k, v, mask.astype(mx.float32))
        benchmark_sdpa(q.astype(mx.float32), k, v.astype(mx.float32), mask.astype(mx.float32))
        benchmark_sdpa(q.astype(mx.float32), k.astype(mx.float32), v.astype(mx.float32), mask.astype(mx.float32))

Output

Device(gpu, 0)
run 0
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.29217307281214744 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.3049883465610037 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.4575847265550692 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.5957424481312046 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.28829277343902504 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2711911718779447 ms
run 1
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.2815349999946193 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.30608838531406946 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.4103653778111038 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.5485320703155594 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2936811981271603 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.271009414063883 ms
run 2
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.2813810156294494 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.42186102844425477 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5979819399999542 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.5894057940622588 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.29002313781347766 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2728101821958262 ms
run 3
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.2963425521920726 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.3069471484377573 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.4481344268697285 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.5586622524970153 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2904496224982722 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.27179540375072975 ms
run 4
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.2822949871824676 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.3853401953165303 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5780678125029226 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.5801494400020601 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.29036419280600967 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2726223828176444 ms

Expected behavior Before this experiment, I thought bfloat16 would be the fastest, and float32 would be the slowest.

However, from the experiment result, it seems to me that mx.fast.scaled_dot_product_attention can be fast when q, k, v and mask have the same dtype (either bfloat16 and float32). When q, k, v, mask have different dtypes, the computation seems slower. Could you help me understand the reason? Is it due to the implicit dtype conversion?

Comparing with torch sdpa requires q, k, v and mask to have the same dtype, I also wonder if it is intended to have an implicit dtype conversion in mlx?

Desktop (please complete the following information):

awni commented 2 months ago

The performance differences you are seeing is likely due to implicit casting. When you call sdpa it will promot q, k, v and mask to a common data type. If any of them are float32 then they will all become float32 and the output is in float32. So you are paying for all of those type casts as well doing the computation in fp32 which can be a little slower for inference.

cccyf commented 2 months ago

Thanks @awni for the reply! I did another experiment to explicitly cast the dtypes to float32 before calling sdpa. Here's the code and output

import time

import mlx.core as mx

def benchmark_sdpa(q: mx.array, k: mx.array, v: mx.array, mask: mx.array):
    mx.eval(q, k, v, mask)
    for _ in range(100):
        mx.eval(mx.fast.scaled_dot_product_attention(
            q, k, v, scale=1, mask=mask
        ))
    toi = time.perf_counter()
    for _ in range(3200):
        mx.eval(mx.fast.scaled_dot_product_attention(q, k, v, scale=1, mask=mask))
    toc = time.perf_counter()
    tpi = 1e3 * (toc - toi) / 3200
    print(f"q {q.dtype}, k {k.dtype}, v {v.dtype}, mask {mask.dtype} takes {tpi} ms")

def benchmark_sdpa_with_explicit_dtype_cast(q: mx.array, k: mx.array, v: mx.array, mask: mx.array):
    mx.eval(q, k, v, mask)
    for _ in range(100):
        mx.eval(mx.fast.scaled_dot_product_attention(
            q.astype(mask.dtype), k.astype(mask.dtype), v.astype(mask.dtype), scale=1, mask=mask
        ))
    toi = time.perf_counter()
    for _ in range(3200):
        mx.eval(mx.fast.scaled_dot_product_attention(q.astype(mask.dtype), k.astype(mask.dtype), v.astype(mask.dtype), scale=1, mask=mask))
    toc = time.perf_counter()
    tpi = 1e3 * (toc - toi) / 3200
    print(f"all casted to {mask.dtype}, q {q.dtype}, k {k.dtype}, v {v.dtype}, mask {mask.dtype} takes {tpi} ms")

if __name__ == "__main__":
    print(mx.default_device())
    q = mx.random.uniform(shape=(1, 32, 1, 4096 // 32)).astype(mx.bfloat16)
    k = mx.random.uniform(shape=(1, 32, 16, 4096 // 32)).astype(mx.bfloat16)
    v = mx.random.uniform(shape=(1, 32, 16, 4096 // 32)).astype(mx.bfloat16)
    mask = mx.zeros(shape=(1, 16)).astype(mx.bfloat16)
    mx.eval(q, k, v, mask)
    for i in range(5):
        mx.metal.clear_cache()
        print(f"run {i}")
        # O = softmax(Q @ K.T * scale + mask, dim=-1) @ V
        benchmark_sdpa(q, k, v, mask) 
        benchmark_sdpa(q, k, v, mask.astype(mx.float32)) 
        benchmark_sdpa(q, k, v.astype(mx.float32), mask.astype(mx.float32))
        benchmark_sdpa(q.astype(mx.float32), k, v, mask.astype(mx.float32))
        benchmark_sdpa(q.astype(mx.float32), k, v.astype(mx.float32), mask.astype(mx.float32))
        benchmark_sdpa(q.astype(mx.float32), k.astype(mx.float32), v.astype(mx.float32), mask.astype(mx.float32))
        benchmark_sdpa_with_explicit_dtype_cast(q, k, v, mask) 
        benchmark_sdpa_with_explicit_dtype_cast(q, k, v, mask.astype(mx.float32)) 
        benchmark_sdpa_with_explicit_dtype_cast(q, k, v.astype(mx.float32), mask.astype(mx.float32))
        benchmark_sdpa_with_explicit_dtype_cast(q.astype(mx.float32), k, v, mask.astype(mx.float32))
        benchmark_sdpa_with_explicit_dtype_cast(q.astype(mx.float32), k, v.astype(mx.float32), mask.astype(mx.float32))
        benchmark_sdpa_with_explicit_dtype_cast(q.astype(mx.float32), k.astype(mx.float32), v.astype(mx.float32), mask.astype(mx.float32))

Output

Device(gpu, 0)
run 0
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.26594493499942473 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.3516292837502988 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.37887302093622566 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.4674200912540982 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2810128124929179 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.23619820312887896 ms
all casted to mlx.core.bfloat16, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.24766160156104888 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.2771399868743174 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2908905468757439 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.47396273437698255 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.28099125000153435 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.26967585937200056 ms
run 1
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.2781498568765528 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.2924935415649088 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.3050028774941893 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.4326933203083172 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2690416015593655 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.26142856749174825 ms
all casted to mlx.core.bfloat16, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.27783303375144897 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.2790427603213175 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.3378241015616368 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.4084278253139928 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.24933923186836182 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2695965625025565 ms
run 2
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.25753208343303413 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.33224005218471575 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.3961995574991306 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.4720461718716251 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.28775282562492066 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2675226562496391 ms
all casted to mlx.core.bfloat16, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.2700479428131075 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.2767542578112625 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.27714999999261636 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.3950258071836288 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2968962240538531 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2749186587516306 ms
run 3
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.2650807553072809 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.29339286436879775 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.36821903656345967 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.45831914062546275 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2849734112533042 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2833158984412876 ms
all casted to mlx.core.bfloat16, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.2937799609389913 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.27214130218453647 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.29444730468640046 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.36963549468055135 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.27984914062471944 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2861127084452164 ms
run 4
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.29158317718611215 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.40893740874707873 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5962891406215931 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.5590144531288388 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2735604556164617 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2754762762469909 ms
all casted to mlx.core.bfloat16, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.28663395843977924 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.2904960024989123 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2926441143699776 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.41776882811973337 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.27934273437494994 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2723772787521739 ms

Please correct me if I'm wrong - Based on the output, it appears that casting before sdpa can actually speed up the process. Could you help me understand why implicit casting inside sdpa seems slower than explicit casting outside of it? I didn't expect it to make such a difference.

awni commented 2 months ago

Could you help me understand why implicit casting inside sdpa seems slower than explicit casting outside of it? I didn't expect it to make such a difference.

Sorry, there's a lot of results there, I'm not sure what you are looking at. Could you point me to the two cases that are unexpected to you?

cccyf commented 2 months ago

Sure! For example

q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.37887302093622566 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2908905468757439 ms

It appears that casting all inputs to mx.float32 before calling sdpa is faster than letting sdpa do the implicit dtype casting.

awni commented 2 months ago

I don't see the same results on my M1 Max. They look pretty similar though there is some variance in the timings in general:

q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.45112710937500006 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5807739062499999 ms
run 1
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5941705990625 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5872773306250001 ms
run 2
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5863978646875001 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5813158334375002 ms
run 3
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.4938347265624998 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5938519790625002 ms
run 4
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.573169166875 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.6042057031249992 ms