dingo-actual / infini-transformer

PyTorch implementation of Infini-Transformer from "Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention" (https://arxiv.org/abs/2404.07143)
MIT License
280 stars 23 forks source link

inference issue #2

Closed winglian closed 6 months ago

winglian commented 6 months ago

thoughts about this exception during inference?

  File "/workspace/voltronformers/src/voltronformer/model.py", line 303, in forward                               
    output = residual + self.attn(h, position_ids=position_ids)[0]                                                                                                                                                                  
modules/module.py", line 1520, in _call_impl    tor h/nn/modules/module.py",eline 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)                                                                       
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl             
    return forward_call(*args, **kwargs)                                                                                                                                                                                            
  File "/workspace/voltronformers/src/voltronformer/infini_attention.py", line 109, in forward                                                                                                                                      
    return torch.concat(out, dim=1)                                                                                                                                                                                                 
RuntimeError: torch.cat(): expected a non-empty list of Tensors                                                   
winglian commented 6 months ago

this might be related to Mixture-Of-Depth as it for some reason when inferencing short sequence lengths, sends zero tokens through.

winglian commented 6 months ago

yeah, it's a mix of issues with MoD and infini-attention. MoD had a weird optimization where if the prompt tokens would also be reduced by the capacity factor, which can somehow result in no tokens making it through. simply removing that optimization gets a bit further, but now infini-attention barfs when the tokens are less than the segment len

ValueError: Sequence length must be divisible by segment length. seq_len: 6 segment_len: 256
muditbhargava66 commented 6 months ago

In the CompressiveMemory module , handle the case where the input sequence length is less than the segment length more gracefully:

if seq_len < self.segment_len:
    # If sequence is shorter than segment length, just do regular scaled dot-product attention 
    q = self.proj_q(x).view(batch_size, seq_len, self.num_heads, self.dim_key).transpose(1, 2)
    k = self.proj_k(x).view(batch_size, seq_len, self.num_heads, self.dim_key).transpose(1, 2) 
    v = self.proj_v(x).view(batch_size, seq_len, self.num_heads, self.dim_value).transpose(1, 2)

    att = nn.functional.softmax(q @ k.transpose(-2, -1) / torch.sqrt(torch.tensor(self.dim_key)), dim=-1) @ v
    att = att.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.dim_value)

    return self.proj_out(att)

This will skip the memory-based attention and just perform regular scaled dot-product attention when the sequence is too short, avoiding the segment length divisibility issue.

dingo-actual commented 6 months ago

For MoD in particular, I'm planning two alternatives:

  1. Let the MoD block sample at the segment-level and apply CompressiveMemory on those segments without modification
  2. Create an MoD variant of CompressiveMemroy that independently samples the specified fraction of tokens within each segment.

If there's another mechanism for combining the two techniques, I'm always open to suggestions!

dingo-actual commented 6 months ago

As for the current behavior where the sequence length is less than the segment length, I'm on the fence.

On one hand, raising an error for that instead of defaulting back to ordinary SDP attention would make it so that users would be alerted to problems in their upstream code creating the short sequence lengths. On the other hand, if the intention is to use a short sequence length, then it makes sense to bypass the infini-former attention mechanism completely.

Thoughts?

db7894 commented 6 months ago

Hi, just dropping in—for the sake of general usability (maybe a bit weird to want to use this technique for short sequence lengths, but who knows), I'd consider defaulting + log / alert to user the sequence length you're getting and that you're bypassing infini-attention.

dingo-actual commented 6 months ago

@db7894 I like that a lot. I'll change the error to a warning and fall back to regular attention for short sequences.

dingo-actual commented 6 months ago

Just pushed the change. For now, I'm giving a warning and falling back to full attention if the sequence length is less than the segment length or not divisible by the segment length.

Going to go back over the math to make sure there isn't a workaround for the latter, so I'm leaving this issue open for the moment.

db7894 commented 6 months ago

I'll try to play around and see if I notice anything r.e. the divisibility issue when I have time (not sure when that'll be), but I do wonder if people would want to do that—maybe a dumb/throwaway idea is to pad qkv tensors, but that would probably introduce other issues.

dingo-actual commented 6 months ago

Haven't done the math yet, but while I was fixing a bug with the inference case for the .forward() method, I realized the number of routed tokens has to have a variable length. I might not get around to math-ing out the correct way to keep the dimensions lined up until next weekend, but rest assured, I'm on it!

dingo-actual commented 6 months ago

Just pushed the fix for the original issue. For now, I have a workaround for inference time where I force the batch size to be one. While going over the math and the tensor shapes, I found that the normalization vector was being computed incorrectly, so I fixed that as well.

Going to wait until people have had a chance to try the fixes before I close this issue.

dingo-actual commented 6 months ago

It's been a few days and nobody's reported any problems relating to sequence lengths (other than Issue #10). Going to close this issue. If anyone encounters any issues, I'll reopen it.