JDAI-CV / image-captioning

Implementation of 'X-Linear Attention Networks for Image Captioning' [CVPR 2020]
268 stars 52 forks source link

Diversity augmented beam_search for XLAN with group_size set as 1?? #20

Closed ChenYutongTHU closed 3 years ago

ChenYutongTHU commented 3 years ago

Hello, thanks for this great work. When I was reading the code, I found that in models/att_basic_model.py, DBS: diversity-augmented beam search is recommended for xlan model.

    # For the experiments of X-LAN, we use the following beam search code, 
    # which achieves slightly better results but much slower.

    #def decode_beam(self, **kwargs):
    #    beam_size = kwargs['BEAM_SIZE']
    #    gv_feat, att_feats, att_mask, p_att_feats = self.preprocess(**kwargs)
    #    batch_size = gv_feat.size(0)
    #
    #    sents = Variable(torch.zeros((cfg.MODEL.SEQ_LEN, batch_size), dtype=torch.long).cuda())
    #    logprobs = Variable(torch.zeros(cfg.MODEL.SEQ_LEN, batch_size).cuda())   
    #    self.done_beams = [[] for _ in range(batch_size)]
    #    for n in range(batch_size):
    #        state = self.init_hidden(beam_size)
    #        gv_feat_beam = gv_feat[n:n+1].expand(beam_size, gv_feat.size(1)).contiguous()
    #        att_feats_beam = att_feats[n:n+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous()
    #        att_mask_beam = att_mask[n:n+1].expand(*((beam_size,)+att_mask.size()[1:]))
    #        p_att_feats_beam = p_att_feats[n:n+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous() if p_att_feats is not None else None
    #
    #        wt = Variable(torch.zeros(beam_size, dtype=torch.long).cuda())
    #        kwargs = self.make_kwargs(wt, gv_feat_beam, att_feats_beam, att_mask_beam, p_att_feats_beam, state, **kwargs)
    #        logprobs_t, state = self.get_logprobs_state(**kwargs)
    #
    #        self.done_beams[n] = self.beam_search(state, logprobs_t, **kwargs)
    #        sents[:, n] = self.done_beams[n][0]['seq'] 
    #        logprobs[:, n] = self.done_beams[n][0]['logps']
    #    return sents.transpose(0, 1), logprobs.transpose(0, 1)

However, in models/basic_model.py where DBS is implemented, the group size is forced to be 1, which means that diversity-augmented beam search is degraded to standard beam search.

        beam_size = kwargs['BEAM_SIZE']
        group_size = 1 #kwargs['GROUP_SIZE']
        diversity_lambda = 0.5 #kwargs['DIVERSITY_LAMBDA']

So how can DBS with group_size=1 slightly outperform than standard BS for xlan as the comment above mentioned? Thanks a bunch!

YehLi commented 3 years ago

We do not use any diversity-augmented beam search strategy in our experiments. The outputs of the two beam search versions are almost the same. The small difference might be caused by the sort part, which can be ignored in most of the cases.