luyug / GC-DPR

Train Dense Passage Retriever (DPR) with a single GPU
Other
128 stars 20 forks source link

Gradient caching vs Model dropout #12

Open harveyp123 opened 7 months ago

harveyp123 commented 7 months ago

The GC-DPR has two steps

  1. The first step did a full batch forward without gradient, to get the full batch contrastive learning loss and corresponding embedding gradient.
  2. The second step conduct mini-batch forward, and assign the embedding gradient, then do backward. The mini-batch will loop through the full batch to computing all gradient and accumulate.

However, during the computation, there might be one issues:

  1. The backbone model has randomized dropout process, the dropout will make the 1 & 2 to be inconsistent. 1's dropout process will be different from 2, so 1's gradient can not be directly applied to 2. 2's gradient shall be calculated again for every mini-batch. This bug can be fixed using some more sophisticated operation to make sure 1&2 to be consistent.
harveyp123 commented 6 months ago

In short, in the second for loop, for everything minibatch query and passage loss backward, you put the query and passage embedding into the original batch, and calculate the gradient for the current query/passage, so you can make sure the dropout behavior doesn't change your gradient too much.

luyug commented 6 months ago

In our train code, the random states are snapshot using the RandContext class https://github.com/luyug/GC-DPR/blob/79e1fe06bd879662a5e8415efd278080225b9892/train_dense_encoder.py#L53-L69 in the first fwd and restored at the beggining of the 2nd, so what you described shouldn't be a problem.

harveyp123 commented 6 months ago

Oh, okay, I was using deepspeed + gradient caching, the model is wrapped into a deepspeed defined object, and RandContext doesn't work on my side. But it's good to learn from your code : )