Open abheesht17 opened 2 years ago
@abheesht17 Thanks for opening this feature request!
I have one question - why are we still doing masking at step 3? I am not very familiar with permutation language modeling, but reading some articles, it does not apply masks any more?
Hello, @chenmoneygithub! I think the reason is as follows:
XLNet has multiple factorisation orders since it permutes the input sequence. Suppose our input text is [1, 2, 3, 4], and assume that XLNet generates two permutations - [3, 2, 4, 1] and [2, 4, 3, 1]. Then, in the first case, if we want to compute the updated representation of token "3", we will mask "2", "4", "1" (since they come after "3"), and in the second case, we will mask "1". That's why we have a "permutation mask" for XLNet.
(Sorry for the late reply)
This figure explains it well:
A more concrete explanation: https://www.borealisai.com/en/blog/understanding-xlnet/
Building on https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/layers/preprocessing/mlm_mask_generator.py which dynamically masks tokens, I was wondering if we can implement a layer for how XLNet generates permutation masks for its inputs (Permutation Language Modelling).
This is a very good function which generates inputs for XLNet: https://github.com/huggingface/transformers/blob/72728be3dbca26c70dddc8b724eb2c8d901e97dc/src/transformers/data/data_collator.py#L1230
Would love to take this up!