graykode / xlnet-Pytorch

Simple XLNet implementation with Pytorch Wrapper
https://arxiv.org/pdf/1906.08237.pdf
Apache License 2.0
577 stars 107 forks source link

Confusion about the relative position embedding with attn_type='bi' but bsz=1 #12

Open NotANumber124 opened 5 years ago

NotANumber124 commented 5 years ago

The default setting is to use the bidirectional data, attn_type='bi', but bsz=1. But in this function, https://github.com/graykode/xlnet-Pytorch/blob/cb793a1c75bdc59e3360f04ec641af726719811f/xlnet.py#L371 It shows the bidirectional data only works when bsz%2 ==0. However in default, bsz = 1. I am confused, if bsz=1, the setting for the beg, and end in the following code, is it right? https://github.com/graykode/xlnet-Pytorch/blob/cb793a1c75bdc59e3360f04ec641af726719811f/xlnet.py#L380-L387 Could anyone help me with this confusion?

Asichurter commented 3 years ago

@NotANumber124 Actually there's no problem with the beg and end value of positional encoding whether the bidirectional data works, because beg and end values are used for relative distance ranging.

To see this, you can imagine there's a sequence of hidden states to make self-attention with shape: [mlen+qlen, hidden_dim] (mlen memory first, qlen input follows, ignores batch), where 'mlen' refers to memories and 'qlen' refers to input data sequence. Because relative position is modeled to replace absolute position, we have to determine the range of relative position distances (i - j) by finding maximum and minimum i-j value and embed them. When bidirectional attention is used, maximum relative distance comes from the last element of the sequence (index=mlen+qlen-1) by looking the left-most element, which results (mlen+qlen)-0=mlen+qlen. Similarily, minimum relative distance comes from the first element of the input sequence (index=mlen), by looking the right-most element of the sequence, which results mlen-(mlen+qlen) = -qlen. In summary, the range of relative position distance is [-qlen, mlen+qlen], the same as code. Note that memories of mlen can only be attended as K, but not Q to attend K because memories are not queries.

However, I found another place confusing. When bidirectional data is used, forward and backward data are concatenated to an united pos_emb tensor on dim=1, which means this dimension refers to direction of data:

https://github.com/graykode/xlnet-Pytorch/blob/cb793a1c75bdc59e3360f04ec641af726719811f/xlnet.py#L401-L404

However, this dimension of pos_emb tensor is misused as batch_size dimension when calculating relative attention score:

https://github.com/graykode/xlnet-Pytorch/blob/cb793a1c75bdc59e3360f04ec641af726719811f/xlnet.py#L219

No errors are occuring just because we set batch_size=1 in default and the direction of data is incidentally equal to batch_size. When batch_size increases, there may be some dimension incompatible errors occurring caused by this. Hope some explaination and fixing could be given out.