meta-llama / llama

Inference code for Llama models
Other
55.81k stars 9.51k forks source link

An ingenious way to speed up inference! πŸš€ #256

Open elephantpanda opened 1 year ago

elephantpanda commented 1 year ago

I thought of a way to speed up inference by using batches. This assumes that you can run a batch of 2 faster much than you can run 2 passes. So it will work with GPUs with a lot of compute cores or multi-GPU setups. The algorithm scales so the more computing power (more GPUs) the faster it will go.

First create a dictionary that gives the most common token to follow each particular token. e.g. the most common token to follow 'there' might be 'was'. You could probably get this data by just going through every token with a window of 1. And store the most likely next token. Then store these in a dictionary.

Say your tokens are this:

[Once, upon, a time, there,]

Then you put them as a batch of two like this. In the second batch, you simply guess the next token using your dictionary. (In this case your dictionary says that the most common word to follow 'there' is 'was'.)

[ _ ,Once, upon, a, time, there,]
[Once, upon, a , time, there, was,]

So now, if the output is this:

[ Once, upon, a, time, there, was]
[ upon, a , time, there, was, a]

It means you have got two tokens for the price of one [was, a]. I'm not sure what percent of the time you will get lucky like this. You might only do a double batch if you are fairly certain of the next word(s). You can always do bigger batches if you are less certain of the next word. Or you can even guess several words ahead.

Thus with dictionary lookups, and guessing ahead you might be able to speed up inference maybe two times!

This is the simplest way, a more complicated way would be to train a very small neural network (or use the same NN but on a very small window) to guess the next word, before running the full neural network. This means that if the small NN guesses correctly, you skip ahead several tokens! πŸš€

(I wonder if such an algorithm is implemented by Chat GPT or Bard πŸ€”)

Unfortunately using the "window of 1" method the most common token to follow any word is usually one of these:

,
.
and
to
of
the

Which may make the method not so useful πŸ€” Although for some words such as 'suggest' the most likely word to follow is 'that'.


I have found that I can use a smaller LLM such as the 111M cerebras model to make an initial good guess for the next word in 0.1 seconds then run a batch of 2. It gets the guess right a lot of the time. So in this way you can use a bad model to speed up a good model!

elephantpanda commented 1 year ago

Another way to speed things up is to have the GPU truncate the output so it only sends the final token to the CPU.

brucewlee commented 1 year ago

This is interesting but what use is LLM if you are using word frequency to guess a word out two?

andrewPoulton commented 1 year ago

This is indeed a great idea - it's called speculative decoding. You specific idea of having a dictionary lookup is close to staged speculative decoding, where there is a hierarchy of LMs, starting with (essentially) a lookup ngram model, then a small (transformer) LM, then the biggest "oracle" LM,