salesforce / ALBEF

Code for ALBEF: a new vision-language pre-training method
BSD 3-Clause "New" or "Revised" License
1.57k stars 199 forks source link

The code for loss computation of itc is not corresponding to the original paper #133

Open Whisht opened 1 year ago

Whisht commented 1 year ago

In the loss computation process in [ALBEF](https://github.com/salesforce/ALBEF/blob/b9727e43c3040491774d1b22cc27718aa7772fac/models/model_pretrain.py#L103C3-L103C3), the computation is a little different to the raw paper. Let's take loss_i2t for example.

with torch.no_grad():
    self._momentum_update()
    image_embeds_m = self.visual_encoder_m(image)
    image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) 
    text_output_m = self.text_encoder_m.bert(text.input_ids, attention_mask = text.attention_mask,                      
                                                return_dict = True, mode = 'text')    
    text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 
    text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)

    sim_i2t_m = image_feat_m @ text_feat_all / self.temp 

    sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
    sim_targets.fill_diagonal_(1)          

    sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets    

sim_i2t = image_feat @ text_feat_all / self.temp 

loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()

This loss should be

image

In the code above-cited, image_feat_m is $I_m \in R^{n\times d}$, text_feat_all is $T_a \in R^{d\times(n+n_q)}$, sim_targets noted $y \in R^{n\times(n+n_q)}$, $p^{\text{i2t}}(I)=\mathop{\text{softmax}}(S(I,T_a))$, $q^\text{i2T}(I)=\mathop{\text{softmax}}(S(I_m,T_a))$. Here, $_m$ means momentum.

Suppose that $n$ batch_size = 2, queue_size = 2, so $n_q = 2 \times 2 = 4$.

image

The first term is not a KL divergence between $q$ and $p$, i.e., a self-entropy term lost. So, does this affect the performance of ALBEF? I think it should be a good regularization term.

MengHao666 commented 1 year ago

I also could not find the KL-divergence loss in the code. Do you have any idea?

Whisht commented 1 year ago

Add the Negative Entropy of sim_i2t_m to the total loss by multiplying the coefficient $\alpha$. You will get the KL divergence of $q$ and $p$. Some equations for reference:

$$ KL(q||p) = E_q \log \frac{q}{p} = -E_q \log p + (- E_q \log q) =CE(q|p) - H(p) $$