Closed winglian closed 6 months ago
this might be related to Mixture-Of-Depth as it for some reason when inferencing short sequence lengths, sends zero tokens through.
yeah, it's a mix of issues with MoD and infini-attention. MoD had a weird optimization where if the prompt tokens would also be reduced by the capacity factor, which can somehow result in no tokens making it through. simply removing that optimization gets a bit further, but now infini-attention barfs when the tokens are less than the segment len
ValueError: Sequence length must be divisible by segment length. seq_len: 6 segment_len: 256
In the CompressiveMemory module , handle the case where the input sequence length is less than the segment length more gracefully:
if seq_len < self.segment_len:
# If sequence is shorter than segment length, just do regular scaled dot-product attention
q = self.proj_q(x).view(batch_size, seq_len, self.num_heads, self.dim_key).transpose(1, 2)
k = self.proj_k(x).view(batch_size, seq_len, self.num_heads, self.dim_key).transpose(1, 2)
v = self.proj_v(x).view(batch_size, seq_len, self.num_heads, self.dim_value).transpose(1, 2)
att = nn.functional.softmax(q @ k.transpose(-2, -1) / torch.sqrt(torch.tensor(self.dim_key)), dim=-1) @ v
att = att.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.dim_value)
return self.proj_out(att)
This will skip the memory-based attention and just perform regular scaled dot-product attention when the sequence is too short, avoiding the segment length divisibility issue.
For MoD in particular, I'm planning two alternatives:
CompressiveMemory
on those segments without modificationCompressiveMemroy
that independently samples the specified fraction of tokens within each segment.If there's another mechanism for combining the two techniques, I'm always open to suggestions!
As for the current behavior where the sequence length is less than the segment length, I'm on the fence.
On one hand, raising an error for that instead of defaulting back to ordinary SDP attention would make it so that users would be alerted to problems in their upstream code creating the short sequence lengths. On the other hand, if the intention is to use a short sequence length, then it makes sense to bypass the infini-former attention mechanism completely.
Thoughts?
Hi, just dropping in—for the sake of general usability (maybe a bit weird to want to use this technique for short sequence lengths, but who knows), I'd consider defaulting + log / alert to user the sequence length you're getting and that you're bypassing infini-attention.
@db7894 I like that a lot. I'll change the error to a warning and fall back to regular attention for short sequences.
Just pushed the change. For now, I'm giving a warning and falling back to full attention if the sequence length is less than the segment length or not divisible by the segment length.
Going to go back over the math to make sure there isn't a workaround for the latter, so I'm leaving this issue open for the moment.
I'll try to play around and see if I notice anything r.e. the divisibility issue when I have time (not sure when that'll be), but I do wonder if people would want to do that—maybe a dumb/throwaway idea is to pad qkv tensors, but that would probably introduce other issues.
Haven't done the math yet, but while I was fixing a bug with the inference case for the .forward()
method, I realized the number of routed tokens has to have a variable length. I might not get around to math-ing out the correct way to keep the dimensions lined up until next weekend, but rest assured, I'm on it!
Just pushed the fix for the original issue. For now, I have a workaround for inference time where I force the batch size to be one. While going over the math and the tensor shapes, I found that the normalization vector was being computed incorrectly, so I fixed that as well.
Going to wait until people have had a chance to try the fixes before I close this issue.
It's been a few days and nobody's reported any problems relating to sequence lengths (other than Issue #10). Going to close this issue. If anyone encounters any issues, I'll reopen it.
thoughts about this exception during inference?