tairov / llama2.mojo

Inference Llama 2 in one file of pure 🔥
https://www.modular.com/blog/community-spotlight-how-i-built-llama2-by-aydyn-tairov
MIT License
2.09k stars 140 forks source link

Stop generating when EOS BOS appear #21

Closed tic-top closed 12 months ago

tic-top commented 12 months ago

Stop generating when EOS BOS appear.

tic-top commented 12 months ago

Instead of calculating q,v and store q,v into q_cache, v_cache, we can do calculation directly on q_cache, v_cache. This speed up the inference about 10%.

tairov commented 12 months ago

Hi @tic-top , thanks for PR, I'll take a look. The improvements sounds cool, did you benchmarked this ?

tairov commented 12 months ago

Found an issue with tok/s calculation. We should use pos - 1 instead of steps - 1

print("\nachieved tok/s: ", (pos - 1) / (end - start) * 1000)

After this fix, it seems that current PR doesn't show any significant improvement. Maybe I got few % average generation speed increase.

I would keep the original notation without changes. Q,K,V seem to be important parts of transformer algo

tic-top commented 12 months ago

Yes, you're right. The improvement is caused by tok/s calculation. I have made a little change to the previous code. This version almost keep the original notation unchanged. What I do is changing state.k, state.v to a temporary matrix pointing to k_cache and v_cache and then doing calculation on the cache. Although this modification doesn't improve the inference speed significantly, it saves the memory of q, v.