turboderp / exllama

A more memory-efficient rewrite of the HF transformers implementation of Llama for use with quantized weights.
MIT License
2.74k stars 215 forks source link

KV caching? #238

Open bryanhpchiang opened 1 year ago

bryanhpchiang commented 1 year ago

Where is it being done in the code?

sleepwalker2017 commented 1 year ago

Hey, I can try to answer the question, seems it's here : image

turboderp commented 1 year ago

Sorry, I apparently missed this one. The cache is contained in the ExLlamaCache class, which is just a wrapper for two lists of preallocated tensors, one pair for each layer of the model.

Caching is performed in the attention function, here:

        # Add keys and values to cache

        new_keys = cache.key_states[self.index].narrow(2, past_len, q_len).narrow(0, 0, bsz)
        new_values = cache.value_states[self.index].narrow(2, past_len, q_len).narrow(0, 0, bsz)
        new_keys.copy_(key_states)
        new_values.copy_(value_states)

Which creates a narrow view on the K/V cache for the given layer, then copies the keys and values computed for the current hidden state into it. Then it takes another view on the cache tensors to feed into the attention step:


        # Key/value tensors with past

        key_states = cache.key_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz)
        value_states = cache.value_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz)

This is the regular HF-like version of the function. The faster C++ implementation is called in ExLlamaAttention.fused(), where the cuda_ext.exllama_ext.q4_attn() function does the same copy operation, but with a custom kernel defined in q4_attn.cu.