Closed ziadloo closed 1 month ago
This could be a result of right-aligning and padding the inputs in the batch. The rotary embeddings shouldn't care what the starting token ID is as long as any prior tokens are masked out during attention, and I haven't had issues with it before, even mixing very long and very short prompts in a batch. CodeLlama does use a very large rotary embedding base (1e6), so that could be amplifying any numerical precision issues.
I'll have a look and see if maybe the padding approach isn't enough and there needs to be another channel for communicating position offsets to the forward pass. In the meantime, I guess to rule out that it's a simple tokenization issue, you could verify that the attention mask looks reasonable in generate_simple()
:
mask = self.tokenizer.padding_mask(ids) if batch_size > 1 else None
Thanks for the quick reply. I checked and I can confirm that the mask is generated right-aligned and properly. Looking forward for a fix. Perhaps for now, I'll batch copies of the same prompt together. I was going for pass@10 anyways.
Thanks again, this code is very stable. Unlike other packages I worked with.
Well, I pushed an update that shifts the RoPE position IDs according to the length of each item in the batch. This at least ensures that the first token in a sequence will be encoded as position zero regardless of how many padding tokens precede it. It doesn't fix the issue, though, which comes down to the fact that Flash Attention doesn't support attention masking. So all those padding tokens are still attended to, and you simply get wrong output when they're present.
You can fix it for now with config.no_flash_attn = True
, and I seem to be getting the expected output in your example using that option.
While flash-attn does provide ways to process variable-length sequences in a batch, rewriting the code to use those would make flash-attn a requirement rather than an option, and it's still problematic for Windows and AMD users. So it could take some time to come up with a proper solution.
On the last point in your message, I'm not sure if I follow, perhaps I'm just misunderstood. But isn't it possible to have both implementations (the current one and the fix) in code available and switch between the two based on the value of config.no_flash_attn
? Like, a completely new class that implements the fix and when you are instantiating an object, you'll choose the class based on whether the user has set the no_flash_attn
or not. That way, by resetting this variable, the user can switch back to the current implementation if they cannot run a flash attention.
Perhaps, I'm missing something and it's not that simple. In any case, thanks for the reply.
The fix I pushed today doesn't really relate to flash-attn. The fix is just always on, shifting position IDs whenever you generate in a batch.
But regardless of that, flash-attn won't currently work with batches of varying sequence length because of the padding mask issue. So you can disable it at runtime with the no_flash_attn
flag, and then batching will work correctly.
Turns out there was a bug in the how the position embeddings were being applied. I'm getting consistent results with the latest commit, although still only with flash-attn disabled.
Now looking at the PR #240 which looks promising as a solution for flash-attn.
This should be fully addressed with the dynamic generator which does not require attention masking and starts all sequences in a batch at position zero, regardless of uneven lengths.
When I group more than one prompt into a batch, if the prompts are of different sizes, the generated output for the shorter prompts suffers from repeating tokens. This does not happen when all the prompts are the same.
Here's an example (the prompts are taken from HumanEval):
In this example, first we generate 4 outputs for a batch of prompts with different sizes. Then the same prompts are each copied to a batch of their own. This behaviour is replicable but I could not find the reasoning why/how it happens other than the fact that it has something to do with the length of the input prompts.
Batch of different prompt sizes:
Batch of same prompt (0) sizes:
Batch of same prompt (1) sizes:
Batch of same prompt (2) sizes:
Batch of same prompt (3) sizes: