Open cccyf opened 2 months ago
The performance differences you are seeing is likely due to implicit casting. When you call sdpa it will promot q
, k
, v
and mask
to a common data type. If any of them are float32
then they will all become float32
and the output is in float32
. So you are paying for all of those type casts as well doing the computation in fp32
which can be a little slower for inference.
Thanks @awni for the reply! I did another experiment to explicitly cast the dtypes to float32 before calling sdpa. Here's the code and output
import time
import mlx.core as mx
def benchmark_sdpa(q: mx.array, k: mx.array, v: mx.array, mask: mx.array):
mx.eval(q, k, v, mask)
for _ in range(100):
mx.eval(mx.fast.scaled_dot_product_attention(
q, k, v, scale=1, mask=mask
))
toi = time.perf_counter()
for _ in range(3200):
mx.eval(mx.fast.scaled_dot_product_attention(q, k, v, scale=1, mask=mask))
toc = time.perf_counter()
tpi = 1e3 * (toc - toi) / 3200
print(f"q {q.dtype}, k {k.dtype}, v {v.dtype}, mask {mask.dtype} takes {tpi} ms")
def benchmark_sdpa_with_explicit_dtype_cast(q: mx.array, k: mx.array, v: mx.array, mask: mx.array):
mx.eval(q, k, v, mask)
for _ in range(100):
mx.eval(mx.fast.scaled_dot_product_attention(
q.astype(mask.dtype), k.astype(mask.dtype), v.astype(mask.dtype), scale=1, mask=mask
))
toi = time.perf_counter()
for _ in range(3200):
mx.eval(mx.fast.scaled_dot_product_attention(q.astype(mask.dtype), k.astype(mask.dtype), v.astype(mask.dtype), scale=1, mask=mask))
toc = time.perf_counter()
tpi = 1e3 * (toc - toi) / 3200
print(f"all casted to {mask.dtype}, q {q.dtype}, k {k.dtype}, v {v.dtype}, mask {mask.dtype} takes {tpi} ms")
if __name__ == "__main__":
print(mx.default_device())
q = mx.random.uniform(shape=(1, 32, 1, 4096 // 32)).astype(mx.bfloat16)
k = mx.random.uniform(shape=(1, 32, 16, 4096 // 32)).astype(mx.bfloat16)
v = mx.random.uniform(shape=(1, 32, 16, 4096 // 32)).astype(mx.bfloat16)
mask = mx.zeros(shape=(1, 16)).astype(mx.bfloat16)
mx.eval(q, k, v, mask)
for i in range(5):
mx.metal.clear_cache()
print(f"run {i}")
# O = softmax(Q @ K.T * scale + mask, dim=-1) @ V
benchmark_sdpa(q, k, v, mask)
benchmark_sdpa(q, k, v, mask.astype(mx.float32))
benchmark_sdpa(q, k, v.astype(mx.float32), mask.astype(mx.float32))
benchmark_sdpa(q.astype(mx.float32), k, v, mask.astype(mx.float32))
benchmark_sdpa(q.astype(mx.float32), k, v.astype(mx.float32), mask.astype(mx.float32))
benchmark_sdpa(q.astype(mx.float32), k.astype(mx.float32), v.astype(mx.float32), mask.astype(mx.float32))
benchmark_sdpa_with_explicit_dtype_cast(q, k, v, mask)
benchmark_sdpa_with_explicit_dtype_cast(q, k, v, mask.astype(mx.float32))
benchmark_sdpa_with_explicit_dtype_cast(q, k, v.astype(mx.float32), mask.astype(mx.float32))
benchmark_sdpa_with_explicit_dtype_cast(q.astype(mx.float32), k, v, mask.astype(mx.float32))
benchmark_sdpa_with_explicit_dtype_cast(q.astype(mx.float32), k, v.astype(mx.float32), mask.astype(mx.float32))
benchmark_sdpa_with_explicit_dtype_cast(q.astype(mx.float32), k.astype(mx.float32), v.astype(mx.float32), mask.astype(mx.float32))
Output
Device(gpu, 0)
run 0
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.26594493499942473 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.3516292837502988 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.37887302093622566 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.4674200912540982 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2810128124929179 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.23619820312887896 ms
all casted to mlx.core.bfloat16, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.24766160156104888 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.2771399868743174 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2908905468757439 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.47396273437698255 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.28099125000153435 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.26967585937200056 ms
run 1
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.2781498568765528 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.2924935415649088 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.3050028774941893 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.4326933203083172 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2690416015593655 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.26142856749174825 ms
all casted to mlx.core.bfloat16, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.27783303375144897 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.2790427603213175 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.3378241015616368 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.4084278253139928 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.24933923186836182 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2695965625025565 ms
run 2
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.25753208343303413 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.33224005218471575 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.3961995574991306 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.4720461718716251 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.28775282562492066 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2675226562496391 ms
all casted to mlx.core.bfloat16, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.2700479428131075 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.2767542578112625 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.27714999999261636 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.3950258071836288 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2968962240538531 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2749186587516306 ms
run 3
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.2650807553072809 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.29339286436879775 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.36821903656345967 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.45831914062546275 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2849734112533042 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2833158984412876 ms
all casted to mlx.core.bfloat16, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.2937799609389913 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.27214130218453647 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.29444730468640046 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.36963549468055135 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.27984914062471944 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2861127084452164 ms
run 4
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.29158317718611215 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.40893740874707873 ms
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5962891406215931 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.5590144531288388 ms
q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2735604556164617 ms
q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2754762762469909 ms
all casted to mlx.core.bfloat16, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.bfloat16 takes 0.28663395843977924 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.2904960024989123 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2926441143699776 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.bfloat16, mask mlx.core.float32 takes 0.41776882811973337 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.27934273437494994 ms
all casted to mlx.core.float32, q mlx.core.float32, k mlx.core.float32, v mlx.core.float32, mask mlx.core.float32 takes 0.2723772787521739 ms
Please correct me if I'm wrong - Based on the output, it appears that casting before sdpa can actually speed up the process. Could you help me understand why implicit casting inside sdpa seems slower than explicit casting outside of it? I didn't expect it to make such a difference.
Could you help me understand why implicit casting inside sdpa seems slower than explicit casting outside of it? I didn't expect it to make such a difference.
Sorry, there's a lot of results there, I'm not sure what you are looking at. Could you point me to the two cases that are unexpected to you?
Sure! For example
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.37887302093622566 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2908905468757439 ms
It appears that casting all inputs to mx.float32
before calling sdpa is faster than letting sdpa do the implicit dtype casting.
I don't see the same results on my M1 Max. They look pretty similar though there is some variance in the timings in general:
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.45112710937500006 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5807739062499999 ms
run 1
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5941705990625 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5872773306250001 ms
run 2
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5863978646875001 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5813158334375002 ms
run 3
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.4938347265624998 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.5938519790625002 ms
run 4
q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.573169166875 ms
all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.6042057031249992 ms
Describe the bug I implemented a model with
mx.fast.scaled_dot_product_attention
but observed the performance improves significantly when the I applymask=mask.astype(q.dtype)
before sdpa. The model dtype is bfloat16, and the mask before applyingastype
has dtype float32. Thus I did an experiment to perf test themx.fast.scaled_dot_product_attention
with different input dtypes.To Reproduce
Output
Expected behavior Before this experiment, I thought
bfloat16
would be the fastest, andfloat32
would be the slowest.However, from the experiment result, it seems to me that
mx.fast.scaled_dot_product_attention
can be fast whenq, k, v
andmask
have the samedtype
(eitherbfloat16
andfloat32
). Whenq, k, v, mask
have different dtypes, the computation seems slower. Could you help me understand the reason? Is it due to the implicit dtype conversion?Comparing with torch sdpa requires q, k, v and mask to have the same dtype, I also wonder if it is intended to have an implicit dtype conversion in mlx?
Desktop (please complete the following information):