tairov / llama2.mojo

Inference Llama 2 in one file of pure 🔥
https://www.modular.com/blog/community-spotlight-how-i-built-llama2-by-aydyn-tairov
MIT License
2.09k stars 140 forks source link

replace Matrix with Tensor #38

Closed mikowals closed 11 months ago

mikowals commented 11 months ago

Performance seems the same as the Matrix version. I got between 550 and 650 toks / sec in both versions for stories15M.bin. Running on a Codespace instance with 4 hardware threads and a SIMD width of 16.

tairov commented 11 months ago

Thanks for PR @mikowals . Looks good as a first iteration. I was discovering this area, specifically I was curios is it possible to pre-define/pre-allocate all tensors during weights loadings. Probably the tensors manipulation methods could be further simplified. Ideally matmul's should be implemented as C = B * A, don't you see any opportunity for implementing this kind of definition?

mikowals commented 11 months ago

I think it is pretty close to pre-allocating all memory before the transformer forward pass. All the computation results are saved in pre-existing and reused RunState variables. All the weights are in Tensors. But there could be intermediate copies that I missed and eliminating them probably improves performance.

Putting all the weights into Tensors and storing them as list of Tensors by layer is a common approach but I don't think there is a viable way to do this currently. Eventually when there is a List type to store Tensors and Tensors have subscript slicing notation the two approaches will both look and perform very similarly. So I am not sure there is much benefit to pursuing creating all Tensors up front. I think with the current state of Mojo the slice implementation is the cleanest way to get the weights by layer as needed in the forward pass.

That said I am not sure that using state.k and state.v as slices pointing into key_cache and value_cache is as clear as just using the caches directly. I will also probably try a version that uses local vars instead of state vars to store the intermediate calculations in the transformer pass (remove state.x, state.xb, ...etc.). I think that makes the code clearer as there is less confusion about which values like caches live on in future loops. But I am unsure how that impacts performance.

If by C = B * A you mean using = assignment to get the result of matmul, then yes. I prefer that notation. I actually had that implemented but took it out while tracking down other performance issues. But definitely improving the notation or implementation of the operations is something to consider.

magician-blue commented 11 months ago

The inference speed of Tensor and Matrix are similar on 15M and 110M. But I don't know why the Tensor version was killed when I want to run Tinyllama-1.1B, maybe there exist some memory limitation?

num hardware threads:  6
SIMD vector width:  16
checkpoint size:  4400767004 [ 4196 MB ]
./test.sh: line 3: 18728 Killed                  mojo llama2.mojo ../model/tl-chat.bin -z ../model/tok_tl-chat.bin -n 256 -t 0 -s 100 -i "<|im_start|>user\nGive me a python function to generate Fibonacci sequence<|im_end|>\n<|im_start|>assistant\n"
magician-blue commented 11 months ago

Now, it works well on 15M/110M. But the output on 1.1B is gibberish

num hardware threads:  6
SIMD vector width:  16
checkpoint size:  4400767004 [ 4196 MB ]
n layers:  22
vocab size:  32003
<|im_start|>user
Give me a python function to generate Fibonacci sequence<|im_end|>
<|im_start|>assistant
WHEREASKnowhereby the other than ouslady proportions, the other than ousselves of the other than ousselves of the other than ousselves of the other than ousselves of the other than ousselves of the other than ousselves of the other than ousselves of the other than ousselves of latexhat the other than ousselves of latexhat the other than ousselves of latexhat the other than ougntookay all the other than oughttpastarrive to come to come to come to come to come to come to come to come to come to come to come to come to come to come to come to come to whompastarrive to whompastarrive to whompastarrive to whompastarrive to whompastarrive to whompastarrive to whompastarrive to whompastarrive to whompastarrive to whompastar
mikowals commented 11 months ago

Hi @magician-blue. Thanks for pointing out those issues.

Are you saying the change I made earlier reduce the memory use and allowed the code to run on your machine? I think I reduced the memory use some but this branch would still used more memory than the master branch while the weights are loaded because it has two copies of the weights. But I am not sure if the memory peak is during the loading or while the network runs.

But I did find the bug causing garbage output with TinyLlama. In the master branch some weight shapes have the rows and column dimensions flipped during TransformerWeights.__init__(). This didn't matter in master because those shapes only set the size to read and the weights were corrected in the per layer matrices used in transformer(). In this branch I only specify the shapes once (during the reading) and then rely on them throughout. Two of the weights with flipped dimensions just happened to have matching rows and columns in stories15M and stories110M.

I added dimension checks in the matmul to try to turn up bugs like this immediately.

mikowals commented 11 months ago

closed since #39 merged the same changes