DerrickWang005 / CRIS.pytorch

An official PyTorch implementation of the CRIS paper
MIT License
237 stars 36 forks source link

About loss function #5

Open wudongming97 opened 2 years ago

wudongming97 commented 2 years ago

Hi, I found that the loss used in this repo is a cross-entropy loss between prediction and mask.

loss = F.binary_cross_entropy_with_logits(pred, mask)

But the loss mentioned in the paper is a contrastive loss between visual and textual features.

Deepayan137 commented 2 years ago

I have the same query. Can the authors please clarify?

Deepayan137 commented 2 years ago

Hello! I wrote the contrastive learning part by following the instructions in the paper. However, when training the model only with the contrastive loss, the training IOU doesn't seem to improve. Below, I am attaching the code snippet and the training IOU and precision curves. The training is done only for 1 epoch. The brown plots are for cross-entropy loss while the blue plots are for contrastive loss. I would be grateful if you could let me know what I am doing wrong and also if the contrastive loss is supposed to be used in addition to cross-entropy loss. Thanks

def forward(self, x, word, mask):
      x = self.vis(x)
      B, C, H, W = x.size()
      word = self.txt(word)
      x = x.permute(0, 2, 3, 1)
      out = torch.einsum('nhwc,nc->nhw', x, word).unsqueeze(1)
      out = torch.sigmoid(out) #sigmoid of zt dot zv
      loss = torch.zeros((H, W)).cuda()
      pos_count, neg_count = 0, 0
      for i in range(word.size(0)):
          zt = word[i]
          zt = zt.unsqueeze(0)
          for j in range(x.size(0)):
              zv = x[j]
              zv = zv.reshape(self.in_dim, -1)
              prod = torch.mm(zt, zv).squeeze()
              prod = prod.reshape(H, W)
              if i == j:
                  pos = - torch.log(F.sigmoid(prod))
                  loss += pos
                  pos_count += 1
              else:
                  neg = - torch.log(1 - F.sigmoid(prod))
                  loss += neg
                  neg_count += 1
      total = pos_count + neg_count
      loss = torch.mean(loss)
      if out.shape[-2:] != mask.shape[-2:]:
          mask = F.interpolate(mask, out.shape[-2:],
              mode='nearest').detach()
      return out, loss/total, mask

W B Chart 09_07_2022, 10_18_53 W B Chart 09_07_2022, 10_18_40

DerrickWang005 commented 2 years ago

please follow our implementation. https://github.com/DerrickWang005/CRIS.pytorch/blob/0df39f073acfb9e6e17d83536a916548905ecfc3/model/layers.py#L47-L84

Deepayan137 commented 2 years ago

Hello Derrick,

I had seen this implementation. In your paper, you have mentioned equations 9 and 10 as the contrastive loss between pixel embeddings and the text features. I am not able to understand, how it is taken care of in your above code snippet?

tiger990111 commented 1 year ago

I have the same query. Can the authors please clarify?

FabianRitter commented 1 year ago

No follow up? looks like supervised learning on the code. I assume something is missing in the code.

Starboy-at-earth commented 1 year ago

@DerrickWang005 Could you please realse the code snippet of contrastive learning loss?

clownrat6 commented 1 year ago

Actually, the implementation is in line with the description of the paper. However, this is actually not the standard contrastive learning.

Fake10086 commented 1 year ago

you may take a deeper look at codes mentioned by the author above, and you'll find that conv2d actually acts like element wise product between text and image which can be considered as equation 9&10.

lyu-yx commented 8 months ago

I have the same question, could the authors release the latest version of code? @DerrickWang005

DerrickWang005 commented 8 months ago

I think this article can answer your question to some extent. @lyu-yx https://arxiv.org/pdf/2303.15343.pdf

ccccchenllll commented 7 months ago

I have the same question. I couldn't find the code about contrastive loss.