tairov / llama2.mojo

Inference Llama 2 in one file of pure 🔥
https://www.modular.com/blog/community-spotlight-how-i-built-llama2-by-aydyn-tairov
MIT License
2.09k stars 139 forks source link

improve readability in batch_matmul #75

Closed mikowals closed 9 months ago

mikowals commented 9 months ago

Llamatune results on M1 Pro: V1 = master, V2 = this PR

Benchmark 1: run_v1 stories15M.bin Time (mean ± σ): 327.2 ms ± 8.6 ms [User: 1945.5 ms, System: 65.6 ms] Range (min … max): 316.4 ms … 352.2 ms 30 runs

Benchmark 2: run_v2 stories15M.bin Time (mean ± σ): 325.2 ms ± 5.6 ms [User: 1928.7 ms, System: 65.6 ms] Range (min … max): 315.0 ms … 338.6 ms 30 runs

Benchmark 3: run_v1 stories42M.bin Time (mean ± σ): 756.8 ms ± 16.8 ms [User: 4164.0 ms, System: 157.5 ms] Range (min … max): 732.3 ms … 811.0 ms 30 runs

Benchmark 4: run_v2 stories42M.bin Time (mean ± σ): 750.0 ms ± 23.4 ms [User: 4110.9 ms, System: 147.9 ms] Range (min … max): 718.2 ms … 811.6 ms 30 runs

Benchmark 5: run_v1 stories110M.bin Time (mean ± σ): 1.979 s ± 0.020 s [User: 11.201 s, System: 0.372 s] Range (min … max): 1.952 s … 2.027 s 30 runs

Benchmark 6: run_v2 stories110M.bin Time (mean ± σ): 1.969 s ± 0.023 s [User: 11.241 s, System: 0.354 s] Range (min … max): 1.939 s … 2.038 s 30 runs

Tiny improvement but consistent so it is probably real. I expect it is the reduce_add() removal in softmax that contributed most.

tairov commented 9 months ago

Thank you for consistent improvements!