facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.47k stars 605 forks source link

Encoder decoder arch doesnt work when sequence lengths are different #138

Closed kashif closed 2 years ago

kashif commented 2 years ago

šŸ› Bug

I get an error when the sequence lengths to the encoder and decoder are different, e.g. in the code snippet below:

Command

EMB = 384
SEQ_ENC = 128
SEQ_DEC = 64
BATCH = 16
VOCAB = 64

my_config = [
    # A list of the encoder or decoder blocks which constitute the Transformer.
    # Note that a sequence of different encoder blocks can be used, same for decoders
    {
        "reversible": False,  # Optionally make these layers reversible, to save memory
            "block_type": "encoder",
            "num_layers": 3,  # Optional, this means that this config will repeat N times
            "dim_model": EMB,
            "layer_norm_style": "pre",  # Optional, pre/post
            "position_encoding_config": {
                "name": "vocab",  # whatever position encodinhg makes sense
                "seq_len": SEQ_ENC,
                "vocab_size": VOCAB,
            },
            "multi_head_config": {
                "num_heads": 4,
                "residual_dropout": 0,
                "attention": {
                    "name": "linformer",  # whatever attention mechanism
                    "dropout": 0,
                    "causal": False,
                    "seq_len": SEQ_ENC,
                },
            },
            "feedforward_config": {
                "name": "MLP",
                "dropout": 0,
                "activation": "relu",
                "hidden_layer_multiplier": 4,
            },
        },
    {
        "reversible": False,  # Optionally make these layers reversible, to save memory

            "block_type": "decoder",
            "num_layers": 3,  # Optional, this means that this config will repeat N times
            "dim_model": EMB,
            "layer_norm_style": "pre",  # Optional, pre/post
            "position_encoding_config": {
                "name": "vocab",  # whatever position encodinhg makes sense
                "seq_len": SEQ_DEC,
                "vocab_size": VOCAB,
            },
            "multi_head_config_masked": {
                "num_heads": 4,
                "residual_dropout": 0,
                "attention": {
                    "name": "nystrom",  # whatever attention mechanism
                    "dropout": 0,
                    "causal": True,
                    "seq_len": SEQ_DEC,
                },
            },
            "multi_head_config_cross": {
                "num_heads": 4,
                "residual_dropout": 0,
                "attention": {
                    "name": "favor",  # whatever attention mechanism
                    "dropout": 0,
                    "causal": True,
                    "seq_len": SEQ_DEC,
                },
            },
            "feedforward_config": {
                "name": "MLP",
                "dropout": 0,
                "activation": "relu",
                "hidden_layer_multiplier": 4,
            },
        },
]

# This part of xFormers is entirely type checked and needs a config object,
# could be changed in the future
config = xFormerConfig(my_config)
model = xFormer.from_config(config)

#  Test out with dummy inputs
src = (torch.rand((BATCH, SEQ_ENC)) * VOCAB).abs().to(torch.int)
tgt = (torch.rand((BATCH, SEQ_DEC)) * VOCAB).abs().to(torch.int)
y = model(src=src, tgt=tgt)

print(y.shape)

Expected behavior

torch.Size([16, 64, 384])

however, I get:

RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [64, 128, 96, 96]->[64, 128, 96, 96] [64, 64, 96]->[64, 64, 1, 96]
blefaudeux commented 2 years ago

oh thanks for the report and sorry for the delay ! I donĀ“t think that we've ever tried that, I'll have a look, should not be too big of an issue

kashif commented 2 years ago

Yeah no worries I too think itā€™s just some variable issue.. where you are assuming both seq lengths are the sameā€¦ just a bit harder for me to fix than you cause you know where to lookā€¦ thanks!

blefaudeux commented 2 years ago

