bcmi / PHDiffusion-Painterly-Image-Harmonization

[ACM MM 2023] The code used in our paper "Painterly Image Harmonization using Diffusion Model", ACM MM2023.
Apache License 2.0
51 stars 2 forks source link

请问有支持别的batch size的代码吗 #3

Closed guohanp closed 8 months ago

guohanp commented 12 months ago

请问有支持别的batch size的代码吗

ArtoriaKawaii commented 11 months ago

train.py and test.py each have the param --bsize which can be set, not sure this is what you're looking for?

guohanp commented 11 months ago

哦我看到代码里对比损失这写的只能是1的batch

guohanp commented 11 months ago

image

guohanp commented 11 months ago

这个代码里写的只能1的batch size啊

pokaaa commented 11 months ago

可以尝试一下这个新的函数: ` def get_contrastive_loss(self, style, mask, pred,style_comparison):

    style_feats = self.encode_with_small_intermediate(style)
    fine_feats = self.encode_with_small_intermediate(pred)

    loss_contra=[]
    for j in range(style.shape[0]):
        style_comparison_cur=style_comparison[j]
        style_comparison_feats=[]
        for i in range(self.compare_num):
            style_comparison_feats.append(self.encode_with_small_intermediate(style_comparison_cur[i:i+1]))
        style_comparison_feats=torch.cat(style_comparison_feats,dim=0)  
        cur_loss_contra=self.calc_contrastive_loss(style_feats[j].unsqueeze(0),style_comparison_feats,fine_feats[j].unsqueeze(0),mask[j].unsqueeze(0))
        loss_contra.append(cur_loss_contra)

    loss_contra=torch.mean(torch.stack(loss_contra,dim=0))

    return loss_contra`

感谢您的提醒,我们已经在库里进行了更新,如果使用中有任何问题欢迎再进行反馈,我们会及时跟进。