YannDubs / disentangling-vae

Experiments for understanding disentanglement in VAE latent representations
Other
793 stars 145 forks source link

Why is tc_loss in bTCVAE negative? #60

Open sisodia-a opened 4 years ago

sisodia-a commented 4 years ago

https://github.com/YannDubs/disentangling-vae/blob/535bbd2e9aeb5a200663a4f82f1d34e084c4ba8d/results/btcvae_dsprites/train_losses.log#L5

sisodia-a commented 4 years ago

https://github.com/rtqichen/beta-tcvae/ calculates logqz_prodmarginals = (logsumexp(_logqz, dim=1, keepdim=False) - math.log(batch_size dataset_size)).sum(1) logqz = (logsumexp(_logqz.sum(2), dim=1, keepdim=False) - math.log(batch_size dataset_size)) in case of # minibatch weighted sampling

and in case of # minibatch stratified sampling, they do logiw_matrix = Variable(self._log_importance_weight_matrix(batch_size, dataset_size).type_as(_logqz.data)) logqz = logsumexp(logiw_matrix + _logqz.sum(2), dim=1, keepdim=False) logqz_prodmarginals = logsumexp(logiw_matrix.view(batch_size, batch_size, 1) + _logqz, dim=1, keepdim=False).sum(1)

so in this codebase, shouldn't we also do (in case of NOT is_mss)

log_qz = (torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)-math.log(batch_size*n_data))       
log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False)-math.log(batch_size*n_data)).sum(1)

and in case of (is_mss)

    log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)                   
    log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)                            
    log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size,batch_size,1)+mat_log_qz, dim=1, keepdim=False).sum(1)    
YannDubs commented 4 years ago

Thanks @UserName-AnkitSisodia! I think you might be right (I am taking a sum instead of marginalizing in the log space), but It's been a long time so I'll have to double-check this w-e.

Did you test it with these changes?

sisodia-a commented 4 years ago

Using some random matrices (code attached temp.txt temp.txt

), I used your code as well as Ricky Chen's code to compare what is happening.

I found

MWS log_qz != logqz_ricky log_prod_qzi != logqz_prodmarginals_ricky

MSS logqz_prodmarginals_ricky_mss == log_prod_qzi_mss logqz_ricky_mss != log_qz_mss

So, when I use your code with is_mss=true, then I get -ve tc_loss and with is_mss=false, I get -ve mi_loss and -ve tc_loss. I ran it on dsprites dataset with batchsize 128.

Then I changed the _get_log_pz_qz_prodzi_qzCx function in your code to make it similar to Ricky Chen's code.

batch_size, hidden_dim = latent_sample.shape

# calculate log q(z|x)
log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)

# calculate log p(z)
# mean and log var is 0
zeros = torch.zeros_like(latent_sample)
log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)

mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)

log_qz = (torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)-math.log(batch_size * n_data))       ## Ankit - modified
log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False)-math.log(batch_size * n_data)).sum(1) ## Ankit - modified

# is_mss=False
if is_mss:                                                                                                                ## Ankit - modified
    # use stratification                                                                                                  ## Ankit - modifiede
    log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)                                ## Ankit - modified
    log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)                                        ## Ankit - modified
    log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size,batch_size,1)+mat_log_qz, dim=1, keepdim=False).sum(1)      ## Ankit - modified

return log_pz, log_qz, log_prod_qzi, log_q_zCx

Then I get +ve losses for everything when is_mss=True but then I get -ve dw_kl_loss term.

YannDubs commented 4 years ago

Awesome thanks for checking. Few comments:

1/ What do you mean by "+ve" and "-ve" ? What is ve ?

2/ Looking back at it it seems that I actually had the correct code and then incorporated the problem it in a late night push ( #43 )

Here's what I had before my changes:

def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist,n_data, is_mss=False):
    batch_size, hidden_dim = latent_sample.shape

    # calculate log q(z|x)
    log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)

    # calculate log p(z)
    # mean and log var is 0
    zeros = torch.zeros_like(latent_sample)
    log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)

    if not self.is_mss:
        log_qz, log_prod_qzi = _minibatch_weighted_sampling(latent_dist,
                                                            latent_sample,
                                                            n_data)

    else:
        log_qz, log_prod_qzi = _minibatch_stratified_sampling(latent_dist,
                                                              latent_sample,
                                                              n_data)

    return log_pz, log_qz, log_prod_qzi, log_q_zCx

def _minibatch_weighted_sampling(latent_dist, latent_sample, data_size):
    """
    Estimates log q(z) and the log (product of marginals of q(z_j)) with minibatch
    weighted sampling.

    Parameters
    ----------
    latent_dist : tuple of torch.tensor
        sufficient statistics of the latent dimension. E.g. for gaussian
        (mean, log_var) each of shape : (batch_size, latent_dim).

    latent_sample: torch.Tensor
        sample from the latent dimension using the reparameterisation trick
        shape : (batch_size, latent_dim).

    data_size : int
        Number of data in the training set

    References 
    -----------
       [1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
       autoencoders." Advances in Neural Information Processing Systems. 2018.
    """
    batch_size = latent_sample.size(0)

    mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)

    log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False) -
                    math.log(batch_size * data_size)).sum(dim=1)
    log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False
                             ) - math.log(batch_size * data_size)

    return log_qz, log_prod_qzi

def _minibatch_stratified_sampling(latent_dist, latent_sample, data_size):
    """
    Estimates log q(z) and the log (product of marginals of q(z_j)) with minibatch
    stratified sampling.

    Parameters
    -----------
    latent_dist : tuple of torch.tensor
        sufficient statistics of the latent dimension. E.g. for gaussian
        (mean, log_var) each of shape : (batch_size, latent_dim).

    latent_sample: torch.Tensor
        sample from the latent dimension using the reparameterisation trick
        shape : (batch_size, latent_dim).

    data_size : int
        Number of data in the training set

    References 
    -----------
       [1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
       autoencoders." Advances in Neural Information Processing Systems. 2018.
    """
    batch_size = latent_sample.size(0)

    mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)

    log_iw_mat = log_importance_weight_matrix(batch_size, data_size).to(latent_sample.device)
    log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)
    log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size, batch_size, 1) +
                                   mat_log_qz, dim=1, keepdim=False).sum(1)

    return log_qz, log_prod_qzi
YannDubs commented 4 years ago

which is (I believe) exactly what you tested.

sisodia-a commented 4 years ago

Yes, this makes the code exactly same. Once these changes are made, I get negative dw_kl_loss term in case of _minibatch_weighted_sampling. For _minibatch_stratified_sampling, I am getting all loss terms as positive. I tested on dsprites.

YannDubs commented 4 years ago

and qualitatively do you see any differences?

sisodia-a commented 4 years ago

I didn't test that yet. I was just trying to see from the math/code where am I getting the error.

DianeBouchacourt commented 3 years ago

Has this issue been solved ? Training on dSprites, I also get negative tc loss

shi-yu-wang commented 2 years ago

I also got the negative loss with the DSprites data

shi-yu-wang commented 2 years ago

tc loss