ah I just realized that you're mixing different attention mechanisms, that's great ! favor needs some work on the causal side, WIP (it's correct I believe but the causal path is a memory hog right now). In that case the crash is in favor, it looks like it does not support q/k having different lengths, I'll have a look !

kashif commented 2 years ago

@blefaudeux thanks! Yes even with:

 "multi_head_config_masked": {
                "num_heads": 4,
                "residual_dropout": 0,
                "attention": {
                    "name": "nystrom",  # whatever attention mechanism
                    "dropout": 0,
                    "causal": True,
                    "seq_len": SEQ_DEC,
                },
            },
            "multi_head_config_cross": {
                "num_heads": 4,
                "residual_dropout": 0,
                "attention": {
                    "name": "nystrom",  # whatever attention mechanism
                    "dropout": 0,
                    "causal": True,
                    "seq_len": SEQ_DEC,
                },
            },

I get:

RuntimeError: The size of tensor a (128) must match the size of tensor b (64) at non-singleton dimension 2
blefaudeux commented 2 years ago

ok that's not normal, it should work for all the attentions which are allow-listed here. I didnĀ“t have a lot of time to look into that, sorry, not forgotten though

blefaudeux commented 2 years ago

ok @kashif I finally got the time to have a look, it has to do with how the causality is handled (if you don't ask for causality then it's all good as it should, at least for attentions which support different k and q dimensions). I think it makes sense in that if you ask for causality and k and q have different lengths, it's not obvious how they refer one to another ? is the shorter one related to the beginning, middle or end of the other one ? I do think that we need to properly output an error about that though, putting up a PR with a unit test right now.

On the user side what you can do is to pad SEQ_DEC where it makes sense (beginning, middle, end) so that they have the same length. Pulling in @fmassa and @dianaml0 for sanity checking

kashif commented 2 years ago

thanks, @blefaudeux for looking into this... so I am planning to use this in the context of time series forecasting... where the encoder takes a sequence (or arb. length) and outputs the memory for the decoder which then takes a sequence length described by the problem and predicts the next value (so I need a causal mask in the decoder of size SEQ_DEC x SEQ_DEC). So I am also not sure without working through the code which one should be used... haha hopefully that makes sense?

blefaudeux commented 2 years ago

thanks, @blefaudeux for looking into this... so I am planning to use this in the context of time series forecasting... where the encoder takes a sequence (or arb. length)

the devil is there basically, "takes a sequence of arbitrary length". Causal needs a 1:1 mapping, because you want to say ĀØdon't look into the future', and you have to know what the future is. If you have [. . . . . .] and [. . .] where is the future in the second sequence, when related to the first one ? If you know that the sequence beginnings are aligned, then you can pad the second one (see torch.nn.functional.pad as linked above) and dump the corresponding output (so it becomes [. . . . . .] and [. . . x x x] and you dump the end of the layer output), if it's aligned in another way you can pad correspondingly.

From the lib perspective it cannot know how to the two relate to another in a causal case and different sequence lengths, so I think that it's better to hard fail but we should catch and explain that properly. In the encoder/decoder setup the cross MHA will mix the encoder and decoder sequences, and that's where the causality request will fail. You can try not asking for causality there, but to my understanding this will be wrong and the model will cheat (overfit)

and outputs the memory for the decoder which then takes a sequence length described by the problem and predicts the next value (so I need a causal mask in the decoder of size SEQ_DEC x SEQ_DEC). So I am also not sure without working through the code which one should be used... haha hopefully that makes sense?

I'm missing a lot of context, but from a distance it also looks like GPT if the goal is to predict the next item in an autoregressive way, we have this example if that helps

blefaudeux commented 2 years ago

thoughts @SeanNaren and @tchaton, my API experts, does the above make sense ?

kashif commented 2 years ago

thanks @blefaudeux I will go over your thoughts and think about it too...

SeanNaren commented 2 years ago

Spoke to @blefaudeux offline and agree with him on this; I think the correct approach would be to instantiate the two separate attention blocks and within the forward function, pad to the appropriate sequence length for the second attention block!

blefaudeux commented 2 years ago

is that ok to close @kashif ? Open to discussing this more

kashif commented 2 years ago

Thanks