lucidrains / routing-transformer

Fully featured implementation of Routing Transformer
MIT License
282 stars 29 forks source link

AutoregressiveWrapper expects different input lengths based on type #5

Closed tomweingarten closed 4 years ago

tomweingarten commented 4 years ago

When the AutoregressiveWrapper receives a tensor input, it shrinks the size of the input by one. When it receives non-tensor input it applies padding. This is a bit confusing, since it means you need to provide different size inputs depending on type. Normally this wouldn't matter, but with axial position encoding it expects an exact input length, so it can fail for an input length difference of 1.

        if isinstance(x, torch.Tensor):
            xi, xo = x[:, :-1], x[:, 1:]
            annotations = annotations[:, :-1]
        else:
            xi = pad(list(map(lambda t: t[:-1], x)))
            xo = pad(list(map(lambda t: t[1:], x)))
lucidrains commented 4 years ago

@tomweingarten Hey! Actually, the current way axial positional encoding is made can take in any sequence length, so that shouldn't be a problem. I agree the AutoregressiveWrapper introduces an area of confusion, since it should accept an input of sequence length + 1, so it can be split into input and output of seq_len -> seq_len, but neither kmeans nor axial positional encoding should require fixed sequence length. Local attention requires it to be a multiple of the window size, but I should have taken care of that for you too (please send me the error message if not, and I'll fix it) How are you passing in the training data? As one tensor or an array of tensors?

lucidrains commented 4 years ago
import torch
from axial_positional_embedding import AxialPositionalEmbedding

x = torch.randn(1, 4095, 512)
emb = AxialPositionalEmbedding(512, axial_shape = (64, 64))
print(emb(x).shape) # (1, 4095, 512)

https://github.com/lucidrains/axial-positional-embedding

tomweingarten commented 4 years ago

You are correct, I had been messing around with the generate function to play with some novel input shapes and forgot I had disabled the Autopadder in the process. I should've tried this on some clean code before reporting :)