facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.18k stars 6.37k forks source link

Assertion compares dimension of key_padding_mask with query dimension in xformers MHA #5075

Open sarthakgarg opened 1 year ago

sarthakgarg commented 1 year ago

❓ Questions and Help

What is your question?

I came across this assertion: https://github.com/facebookresearch/fairseq/blob/3f6ba43f07a6e9e2acf957fc24e57251a7a3f55c/fairseq/modules/multihead_attention.py#L385 Which compares the sequence length dimension of key padding mask with tgt_len, which is the sequence length dimension of the query. This check fails if the sequence length dimensions of key and query are different (for e.g. in cross-attention). Shouldn't the check here be: key_padding_mask.size(1) == key.size(0)?

Code

        tgt_len, bsz, embed_dim = query.size()

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == tgt_len
VarunGumma commented 1 year ago

Just a dumb question. I am training a transformer model using fairseq and want to use xformers. Is it enough if I install xformers library in my environment and start training, or do I need to pass any additional arguments?