rail-berkeley / rlkit

Collection of reinforcement learning algorithms
MIT License
2.45k stars 550 forks source link

MSE reconstruction loss in VAE training is summed over, not averaged over #119

Closed mseitzer closed 4 years ago

mseitzer commented 4 years ago

Hi,

I found a bit of an obscure bug. The VAE reconstruction loss computed here

https://github.com/vitchyr/rlkit/blob/20ea0820eb89bddae7c6a5171038a005e472c3d0/rlkit/torch/vae/conv_vae.py#L241-L242

is unintentionally (I think, because it uses elementwise_mean) summed over instead of correctly averaged over. This is because input and target to F.mse_loss are switched, and in combination with the deprecated reduction elementwise_mean instead of mean, Pytorch takes the sum instead of the mean. This can be checked here:

https://github.com/pytorch/pytorch/blob/665feda15bc45d0f50326596ecde6f2d96ac6644/torch/nn/functional.py#L2668-L2672

Changing the line to

log_prob = -1 * F.mse_loss(obs_distribution_params[0], inputs, reduction='mean')

correctly computes the average. However, I suspect you might not want to change this as your hyperparameters (in particular beta) are tuned to the sum, not the mean.

vitchyr commented 4 years ago

Thanks for pointing this out. Funny enough, due to a bug in pytorch, I think this ends up taking the mean. Either way, I probably won't want to change the hyperparameter, but this is good to keep in mind.

Also, it's actually unclear if the VAE loss should average or sum over pixels. MSE basically corresponds to the log-prob of a Gaussian with a fixed Gaussian. So, I think the proper probabilistic loss function would be to sum over the pixels, since that effectively treats all the pixels as independent. That said, it depends a bit on the implementation.

I'll close the issue for now since I don't think there's an actionable response, but feel free to continue the discussion.

mseitzer commented 4 years ago

Thanks for pointing this out. Funny enough, due to a bug in pytorch, I think this ends up taking the mean. Either way, I probably won't want to change the hyperparameter, but this is good to keep in mind.

At least on the pytorch version I am using (1.2.0), the line I referenced computes the sum. This is consistent with your Pytorch issue, as obs_distribution_params[0] requires gradient.

Also, it's actually unclear if the VAE loss should average or sum over pixels. MSE basically corresponds to the log-prob of a Gaussian with a fixed Gaussian. So, I think the proper probabilistic loss function would be to sum over the pixels, since that effectively treats all the pixels as independent. That said, it depends a bit on the implementation.

Yes, this is something that bugged me for a long time. I agree with you that taking the sum over pixels is the correct thing to do, but that the mean should be taken over the batch dimension. Having beta dependent on the batch dimension seems suboptimal to me.

vitchyr commented 4 years ago

Oh, does it sum over the batch dimension as well? If so, I agree that's quite bad!