allenai / longformer

Longformer: The Long-Document Transformer
https://arxiv.org/abs/2004.05150
Apache License 2.0
2.05k stars 276 forks source link

Does the multi-task training include predicting relevant paragraphs? #156

Open Fan-Luo opened 3 years ago

Fan-Luo commented 3 years ago

Hi, The page mention in section 6.1

We use a two-stage model that first selects the most relevant paragraphs then passes them to a second stage for answer extraction. Both stages concatenate question and context into one sequence, run it through Longformer, then use task-specific prediction layers. We train the models in a multi-task way to predict relevant paragraphs, evidence sentences, answer spans and question types (yes/no/span) jointly.

My understanding is that the first stage concatenates question and the whole context as input to Longformer, and only predict relevant paragraphs as a binary classification task for each paragraph. The second stage concatenates question and the predict relevant paragraphs as input, train another Longformer model in a multi-task way to predict relevant paragraphs, evidence sentences, answer spans and question types (yes/no/span) jointly.

Does the multi-task training in stage2 still need to include predicting relevant paragraphs?

Thank you

armancohan commented 3 years ago

Thanks for reaching out! Both models are identical (they both predict relevant paragraphs, evidence sentences, answer spans, and question types). The difference is that the second stage model uses up to 5 high scoring paragraphs from the first stage model (the paragraph score should be higher than a threshold as mentioned in the appendix).

Fan-Luo commented 3 years ago

Thank you, Arman!

Do you mind I ask for more clarification: when you said 'Both models are identical', I believe you meant the architecture are same, but are the weights are also same, that is, did you reuse the layers of stage1 in stage2? In other words, did you enlarge the whole model size during stage2?

Besides, did you add the loss of two stages for each of the sub_tasks (relevant paragraphs, evidence sentences, answer spans, and question types), as:

paragraph_loss = paragraph_loss _1 + paragraph_loss_2
sentence_loss = sentence_loss_1 + sentence_loss_2
span_loss = span_loss_1 + span_loss_2
type_loss =type_loss_1 + type_loss_2
total_loss = lambda1 * span_loss + lambda2 * sentence_loss + lambda3 * paragraph_loss + lambda4 * type_loss

One more question I would like to ask for clarification:

For answer span extraction we use BERT’s QA model (Devlin et al., 2019) with addition of a question type (yes/no/span) classification head over the first special token ([CLS]).

I believe it means predicting answer span and question type in the similar way as did in BERT’s QA model (Devlin et al., 2019), instead of actually using BERT. Could you please correct me if my understanding is wrong?

Thank you