Closed GavChap closed 6 months ago
Looks 1:1 between PR and master for me (dropout defaults to 0), but good to have the correct value on both methods on the off-chance it's used for something down the line.
Looks 1:1 between PR and master for me (dropout defaults to 0), but good to have the correct value on both methods on the off-chance it's used for something down the line.
Hi, I've been doing quite a bit of testing and it looks like batches and non-xformers is working just fine.
The original issue with torch SDP was batch size > 2, mostly due to the attention mask here being the result of literal guesswork. Larger batch sizes e.g. 4/8 (especially with the positive/negative prompt being the same length) usually make that really visible, because the mask doesn't take that part into account, since it only really accounts for the simple batch size = 2 scenario.
The original issue with torch SDP was batch size > 2, mostly due to the attention mask here being the result of literal guesswork. Larger batch sizes e.g. 4/8 (especially with the positive/negative prompt being the same length) usually make that really visible, because the mask doesn't take that part into account, since it only really accounts for the simple batch size = 2 scenario.
Thanks, how would I go about trying to fix that? Any hints / tips?
Well, good question. I'm not sure what the mask is supposed to look like for batch sizes >2, but I know for a fact the current one is wrong, since it always just has the same initial width (no matter how many mask elements we have) and has the height set to q.shape[2]//2
, so it's not even symmetric lol.
With how it's set up, B=2 is fine, but B=4 looks like how it's depicted in the image below. I think forcing it to be the same as the left (B=2) one isn't the way to go, since the overall shape changes, but I might be wrong. The xformers function is called "BlockDiagonalMask", so maybe it expects something like the one pictured on the right?
(note: the shapes change based on a lot of factors, this is just a random example)
Add dropout_p to SDP attention