pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
343 stars 13 forks source link

Padding mask for BERT #38

Open kchl5 opened 6 days ago

kchl5 commented 6 days ago

Is it worth investing the effort to write a custom create_mask to handle padding? Can we even pass the padding information based on the input to the function?

Many thanks in advance, Kevin

drisspg commented 6 days ago

Would you mind adding a little more detail here? I think it depends on how you plan to do padding + masking but this should work and should be achievable with mask mod

kchl5 commented 6 days ago

The padding mask is created as follows, where the size is (batch_size, sequence_length) with varying sample lengths within a batch:

def generate_padding(input_id):
    '''
    Description:
    ------------
    Generate padding mask for target tensor.
    For tgt tensor, pad token is 0 and non-pad token is 1.
    Convert tgt tensor to boolean tensor,
    where pad token is True and non-pad token is False.
    Can also be applied to generate source padding mask.

    Parameters:
    -----------
    input_id: `torch.tensor`
        Tensor of input ids.

    Returns:
    --------
    pad: `torch.tensor`
        Boolean tensor of padding mask.
    '''
    pad = input_id == 0
    return pad

For the BERT mask, we apply it by setting the masked token to -100 which is subsequently ignored in the self-supervised masking loss (CE loss). Thus, we only provide the padding mask to the attention. Presumably, we have to write our own masking function for the padding mask, reshape and broadcast our mask and pass it to mask mod. Would you agree with that?

drisspg commented 5 days ago

yeah I think that you would want a mask_mod that knows which tokens should be masked, somewhat similiar to this pattern: https://github.com/pytorch-labs/attention-gym/blob/75867424a1d4391bff49527029d3612a09dd67e2/attn_gym/masks/document_mask.py#L53 where you can prefill a tensor and use it as look up.