google-research / bigbird

Transformers for Longer Sequences
https://arxiv.org/abs/2007.14062
Apache License 2.0
563 stars 101 forks source link

Variable error with the full_bigbird_mask method in the multi head attention class #35

Open BetikuOluwatobi opened 9 months ago

BetikuOluwatobi commented 9 months ago

There is a variable error with the full_bigbird_mask method in the multi-head attention class for the big bird mask that uses MAX_SEQ_LEN instead of from_sequence_length passed, this will affect the creation of attention_mask with the using the convert_attn_list_to_mask(self, rand_attn) method. temp_mask = [ full_bigbird_mask( # pylint: disable=g-complex-comprehension self.from_seq_length, self.to_seq_length, self.from_block_size, self.to_block_size, rand_attn=rand_attn[i]) for i in range(self.num_attention_heads) ] `def full_bigbird_mask(from_seq_length, to_seq_length, from_block_size, to_block_size, rand_attn): """Calculate BigBird attention pattern as a full dense matrix.

Args: from_seq_length: int. length of from sequence. to_seq_length: int. length of to sequence. from_block_size: int. size of block in from sequence. to_block_size: int. size of block in to sequence. rand_attn: adjajency matrix for random attention.

Returns: attention mask matrix of shape [from_seq_length, to_seq_length] """

attn_mask = np.zeros((MAX_SEQ_LEN, MAX_SEQ_LEN), dtype=np.int32) for i in range(1, (MAX_SEQ_LEN // from_block_size) - 1):` full_bird_mask method uses MAX_SEQ_LEN instead of from_seq_length or to_seq_length which does not make the method dynamic as MAX_SEQ_LEN is only defined at the top of the module and seems to be causing a glitch with the convert_attn_list_to_mask method.