tairov / llama2.mojo

Inference Llama 2 in one file of pure 🔥
https://www.modular.com/blog/community-spotlight-how-i-built-llama2-by-aydyn-tairov
MIT License
2.09k stars 140 forks source link

Increase number of SIMD registers used to speed-up the execution #5

Closed VMois closed 1 year ago

VMois commented 1 year ago

I spent some time investigating why parallelized + vectorized version of matmul is slower than only vectorized.

Older Matmul examples showed that multi-core + vector was faster. Still, for me, the Matmul notebook example on Playground and Matmul example from the repo run on the GitHub Codespaces instance (4 cores, 16GB) showed that the multi-core version was slower.

I tried two commands: mojo examples/matmul.mojo and mojo build examples/matmul.mojo + run the binary. They had the same results, multi-core slower. In addition, using htop, I also made sure that the multi-core is utilizing all cores.

I found this PR - https://github.com/modularml/mojo/pull/742 where you could see the value for vector width you get from simdwidthof is multiplied. In the case of the GitHub Codespace instance, my base value from simdwidthof was 8, I benchmarked higher values like 16 (2x), 32 (4x), and 64 (8x). You can see the results below:

image

I believe adjusting nelts value should bring additional speed-ups.

https://github.com/tairov/llama2.mojo/blob/86a34c95cd3631137ca1a1505deb96446c5a881c/llama2.mojo#L24

CPU details:

System information: 
    OS          :  linux
    CPU         :  znver3
    Arch        :  x86_64-unknown-linux-gnu
    Num Cores   :  4
    CPU Features:  avx2
VMois commented 1 year ago

From CPU info (AVX2), it looks like 32-bit float should fit only in 8 registers, and this is what simdwidthof returns, but, using higher values gives a performance boost. Strange. I don't have enough understanding to explain. I will try to dig deeper, but I will be happy if someone more knowledgeable can explain.

VMois commented 1 year ago

Matmul with alias nelts = 16 * simdwidthof[DType.float32]() (x16). If you use 32, the program will crash. I assume it is because the CPU has only 16 SIMD registers.

image
tairov commented 1 year ago

Looks cool, thanks @VMois for researching this topic. I got about 20% improvement when multiplied nelts by 2, further multiplications leading to degradation though.. I think this is related to the nature of data you're manipulating, sizes of matrixes, etc. Still couldn't have parallelize working, but I'll check out that PR you found.

VMois commented 1 year ago

I don't know why, with 4 cores, I do not get a roughly 4x speed-up on matmul. I can understand a context switch, etc., would slow down the execution, but with my tests, it is only 1.9x speed-up.

I got about 20% improvement when multiplied nelts by 2, further multiplications leading to degradation though..

Can you, please, run deviceinfo.mojo and post your device details here? I am curious about what CPU type you have. Also, make sure to try all numbers from 2 to 16 (if you haven't already).

I think this is related to the nature of the data you're manipulating, the sizes of matrixes, etc.

You are probably right. Matmul example is quite simple, your code is more advanced.

Still couldn't have parallelize working...

By "couldn't have parallelize working" do you mean it is slower than the vectorized version or you had errors? If you have some multi-core code ready, maybe, you can create a new branch for others to see and experiment with it. Maybe, someone will figure it out.

Thank you for cool project!

VMois commented 1 year ago

You can try using Tensor instead of Matrix3. But one small note is that Tensor has nelts of 1 hardcoded.

P.S Nevermind this comment, it is not good :)

tairov commented 1 year ago

Where did you find the hardcoded nelts value = 1 ? @VMois Also do you know how can Tensor help to improve llama2 performance? As I understood Tensor is just another wrapper around data: DTypePointer.

VMois commented 1 year ago

Where did you find the hardcoded nelts value = 1 ?

I looked in the wrong place in the docs. nelts = 1 was for getting a single item. For load, it can be set.

Also do you know how can Tensor help to improve llama2 performance?

Not really. I just forked your repo and managed to replicate your speed-up results from alias nelts = 2 x .... I am looking into Tensor right now.

As I understood Tensor is just another wrapper around data: DTypePointer.

Probably. I am looking at your Matrix3 code to maybe find some optimization opportunities but so far nothing. I am considering profiling the code to see what takes the longest time.

tairov commented 1 year ago

@VMois From my experience tinkeing with llama2.c & then porting it to Python llama2.py, most of the CPU time is consuming in matmul. Probably around 80-90%. Also llama2.c contributors got small improvements by implementing sorted_vocabs

Anyway, I would love to see some profiling reports.

PS. I don't think Matrix3 is worthwhile to look for time consuming efficiency wins.

VMois commented 1 year ago

This commit applies multiplication for nelts - https://github.com/tairov/llama2.mojo/commit/06c6076a9dc1702d527279db9c368090da5f5868. I think it is safe to close this issue.