lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

How to use "src_key_padding_mask" #253

Closed LutherLin closed 2 months ago

LutherLin commented 2 months ago

In PyTorch's official nn.TransformerEncoder, there is a parameter called src_key_padding_mask, which represents the mask for source data keys in each batch (optional). Does the x_transformers library offer a similar optional masking method, specifically designed to mask only the keys?

        self.seqTransEncoder = ContinuousTransformerWrapper(
            dim_in = self.latent_dim, dim_out = self.latent_dim,
            emb_dropout = self.dropout,
            max_seq_len = 1024,
            attn_layers = Encoder(
                dim = self.latent_dim,
                depth = num_layers,
                heads = num_heads,
                ff_mult = int(np.round(ff_size / self.latent_dim)), 
                layer_dropout = self.dropout, cross_attn_tokens_dropout = 0,
            )
        )

I have defined the network structure above,then I want to use as:

padding_mask = ...#torch.Size([32, 50]) 
xseq = ... #torch.Size([50, 32, 384])
output = self.seqTransEncoder(xseq, .........=padding_mask)[1:] 

Which mask should I use?

lucidrains commented 2 months ago

@LutherLin yup of course

in this repository, it is simply mask for the key padding mask, the shape accepted is (batch, seq), and True denotes attend, False not attend

lucidrains commented 2 months ago

in your example

padding_mask #   (batch, seq) - (50, 32)
xseq #   (batch, seq, feature dimension) - (50, 32, 384)

all my repositories adopt batch-first