Open tarunbhatiaind opened 2 years ago
We optimize the student model only based on the losses of the tokens which the student model predicts with high confidence. The corresponding line in the code is https://github.com/cliang1453/BOND/blob/32f26988a58ee44eb4f50772c6d6c6eb116c83cf/model_utils.py#L88.
Thanks for the reply. But, in the above line 'pred_labels' are coming from teacher model and you are getting the mask of confident predictions of teacher model, right ? What I understood was, you check which token predictions of teacher are confident and then calculate loss of student for those tokens only . Please correct me if I'm wrong.
What you understand is correct. I was saying that "we optimize the student model only based on the losses of the tokens which the teacher model predicts with high confidence". Sorry for the typo.
Thanks @cliang1453 . Just one more query. My task is also token level classification. Would it make sense to just utilize mean teacher for training. Something like : Learn a model in stage 1 with less data Then use that model to initialize teacher and student for second stage. Give teacher all unlabeled data and it will generate pseudo labels for student to train on(calculating cross entropy loss on confident predictions of teacher). Use consistency cost between soft labels of student and teacher, and then update teacher with exponential moving average of student's weights. Continue this for later epochs. It would be really helpful if I can just get a comment on this. Thanks in advance!
In the paper, its mentioned that ' we select samples based on the prediction confidence of the student model to further improve the quality of soft labels.' But its also mentioned that 'we discard all pseudo-labels from the (t-1)-th iteration, and only train the student model using pseudo-labels generated by the teacher model at the t-th iteration'.
Is the fist statement talking about calculating the loss of student model only on the high confidence pseudo labels or its something else because in the code i could'nt find any other justification for this line. Please suggest.