huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
13.79k stars 751 forks source link

Batch llama prompt #2111

Closed tbogdala closed 3 weeks ago

tbogdala commented 3 weeks ago

This PR is discussed in #2108 and handles mask creation for the Llama model that allows for processing a user supplied prompt in token batches instead of all at once. The key change was to Cache::mask(), adding a second usize and then creating the appropriately sized vector to turn into a Tensor there.

The code in candle-examples/examples/llama/main.rs in this PR may need smoothing, but other than that, I've tested the example with and without the new --prompt-batch-size CLI parameter and at a variety of sizes.

LaurentMazare commented 3 weeks ago

Yeah the change in the example part indeed seems a bit complex. Maybe we should just have the model change in this PR so that users of the candle-transformers crate can benefit from it and we don't need to adapt the example for now.