jsxlei / SCALE

Single-cell ATAC-seq analysis via Latent feature Extraction
MIT License
97 stars 17 forks source link

Question about the elbo_SCALE #18

Open caokai1073 opened 3 years ago

caokai1073 commented 3 years ago

Hi, SCALE is a very interesting and useful tool.
But I have a question about the calculation of ELBO.

Why gamma=q(c|x)=p(c|z) is used to calculate ELBO instead of q(z,c|x)? Does q(c|x) equals to q(z,c|x) in Eq(z,c|x)[log p(x|z) + log p(z|c) + log p(c) - log q(z|x) - log q(c|x)] in loss.py?

def elbo_SCALE(recon_x, x, gamma, c_params, z_params, binary=True):
    """
    L elbo(x) = Eq(z,c|x)[ log p(x|z) ] - KL(q(z,c|x)||p(z,c))
              = Eq(z,c|x)[ log p(x|z) + log p(z|c) + log p(c) - log q(z|x) - log q(c|x) ]
    """
    mu_c, var_c, pi = c_params; #print(mu_c.size(), var_c.size(), pi.size())
    var_c += 1e-8
    n_centroids = pi.size(1)
    mu, logvar = z_params
    mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids)
    logvar_expand = logvar.unsqueeze(2).expand(logvar.size(0), logvar.size(1), n_centroids)

    # log p(x|z)
    if binary:
        likelihood = -binary_cross_entropy(recon_x, x) #;print(logvar_expand.size()) #, torch.exp(logvar_expand)/var_c)
    else:
        likelihood = -F.mse_loss(recon_x, x)

    # log p(z|c)
    logpzc = -0.5*torch.sum(gamma*torch.sum(math.log(2*math.pi) + \
                                           torch.log(var_c) + \
                                           torch.exp(logvar_expand)/var_c + \
                                           (mu_expand-mu_c)**2/var_c, dim=1), dim=1)

    # log p(c)
    logpc = torch.sum(gamma*torch.log(pi), 1)

    # log q(z|x) or q entropy    
    qentropy = -0.5*torch.sum(1+logvar+math.log(2*math.pi), 1)

    # log q(c|x)
    logqcx = torch.sum(gamma*torch.log(gamma), 1)

    kld = -logpzc - logpc + qentropy + logqcx

    return torch.sum(likelihood), torch.sum(kld)

Thanks!

jsxlei commented 3 years ago

Hi, thanks for your insterest in SCALE.

  1. gamma=q(c|x)=p(c|z), gamma is an inference function that inferences the cluster (c) from the original data (x). However, the inference for c in the model is from the latent (z), because only z is connected with c in the model (x->z->c), thus, we replace q(c|x) with p(c|z), this can be also regarded as an approximation.

  2. q(z, c|x) = q(c|x) * q(z|x), this is because c and z can be both directly inferenced from x in their actual relationship (z<-x->c, different from the modeled relationship in the SCALE, x->z->c), thus c and z are independent condition on the observed x.

I hope these could answer your question.

caokai1073 commented 3 years ago

Thanks for your reply! But I am still a little confused. For example,

# log p(c)
logpc = torch.sum(gamma*torch.log(pi), 1)

I think it should be Eq(z,c|x)[log p(c)] = \int q(z,c|x) log p(c) dx

Why you used gamma instead of q(z,c|x) in this calculation?