kimiyoung / transformer-xl

Apache License 2.0
3.61k stars 762 forks source link

Some questions about pytorch code and details. #8

Open wlhgtc opened 5 years ago

wlhgtc commented 5 years ago

Hi, there: So nice that you release the original code. Maybe a little difficult for me to reproduce: ( After nearly 1.5 days for matching your paper and code, still... some questions about model structure, hope you could help, maybe some foolish ...

  1. What's the difference between RelLearnableMultiHeadAttn and RelPartialLearnableMultiHeadAttn ? Seem the most important part is the construction of embedding (A+B+C+D), but the first one doesn't use the position embedding in "Attention is all you need"?

  2. Can you explain the function _rel_shift in detail for me? Especially the top -4 line code, I don't know why we need this?

  3. What happens when the param div_val > 1 and what's the meaning of the cutoff_xxx? More specifically, I think what we need is the part of code when div_val==1.

Hope you could help me, thx.

zihangdai commented 5 years ago
  1. RelLearnableMultiHeadAttn corresponds to the "relative positional encoding" Shaw et al. (2018) proposed, which merges the multiplication W^kR into a single trainable matrix hat{R} (see the last paragraph of Page 5). RelPartialLearnableMultiHeadAttn is the "relative positional encoding" we proposed in this work.

  2. It is easier to give an example. To perform relative attention, we want to relatively shift the attention score matrix as follows:

    a00 a01 a02      a02  0   0
    a10 a11 a12  =>  a11 a12  0
    a20 a21 a22      a20 a21 a22

What the _rel_shift does is just a clear way of achieving the transformation above:

a00 a01 a02      0 a00 a01 a02       0  a00 a01      a02  0  a10     a02  0   0
a10 a11 a12  =>  0 a10 a11 a12  =>  a02  0  a10  =>  a11 a12  0  =>  a11 a12  0
a20 a21 a22      0 a20 a21 a22      a11 a12  0       a20 a21 a22     a20 a21 a22
                                    a20 a21 a22
  1. The div_val is the ratio used to reduce the embedding dimension from each bin, where cutoff is the boundary of the bins. The name is adapted from original PyTorch class.
wlhgtc commented 5 years ago

So glad to see your reply, and list some person understanding, could help me correct them?

  1. About the shift operation in 2., seems an easy way to calculate position embedding in Appendix B?
  2. There seems some redundant code in pytorch version code, e.g. the _shift function?
wlhgtc commented 5 years ago

By the way ,seems like you add position embedding at each layer, is there any improvement compared with add only with the word embedding in your ablation study? @zihangdai

kimiyoung commented 5 years ago

@wlhgtc

wlhgtc commented 5 years ago

Seem your position embedding conflict with the original version in ??? your layer seems like the second col(sin,sin,...,sin,cos,cos,...,cos); but it should like the first col(sin,cos,sin,cos,...). image

wlhgtc commented 5 years ago

And thanks for your help, I finish detach the whole TRANSFORMER-XL model code in a single file. Still one question about your training process: https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/train.py#L433-L437 Seem you split the whole context into several chunk, and the mems[i] is used for training data[i+1] . But this code doesn't prove this? Or some special points in the BalancedDataParallel class?

wlhgtc commented 5 years ago

@kimiyoung Hope you could help

zihangdai commented 5 years ago
  1. For position embedding, the two columns are equivalent, simply because they are consumed by the matrix multiplication which is permutation-invariant.
  2. Just copy what we have explained in the README file ==> --batch_chunk: this option allows one to trade speed for memory. For batch_chunk > 1, the program will split each training batch into batch_chunk sub-batches and perform forward and backward on each sub-batch sequentially, with the gradient accumulated and divided by batch_chunk. Hence, the memory usage will proportionally lower while the computation time will inversely higher.
  3. For BalancedDataParallel, see #5
wlhgtc commented 5 years ago

Yeach, but I mean when we training on data[i], we need mems[i-1]: the memory for the last chunk. But ret = para_model(data_i, target_i, *mems[i]) seem use mems[i]?

zihangdai commented 5 years ago

No. The split by batch_chunk is along the batch dimension. In this case, mems is a python list (line 424), where mems[i] correspond to the i-th chunk, i.e., mini-batch.

wlhgtc commented 5 years ago

Fine, I re-read the code, seem the mems update when a batch finish and will be used in the next batch, am I right?
But according to the Figure 2 in your paper, I think the mems should flow between different segments. So I regard each chunk as different segment, but seems like that the different batch in the iterator are different segments ?

zihangdai commented 5 years ago

Please refer to the _update_mems function for how a single mem is updated.

When batch_chunk is used, each element mems[i] in mems is updated the exactly same way and then returned to the train loop so that it can be used for the next segment (see this line for how the mems[i] is returned).

BenjaminWinter commented 5 years ago

@zihangdai Gonna Piggyback on this issue since my question is somewhat related: Could you explain in a little more detail how the segment level recurrence works in code? I can see that you calculate the mems for an entire batch of (i assume consecutive) segments, and then reuse that in the next step, but I am confused on how recurrence between consecutive segments inside of one batch works.

Im asking this, because im thinking about how to adapt this model to a different task like question answering, and cant really wrap my head around how to build the segmentation when you have to distinguish between the contexts for different questions and cant treat the entire corpus as one large chunk of text.

abhitopia commented 5 years ago

@kimiyoung @zihangdai Following up,

It seems that by default, the zeroing of upper triangular matrix is False.

https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L194

What is the reason for that?

LindgeW commented 4 years ago

@BenjaminWinter I'm also confused about segment level recurrence in the paper?

aleSuglia commented 3 years ago

@zihangdai could you please clarify this issue? I can't find anywhere how you deal with multiple segments.