masato-ka / airc-rl-agent

AI RC Car Agent that using deep reinforcement learning on Jetson Nano
MIT License
86 stars 24 forks source link

question on loss function of VAE #44

Closed Ishihara-Masabumi closed 2 years ago

Ishihara-Masabumi commented 2 years ago

Loss function of VAE is as follows:

    def loss_fn(self, images, reconst, mean, logvar):
        KL = -0.5 * torch.sum((1 + logvar - mean.pow(2) - logvar.exp()), dim=0)
        KL = torch.mean(KL)
        reconstruction = F.binary_cross_entropy(reconst.view(-1,38400), images.view(-1, 38400), reduction='sum') #size_average=False)
        return reconstruction + 5.0 * KL

I can see the part of the loss, "reconstruction". But, what is the meaning of "5.0 * KL"?

masato-ka commented 2 years ago

KL is the KL divergence term, which is responsible for constraining the latent space of the VAE to a regular expression. 5.0 is a beta parameter. It is a coefficient for disentanglement the latent space. disentanglement is to make each element of the latent space have an independent meaning. For example, the value of one element can be constrained to affect only the background.

detail is see in https://openreview.net/pdf?id=Sy2fzU9gl

masato-ka commented 2 years ago

Could I close this question ?

Ishihara-Masabumi commented 2 years ago

I don't fully understand it, but I will read and understand the paper you recommended. Therefore, you can close this issue.