ml-jku / hopfield-layers

Hopfield Networks is All You Need
https://ml-jku.github.io/hopfield-layers/
Other
1.69k stars 189 forks source link

Transformer decoder target mask wrong shape error #18

Closed kashif closed 2 years ago

kashif commented 2 years ago

Hello, so I have an encoder-decoder setup with a tgt_mask in the decoder as follows:

# enc input [B, C, E]
encoder_association = HopfieldLayer(input_size=E, num_heads=num_heads)
encoder_layer = HopfieldEncoderLayer(
            encoder_association,
            dim_feedforward=E*2,
            dropout=dropout_rate,
            activation=act_type,
        )
transformer_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=num_encoder_layers
        )

# dec input [B, P, E]
decoder_association_self = HopfieldLayer(
            input_size=E, num_heads=num_heads
    )
decoder_association_cross = HopfieldLayer(
            input_size=P, num_heads=num_heads
        )
decoder_layer = HopfieldDecoderLayer(
            hopfield_association_self=decoder_association_self,
            hopfield_association_cross=decoder_association_cross,
            dim_feedforward=E*2,
            dropout=dropout_rate,
            activation=act_type
        )
transformer_decoder = nn.TransformerDecoder(
            decoder_layer, num_layers=num_decoder_layers
        )

# Transformer
transformer = nn.Transformer(
            d_model=E,
            nhead=num_heads,
            custom_encoder=transformer_encoder,
            custom_decoder=transformer_decoder,
            batch_first=True,
        )

I create the mask via:

tgt_mask = transformer.generate_square_subsequent_mask(P)

And when I run it I get:

 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
--> 314                         raise RuntimeError('The size of the 2D attn_mask is not correct.')

So for example for P=28 I have:

attn_mask.shape
torch.Size([1, 28, 28])

and query has shape:

query.shape
torch.Size([28, B, E])

and key is:

key.shape
torch.Size([1, B, E])

for some reason even though the input to the decoder has tensor shapes:

dec_output = transformer.decoder(
            dec_input, # [B, P, E]
            enc_out,  # [B, C, E]
            tgt_mask=tgt_mask, # [P, P]
        )

Would you know what I am missing? Thanks!

bschaefl commented 2 years ago

Hi @kashif,

for the default encoder-decoder Transformer setting, one has to use Hopfield instead of HopfieldLayer, as the latter uses learnable parameters as the inputs for the key and query. See https://github.com/ml-jku/hopfield-layers/blob/1497a4d3eaaa0003a8f73484a562329865a61d02/hflayers/__init__.py#L12-L15

and https://github.com/ml-jku/hopfield-layers/blob/1497a4d3eaaa0003a8f73484a562329865a61d02/hflayers/__init__.py#L619-L623

for more information. Moreover, the input size of decoder_association_cross needs to be equal to the number of features of a single instance/token, which is E in your case:

decoder_association_cross = Hopfield(input_size=E, num_heads=num_heads)

Please let me know, if the issue is resolved.

kashif commented 2 years ago

thanks @bschaefl let me check and get back to you!

kashif commented 2 years ago

@bschaefl yes sorry it works now after replacing all the HopfieldLayers by Hopfield and the fix thanks!