Closed tic-top closed 12 months ago
Instead of calculating q,v and store q,v into q_cache, v_cache, we can do calculation directly on q_cache, v_cache. This speed up the inference about 10%.
Hi @tic-top , thanks for PR, I'll take a look. The improvements sounds cool, did you benchmarked this ?
Found an issue with tok/s calculation.
We should use pos - 1
instead of steps - 1
print("\nachieved tok/s: ", (pos - 1) / (end - start) * 1000)
After this fix, it seems that current PR doesn't show any significant improvement. Maybe I got few % average generation speed increase.
I would keep the original notation without changes. Q,K,V
seem to be important parts of transformer
algo
Yes, you're right. The improvement is caused by tok/s calculation. I have made a little change to the previous code. This version almost keep the original notation unchanged. What I do is changing state.k, state.v to a temporary matrix pointing to k_cache and v_cache and then doing calculation on the cache. Although this modification doesn't improve the inference speed significantly, it saves the memory of q, v.
Stop generating when EOS BOS appear.