keras-team / keras-hub

Pretrained model hub for Keras 3
Apache License 2.0
763 stars 230 forks source link

Layer for Permutation Language Modelling [XLNet] #141

Open abheesht17 opened 2 years ago

abheesht17 commented 2 years ago

Building on https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/layers/preprocessing/mlm_mask_generator.py which dynamically masks tokens, I was wondering if we can implement a layer for how XLNet generates permutation masks for its inputs (Permutation Language Modelling).

This is a very good function which generates inputs for XLNet: https://github.com/huggingface/transformers/blob/72728be3dbca26c70dddc8b724eb2c8d901e97dc/src/transformers/data/data_collator.py#L1230

The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
    0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
    1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
    2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
       masked
    3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
       span_length]` and mask tokens `start_index:start_index + span_length`
    4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
       sequence to be processed), repeat from Step 1.

Would love to take this up!

chenmoneygithub commented 2 years ago

@abheesht17 Thanks for opening this feature request!

I have one question - why are we still doing masking at step 3? I am not very familiar with permutation language modeling, but reading some articles, it does not apply masks any more?

abheesht17 commented 2 years ago

Hello, @chenmoneygithub! I think the reason is as follows:

XLNet has multiple factorisation orders since it permutes the input sequence. Suppose our input text is [1, 2, 3, 4], and assume that XLNet generates two permutations - [3, 2, 4, 1] and [2, 4, 3, 1]. Then, in the first case, if we want to compute the updated representation of token "3", we will mask "2", "4", "1" (since they come after "3"), and in the second case, we will mask "1". That's why we have a "permutation mask" for XLNet.

(Sorry for the late reply)

abheesht17 commented 2 years ago

This figure explains it well: image

abheesht17 commented 2 years ago

A more concrete explanation: image https://www.borealisai.com/en/blog/understanding-xlnet/