xchhuang / pytorch_sliced_wasserstein_loss

An unofficial PyTorch implementation of "A Sliced Wasserstein Loss for Neural Texture Synthesis" paper [CVPR 2021].
13 stars 1 forks source link

LOSS can not convergence #1

Open Haosouth opened 2 years ago

Haosouth commented 2 years ago

Hello, thank you very much for writing sliced_wasserstein_loss as a pytorch version. I tried to migrate it to my style transfer network and replace the original Gram matrix used to describe style features. Although the code runs successfully, Loss does I have not been able to converge. I have made some changes to the code. I don’t know if my understanding of the code is wrong. I hope you can help me. Thank you very much. The following are my changes:

 def forward(self, input, target):
        self.update_slices(input)
        loss = 0.0
        target_ = self.compute_target(target)
        for idx, l in enumerate(input):
            cur_l = self.compute_proj(l, idx, 1)
            loss += F.mse_loss(cur_l, target_[idx])
        return loss

Call this LOSS:

slicing_torch = Slicing_torch(device, repeat_rate=1)
loss_style = slicing_torch(Ics_feats[0], style_feats[0])
xchhuang commented 2 years ago

Hi, thanks for your interests.

I am wondering if slicing_torch(Ics_feats[0], style_feats[0]) means you only use batch_size=1? If so, this looks good to me.

Also, could you share the loss curve for the task of texture synthesis using your new forward function? This might be easier for debugging.

Haosouth commented 2 years ago

Hi, thanks for your interests.

I am wondering if slicing_torch(Ics_feats[0], style_feats[0]) means you only use batch_size=1? If so, this looks good to me.

Also, could you share the loss curve for the task of texture synthesis using your new forward function? This might be easier for debugging.

Sorry, I don‘t understand why you mentioned batch_size = 1, slicing_torch(Ics_feats[0], style_feats[0]) where Ics_feats[0] is the output feature map of a layer of VGG after the stylized image, my style Loss finally dropped to about 50 and it couldn't go down anymore.

xchhuang commented 2 years ago

I see. In my case, running the texture synthesis on data/SlicedW/input.jpg in the repo would give me the following loss curve. And the final synthesis result is shown in 1st row of Sample Results section in readme.

Would you mind sharing your loss curve, final synthesis result on the same data?

Haosouth commented 2 years ago

78`_JI%$39~$XA$O~B%$E{2 I am sorry to reply to you for so long and this is my loss curve,

xchhuang commented 2 years ago

No problem, yes the loss should not be that large.

From my understanding, you don't need to call self.update_slices() and self.compute_target() every time in forward. Instead, you will need that after each iteration of LBFGS, that's why I move them outside forward.

The issue may come from your new Slicing_torch class, but I am not sure without knowing more details.