zihangdai / xlnet

XLNet: Generalized Autoregressive Pretraining for Language Understanding
Apache License 2.0
6.17k stars 1.18k forks source link

perm_mask during fine tuning #195

Open ardofski opened 5 years ago

ardofski commented 5 years ago

What should be the perm_mask during fine tuning or which part of the code creates that perm_mask ? Thanks.

langfield commented 5 years ago

The code that generates perm_mask is in the parser function, which is located at line 657 in the data_utils.py file. The following snippet shows how it's created from the output of the _local_perm function.

perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
        inputs[:reuse_len],
        target[:reuse_len],
        is_masked[:reuse_len],
        perm_size,
        reuse_len)

    perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
        inputs[reuse_len:],
        target[reuse_len:],
        is_masked[reuse_len:],
        perm_size,
        non_reuse_len)

    perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])],
                            axis=1)
    perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
                            axis=1)
    perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)

Take a look at _local_perm as well in the same file. The inputs to that function are created in create_tfrecords on line 395. The important bits are happening in the nested for loops starting at line 448, where a call to _split_a_and_b is sampling two roughly equal-sized slices from data[idx] whose total length is tot_len. Then two calls to _sample_mask are masking n-grams of variable size (randomly chosen in a range) along the samples for partial prediction. There is one call for inp, a sample of length reuse_len immediately preceding the call on the AB sample computed via the split function. The output of this block is used to create the inputs, target, and is_masked variables used as input to the _local_perm function.

xingchensong commented 5 years ago

perm_mask will noe be used in fine-tuning.

ardofski commented 5 years ago

is it set to none or matrix with all zeros ?

langfield commented 5 years ago

@stephen-song can it be used during fine-tuning, though? For example if your downstream task is exactly the same as the pretraining task (masked LM objective)? Good catch though, should have included above. Doesn't make sense for most applications and I don't think the code supports it for fine-tuning.