dingo-actual / infini-transformer

PyTorch implementation of Infini-Transformer from "Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention" (https://arxiv.org/abs/2404.07143)
MIT License
280 stars 23 forks source link

Handle batch input inference for MoD Infini-Former more gracefully #10

Open dingo-actual opened 6 months ago

dingo-actual commented 6 months ago

Currently, the token sampling for MoD Infini-Former at inference time can result in different length sequences for each observation in the batch. The current workaround is to force the batch size to one and loop through the observations in the batch, which is highly inefficient.

There are two main options for handling this efficiently:

  1. Pad the sampled sequences to the longest sequence length in such a way that the additional tokens contribute nothing to downstream calculations.
  2. Wait for PyTorch to implement a ragged tensor type

I'm likely to pursue the first because there's no telling how long it'll be before the PyTorch devs add ragged tensors.

muditbhargava66 commented 6 months ago

I worked on this issue #15.

muditbhargava66 commented 6 months ago

Should this issue be closed, or do you need any more changes? Please let me know if you have any further questions.

dingo-actual commented 6 months ago

Unfortunately, the fix you introduced assumes that calling .forward_() produces a valid result when called on the original input. What needs to happen during inference is for .forward() to use sample_mask_seg to pad the samples along the token dimension until they all have the same length. The part I haven't gotten around to is going through the math to determine a choice of padding token that doesn't affect downstream calculations.

For the moment, I'm going to revert the change, just to maintain functionality (slow as it is). I really appreciate your putting in time on this though!