salesforce / ALBEF

Code for ALBEF: a new vision-language pre-training method
BSD 3-Clause "New" or "Revised" License
1.46k stars 193 forks source link

about MLM softlabel implementation #58

Open zhezh opened 2 years ago

zhezh commented 2 years ago

near the code https://github.com/salesforce/ALBEF/blob/f224b67caa0c7294cb1a3d807640688b3aa58cad/models/xbert.py#L1429


soft_labels=F.softmax(logits_m, dim=-1)
if soft_labels is not None:
      loss_distill = -torch.sum(F.log_softmax(prediction_scores, dim=1) * soft_labels, dim=-1)
      loss_distill = loss_distill[labels != -100].mean()
      masked_lm_loss = (1 - alpha) * masked_lm_loss + alpha * loss_distill

The logits_m and prediction_scores are both of shape (batch, seq_len, n_vocab).

My question is why softmax on dim=1 for prediction_scores , but not on dim=-1 ?

LiJunnan1992 commented 2 years ago

It is a typo and has been fixed. Thanks for pointing it out! We will run experiments to make sure that results are not affected.

sanyalsunny111 commented 1 year ago

if soft_labels is not None: loss_distill = -torch.sum(F.log_softmax(prediction_scores, dim=1) * soft_labels, dim=-1) loss_distill = loss_distill[labels != -100].mean() masked_lm_loss = (1 - alpha) * masked_lm_loss + alpha * loss_distill

My question is based on equation 7 in ALBEF paper you guys have claimed to be using KL divergence but I don't see you guys have used KL-divergence in your loss function? Could you please explain?

LiJunnan1992 commented 1 year ago

if soft_labels is not None: loss_distill = -torch.sum(F.log_softmax(prediction_scores, dim=1) * soft_labels, dim=-1) loss_distill = loss_distill[labels != -100].mean() masked_lm_loss = (1 - alpha) * masked_lm_loss + alpha * loss_distill

My question is based on equation 7 in ALBEF paper you guys have claimed to be using KL divergence but I don't see you guys have used KL-divergence in your loss function? Could you please explain?

KL-Div is equivalent to cross-entropy loss when the target distribution does not have gradient.