ruotianluo / self-critical.pytorch

Unofficial pytorch implementation for Self-critical Sequence Training for Image Captioning. and others.
MIT License
991 stars 278 forks source link

dataloader.py get_captions() random sample times #261

Open Eajay opened 2 years ago

Eajay commented 2 years ago

original:

        if ncap < seq_per_img:
            # we need to subsample (with replacement)
            seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
            for q in range(seq_per_img):
                ixl = random.randint(ix1,ix2)
                seq[q, :] = self.label[ixl, :self.seq_length]

Suppose seq_per_img=5, ncap=3 This will random select caption 5 times, some captions might not be selected

How about changing to:

       if ncap < seq_per_img:
            # we need to subsample (with replacement)
            seq = np.zeros([seq_per_img, self.seq_length], dtype='int')
            seq[: ncap, :] = self.label[ix1: ix1+ncap, :self.seq_length]
            for q in range(ncap, seq_per_img):
                ixl = random.randint(ix1, ix2)
                seq[q, :] = self.label[ixl, :self.seq_length]

So we can keep 3 captions and only randomly sample 2 times