rrmina / fast-neural-style-pytorch

Fast Neural Style Transfer implementation in PyTorch :art: :art: :art:
315 stars 77 forks source link

About content loss #5

Closed UtopiaHu closed 5 years ago

UtopiaHu commented 5 years ago

Hi,

Thanks for this project.

In your code, when you compute content loss, your code is : content_loss = CONTENT_WEIGHT * MSELoss(content_features['relu2_2'], generated_features['relu2_2'])

I think this might be wrong, it should be: content_loss = CONTENT_WEIGHT * MSELoss(generated_features['relu2_2'], content_features['relu2_2'])

Please correct me if I am wrong.

rrmina commented 5 years ago

Hi @UtopiaHu,

Thanks for the question. Actually it doesn't matter whether you do the former or you do the latter. MSE stands for Mean Square of Error aka Mean of Square of Differences. The square of (generated - content) should be equal to the square of (content - generated).

UtopiaHu commented 5 years ago

Hi, @rrmina , Thank you for your reply.

The former will raise the following error: AssertionError: nn criterions don't compute the gradient w.r.t. targets - please mark these tensors as not requiring gradients. My pytorch version is 0.4.0.

rrmina commented 5 years ago

@UtopiaHu

I see. I do not recall having that problem in versions 0.4.1 and 1.0.1. Seeing now the assertion error, you might be right. Technically they should be equal but by convention, the latter is the correct one.

Thanks for reporting. I'll update the codes to reflect the suggested changes. :)

UtopiaHu commented 5 years ago

Hi, I have another question. When initializing the VGG network, all the parameters are set to have the attributes of requries_grad=false, which means the gradients of VGG losses with respect to the VGG parameters will not be calculated after the backward() function is called, according to my understand. Then, how do the VGG losses propagate back to the output and the intermediate layers of the transformer net? I think if we want to update the transformer net, the VGG losses should go backward to update the parameters. Maybe I missed something about how pytorch works. Thanks! :-)