ruotianluo / self-critical.pytorch

Unofficial pytorch implementation for Self-critical Sequence Training for Image Captioning. and others.
MIT License
991 stars 278 forks source link

captioning/models/FCModels.py中seqLogprobs的维度问题 #248

Closed BetterZH closed 3 years ago

BetterZH commented 3 years ago

` def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):

    seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length, self.vocab_size + 1)  #这里定义的是三维`

但是, output, state = self.core(xt, state) logprobs = F.log_softmax(self.logit(output), dim=1) greedy采样: sampleLogprobs, it = torch.max(logprobs.data, 1) mulltinomial采样: it = torch.multinomial(prob_prev, 1).to(logprobs.device) sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions

sampleLogprobs存储的采样单词对应位置的概率,维度是(bt,),因此 seqLogprobs[:,t-1] = sampleLogprobs.view(-1) 得到的seqLogprobs的维度是[bt,len],这与定义的三维是矛盾的。

请问这里seqLogprobs的维度应该是多少,感觉这里前后不一致。感谢帮助!

ruotianluo commented 3 years ago

FC比较旧了。我后来不怎么用了,所以可能是不对的。建议看attmodel