jaymody / picoGPT

An unnecessarily tiny implementation of GPT-2 in NumPy.
MIT License
3.21k stars 414 forks source link

feat: added k, v cache for inference speed up #7

Open immortal3 opened 1 year ago

immortal3 commented 1 year ago

Hi, @jaymody, Awesome blog post. I was interested in learning kvcache during inference and searched for it but existing articles on kvcache don't focus on the implementation part of it. So, I decided to implement it in picoGPT.

Are you interested in writing a post for optimization inference time? I would love to collaborate on it.

jaymody commented 1 year ago

Thanks for the implementation!

What kind of speedups did you get with this and did you get an identical output to the non-kv cache version?

Just FYI, I'm going to leave this unmerged to keep the implementation as simple as possible. However, will keep this PR open if people want to reference it in the future.

There's also an inference optimization section in my blog post with some further resources to read up on.

immortal3 commented 1 year ago

Yes, the Output is identical. I am seeing a 25% speedup of CPU.

 the most powerful machines on the planet.

The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.

Yeah, it makes sense to not merge it. Probably, we can create another file gpt2_inference_speed.py which can have these sorts of optimizations.

clam004 commented 9 months ago

Hi @immortal3 I love the minimal implementation I'm having trouble reproducing the 25% speedup though. I've been using time to compare the two implementations and the 125M model for generating the above output text. If you are up for it, a before and after comparison in your own repo would be so cool and very compelling.

immortal3 commented 9 months ago

@clam004 i don't remember exactly how I ended as 25% speedup but it was definitely not a scientific one. 😄

The speedup number will heavily rely on the combination of CPU/Memory and the length of the input tokens. So, I think you might not be getting the exact number 25%, but try feeding a sufficiently longer sequence that should definitely indicate some performance improvement compared to a normal forward pass with KV cache.

On the proper comparison side, I am not sure if it would be worth it (time-wise) at this point to do it thoroughly.