YannDubs / disentangling-vae

Experiments for understanding disentanglement in VAE latent representations
Other
795 stars 146 forks source link

Computing MIG and AAM for other datasets #63

Open DianeBouchacourt opened 3 years ago

DianeBouchacourt commented 3 years ago

I am trying to compute MIG and AAM for another dataset which has a different structure from dsprites, in the sense that the number of samples does not match the product of the size of each latent. Thus, the line fails https://github.com/YannDubs/disentangling-vae/blob/7b8285baa19d591cf34c652049884aca5d8acbca/disvae/evaluate.py#L141

since the size of samples_zCx is (len(dataset), latent_dim) but len(dataset) != *lat_sizes. Any reason why you explicitly choose to use the product of latent sizes, or should it be the length of the dataset?

Thanks !

YannDubs commented 3 years ago

Hi Diane,

Sorry, it's been a while so I don't remember all the design choices.

It's not clear to me what structure of data you have. Are you saying that you have multiple images for the same latents ? Or that you don't have all combinations of latents ?

Both cases would be doable but you'll have to generalize the computation of H[z|v] here:

https://github.com/YannDubs/disentangling-vae/blob/7b8285baa19d591cf34c652049884aca5d8acbca/disvae/evaluate.py#L299

right now the computation is very simple because it's easy to condition on v .

Justin-Tan commented 3 years ago

I've done this for a custom dataset. Here's a simplified version of what I used below. I don't remember all the details as well but essentially you discretize samples from the approximate posterior, compute the mutual information between the discretized samples and the generative factors, then do the rest of the calculation as normal.

You probably also want to transform your generative factors to a reasonable range by minmaxscaling e.g.

from sklearn.metrics import mutual_info_score

def _histogram_discretize(target, num_bins=20):
    """Discretization based on histograms."""
    discretized = np.zeros_like(target)
    for i in range(target.shape[0]):
        discretized[i, :] = np.digitize(target[i, :], np.histogram(
            target[i, :], num_bins)[1][:-1])
    return discretized

def discrete_mutual_info(z, v):
    """
    Compute discrete mutual information.
    z: array-like
        Matrix of learned latent codes, shape [LD, B] where LD is the 
        latent dimension. z is taken to be the mean value of the latent
        representation.
    """
    num_codes = z.shape[0]
    num_factors = v.shape[0]
    mi = np.zeros([num_codes, num_factors])
    for i in range(num_codes):
        for j in range(num_factors):
            mi[i, j] = mutual_info_score(v[j, :], z[i, :])
    return mi

def discrete_entropy(v):
    """
    Compute discrete mutual information.
    v: array-like
        Matrix of underlying generative factors, shape [N,B]
        for N ground truth factors
    """
    num_factors = v.shape[0]
    H = np.zeros(num_factors)
    for j in range(num_factors):
        H[j] = mutual_info_score(v[j, :], v[j, :])
    return H

def estimate_MIG_discrete(qzCx_samples, gen_factors):

    z_discrete = _histogram_discretize(qzCx_samples, num_bins=20)
    mi_discrete = discrete_mutual_info(np.transpose(z_discrete), np.transpose(gen_factors))
    H_v = discrete_entropy(np.transpose(gen_factors))
    sorted_mi_discrete = np.sort(mi_discrete, axis=0)[::-1]
    discrete_metric = np.mean(np.divide(sorted_mi_discrete[0, :] - sorted_mi_discrete[1, :],H_v[:]))

    top_mi_normed_discrete = np.divide(sorted_mi_discrete[0,:], H_v)
    top_mi_idx_discrete = np.argsort(mi_discrete, axis=0)[::-1][0,:]

    return discrete_metric, top_mi_normed_discrete, top_mi_idx_discrete
DianeBouchacourt commented 3 years ago

I have multiple images of the same latents (other factors are unknown, and can differ between images but I don't have ground-truth for these, e.g. multiple images of the same person wearing glasses where {glasses, identity} are known but the rest differ).

Thanks for the reply I'll have a look at H[z|v] and at Justin's solution, I'll tell you if I am stuck somewhere :)