Closed guohanp closed 8 months ago
train.py and test.py each have the param --bsize which can be set, not sure this is what you're looking for?
哦我看到代码里对比损失这写的只能是1的batch
这个代码里写的只能1的batch size啊
可以尝试一下这个新的函数: ` def get_contrastive_loss(self, style, mask, pred,style_comparison):
style_feats = self.encode_with_small_intermediate(style)
fine_feats = self.encode_with_small_intermediate(pred)
loss_contra=[]
for j in range(style.shape[0]):
style_comparison_cur=style_comparison[j]
style_comparison_feats=[]
for i in range(self.compare_num):
style_comparison_feats.append(self.encode_with_small_intermediate(style_comparison_cur[i:i+1]))
style_comparison_feats=torch.cat(style_comparison_feats,dim=0)
cur_loss_contra=self.calc_contrastive_loss(style_feats[j].unsqueeze(0),style_comparison_feats,fine_feats[j].unsqueeze(0),mask[j].unsqueeze(0))
loss_contra.append(cur_loss_contra)
loss_contra=torch.mean(torch.stack(loss_contra,dim=0))
return loss_contra`
感谢您的提醒,我们已经在库里进行了更新,如果使用中有任何问题欢迎再进行反馈,我们会及时跟进。
请问有支持别的batch size的代码吗