Closed mseitzer closed 4 years ago
Thanks for pointing this out. Funny enough, due to a bug in pytorch, I think this ends up taking the mean. Either way, I probably won't want to change the hyperparameter, but this is good to keep in mind.
Also, it's actually unclear if the VAE loss should average or sum over pixels. MSE basically corresponds to the log-prob of a Gaussian with a fixed Gaussian. So, I think the proper probabilistic loss function would be to sum over the pixels, since that effectively treats all the pixels as independent. That said, it depends a bit on the implementation.
I'll close the issue for now since I don't think there's an actionable response, but feel free to continue the discussion.
Thanks for pointing this out. Funny enough, due to a bug in pytorch, I think this ends up taking the mean. Either way, I probably won't want to change the hyperparameter, but this is good to keep in mind.
At least on the pytorch version I am using (1.2.0), the line I referenced computes the sum. This is consistent with your Pytorch issue, as obs_distribution_params[0]
requires gradient.
Also, it's actually unclear if the VAE loss should average or sum over pixels. MSE basically corresponds to the log-prob of a Gaussian with a fixed Gaussian. So, I think the proper probabilistic loss function would be to sum over the pixels, since that effectively treats all the pixels as independent. That said, it depends a bit on the implementation.
Yes, this is something that bugged me for a long time. I agree with you that taking the sum over pixels is the correct thing to do, but that the mean should be taken over the batch dimension. Having beta
dependent on the batch dimension seems suboptimal to me.
Oh, does it sum over the batch dimension as well? If so, I agree that's quite bad!
Hi,
I found a bit of an obscure bug. The VAE reconstruction loss computed here
https://github.com/vitchyr/rlkit/blob/20ea0820eb89bddae7c6a5171038a005e472c3d0/rlkit/torch/vae/conv_vae.py#L241-L242
is unintentionally (I think, because it uses
elementwise_mean
) summed over instead of correctly averaged over. This is becauseinput
andtarget
toF.mse_loss
are switched, and in combination with the deprecated reductionelementwise_mean
instead ofmean
, Pytorch takes the sum instead of the mean. This can be checked here:https://github.com/pytorch/pytorch/blob/665feda15bc45d0f50326596ecde6f2d96ac6644/torch/nn/functional.py#L2668-L2672
Changing the line to
correctly computes the average. However, I suspect you might not want to change this as your hyperparameters (in particular
beta
) are tuned to the sum, not the mean.