CUNY-CL / yoyodyne

Small-vocabulary sequence-to-sequence generation with optional feature conditioning
Apache License 2.0
25 stars 15 forks source link

Different sized encoder for TransformerDecoder #182

Open Adamits opened 2 months ago

Adamits commented 2 months ago

It would be convenient to allow the encoder output_size to be different from the TransformerDecoder embedding size. To illustrate the issue with this, the below code snippet

import torch
import math

def generate_square_subsequent_mask(length: int) -> torch.Tensor:
        return torch.triu(torch.full((length, length), -math.inf), diagonal=1)

# INITIALIZE A TRANSFORMER WITH THIS HIDDEN AND EMBEDDING SIZE
hid=128
emb=64
decoder_layer = torch.nn.TransformerDecoderLayer(
    d_model=emb,
    dim_feedforward=hid,
    nhead=2,
    dropout=0.2,
    activation="relu",
    batch_first=True,
)
frank_transformer = torch.nn.TransformerDecoder(
    decoder_layer=decoder_layer,
    num_layers=2,
    norm=torch.nn.LayerNorm(emb),
)

# INITIALIZE TARGETS WITH EMBEDDING SIZE
# AND A FAKE ENCODER OUTPUT WITH HIDDEN SIZE
b = 4
seq_len = 10
target_embedding = torch.randn((b, seq_len, emb))
encoder_hidden = torch.randn(b, seq_len, hid)
target_sequence_length = target_embedding.size(1)
# -> seq_len x seq_len.
causal_mask = generate_square_subsequent_mask(
    seq_len
)
# -> B x seq_len x d_model.
output = frank_transformer(
    target_embedding,
    encoder_hidden,
    tgt_mask=causal_mask,
    # memory_key_padding_mask=source_mask,
    # tgt_key_padding_mask=target_mask,
)

throws:

File "test.py", line 34, in <module>
    output = frank_transformer(
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/transformer.py", line 460, in forward
    output = mod(output, memory, tgt_mask=tgt_mask,
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/transformer.py", line 847, in forward
    x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
  File "torch/nn/modules/transformer.py", line 865, in _mha_block
    x = self.multihead_attn(x, mem, mem,
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "torch/nn/functional.py", line 5300, in multi_head_attention_forward
    q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
  File "torch/nn/functional.py", line 4836, in _in_projection_packed
    kv_proj = linear(k, w_kv, b_kv)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (40x128 and 64x128)

But if I change the code such that encoder_hidden = torch.randn(b, seq_len, hid) --> encoder_hidden = torch.randn(b, seq_len, emb), then this works fine.

Essentially, we need the self-attention and multihead-attention to expect different input sizes (which may also require the layer norms to change too).

I am putting this up, and will try to work out a solution. The easiest thing for allowing this behavior in yoyodyne would be to either project the encoder output size into the decoder embedding size, or visa versa, but I feel that this changes the architecture more than necessary. Instead, I would like to consider if there is an elegant way to update the sa_block and mha_block such that it does not break other things in the transformer (e.g. layer norm).

Adamits commented 2 months ago

I thought about this more. Since the residual layers in transformers are just summing self_attn and mha_attention (with layer norm in between), I don't think we can make this update without fundamentally changing the transformer architecture (e.g. via concatenating them, or projecting one into the size of the other).

I think the best thing to do is either:

  1. Force the encoder output to be the same as the decoder embedding size (raise an error if it is not)
  2. Infer when the encoder output is different from the decoder embedding size and create an additional layer in the yoyodyne model class that projects source output into the decoder input size.

One place that 1) gives us an issue is if we want to use an LSTM encoder with a transformer decoder. Then the encoder outputs hidden_size * num_directions and the transformer expects embedding_size. This limits the shape of a valid architecture quite a bit. Not sure if that is a problem or not though.

kylebgorman commented 2 months ago

I think either would be fine. This is a good example of a second type of presupposition we will want to test for before training begins.

bonham79 commented 3 weeks ago

I want to say there was a variant of transformer a while back that approached this problem (Sumformer maybe). But I think the ideal solution would be do simply add an additional layer perceptron to force alignment. Personally I don't think it's too much variation on transformer architecture since everyone and their grandmother creates an inhouse variant. (You'll note that no one uses PyTorch's base form.)

Regarding layer norms. What you caaan mess with is swapping out with batch norm. Bit too late for me to do the maths for the main issue, but it may give more flexibility with variations in depth.

kylebgorman commented 3 weeks ago

@bonham79 is on point about how everyone has a slightly different transfomer variant and it's okay.