Open EricLBuehler opened 6 months ago
@LaurentMazare, is this a mistake on my part?
Not sure to understand, this model has been designed to be passed a prompt the one token at a time, so it fails if after the prompt you pass it multiple tokens at once which is somewhat expected. Do you mean that the error message should be more explicit about why this is failing?
For speculative decoding, we need to run the target model with multiple tokens at once, once per step. If we need to run the target model with a full prompt, that would be a big performance hit, which is why I tried to do this. Is there some workaround, like disabling the attention mask?
I think disabling the attention mask would be incorrect, you want the tokens in the batch you're processing to be causal between them and to be able to attend to all tokens in the kv cache. So you would want a mask that is rectangular rather than square based on how many tokens are in the kv-caches at the moment, and it should look like the following for a batch of 4 tokens and a kv cache that already has 5 tokens processed.
00000111
00000011
00000001
00000000
Ok. Would this be similar to #2111?
Indeed looks like the mask part at the bottom. Would be great if you can make a fresh PR with that change for the model that you care about.
Ok, so just to confirm: it is this part?
I can add a PR for this to some of the models if you think it is a good idea.
Yep exactly this part, probably good to support for at least llama and quantized-llama (and others too but they might need a bit more work as the mask generation is different).
I was able to make a general causal masker implementation here:
It works for all models with a causal/causal+sliding window mask. Should I submit this as a PR?
Hello all,
Thanks for your great work here. We are implementing speculative decoding at mistral.rs, and were in the final stages of testing when we discovered some incredibly strange behavior. Specifically, the following error results when sending multiple tokens at once during the completions steps:
Error: cannot broadcast [3, 3] to [1, 32, 3, 5]
Reproducing this error is simple:
In the
quantized/main.rs:578
:Is this a bug?