huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
13.79k stars 751 forks source link

Processing text prompts in batches for LLMs #2108

Closed tbogdala closed 3 weeks ago

tbogdala commented 3 weeks ago

When prompts get longer than trivial sizes, the memory usage spikes as the prompt is thrown into one Tensor and sent off to a forward pass in the model at whatever length it comes in as. These spikes can be reduced by processing the batch in chunks.

The implementation of CausalSelfAttention for Llama inside candle-transformers only handles two cases, a special case of seq_len 1, which occurs while generating text, and a seq_len that matches the whole prompt size with an index_pos of 0 which occurs once when processing the whole prompt. If you attempt to chunk the prompt up into blocks of tokens, when processing the second chunk there will be a broadcasting error because the mask that gets generated is sized to the prompt chunk, but the kv cache has altered the size of k & v to include the previous data and the shapes won't match. (E.g. for a chunk size of 128, the mask will be [128, 128] but on second chunk, the att Tensor ends up as [1, 32, 128, 256])

This can be fixed by creating the mask in a different way:

let mut mask = cache.mask(seq_len)?;
if index_pos != 0 {
    let zero_history = Tensor::zeros((seq_len, (index_pos / seq_len) * seq_len), mask.dtype(), mask.device())?;
    mask = Tensor::cat(&[zero_history, mask], 1)?;
}
mask = mask.broadcast_as(att.shape())?;

I'm not sure how efficient that is, but it produces the same results when processing the prompt in batches vs sending it all in at once. This also is only drop in for the Llama model since it has a kv cache; quantized_llama.rs's ModelWeights doesn't have a kv cache to modify...

I have a modified llama example with the batch processing and this change in the model struct that I could submit for a PR if you'd like to see it, but the above code is all that's needed in candle-transformers.

LaurentMazare commented 3 weeks ago

Sounds like a good thing to add (as mentioned on discord), but rather than doing a cat whatabout just modifying the way mask is defined so that it creates the appropriate maks from the beginning - obviously you will have to pass index_pos to the mask function for that but then it should be pretty straightforward.

tbogdala commented 3 weeks ago

Okay, that is a cleaner choice. After changing the hashmap to have a (usize, usize) key ...

masks: HashMap<(usize, usize), Tensor>,

... this implementation of mask works as you'd want, I think:

fn mask(&mut self, t: usize, u: usize) -> Result<Tensor> {
        if let Some(mask) = self.masks.get(&(t, u)) {
            Ok(mask.clone())
        } else {
            let mask: Vec<_> = 
                (0..t).flat_map(|i| 
                    (0..u).map(move |j| 
                        u8::from(j > i+(u-t))))
                    .collect();
            let mask = Tensor::from_slice(&mask, (t, u), &self.device)?;
            self.masks.insert((t, u), mask.clone());
            Ok(mask)
        }
    }

Then, in forward the mask can get created more like it originally was:

let mask = cache.mask(seq_len, index_pos + seq_len)?.broadcast_as(att.shape())?;

Tested it on batch sizes of 1, 64, 73, 128, 256, 1024 with a prompt of 892 tokens and seems to work well, though batch size of one takes ages, as expected.

LaurentMazare commented 3 weeks ago

Cool, happy to get a PR for this if you want to make one, one nitpick is that I would suggest using j + t > i + u as a condition so as to avoid having to think about why u-t has to be positive (which is the case seeing how you call the function but would result in an underflow if it wasn't called properly).

tbogdala commented 3 weeks ago

Nothing left to ask, so I'm closing the issue.