YannDubs / disentangling-vae

Experiments for understanding disentanglement in VAE latent representations
Other
785 stars 143 forks source link

Low MIG and AAM metrics #52

Open Justin-Tan opened 4 years ago

Justin-Tan commented 4 years ago

Hello,

Firstly, just wanted to state that this is a great repo with a very understandable code base!

I seem to be getting extremely low MIG / AAM scores (around 1e-3 to 1e-2) when training with any of the pretrained models, even using the recommended hyperparams in the .ini file in the main directory. Is this something you were noticing in your own tests?

Visual inspection of the traversals in DSprites seem to show that the network is learning quite disentangled representations (attached, with rows arranged in order of descending KL-divergence from Gaussian prior), so I am quite confused as to why the MIG score is so low.

Even introducing supervision (matching latent factors to generative factors, the maximum MIG score I have been able to attain is around 0.01, but AAM is a lot higher, at around 0.6 for the model that produced the attached latent traversals.

Cheers, Justin

traversals

YannDubs commented 4 years ago

The small MIG is definitely (and unfortunately) something we always had in our experiments. Importantly, I got the same results when using the author's implementation. This is one of the reason we introduced AAM, which measures only the disentanglement rather than disentanglement + amount of information of v about z. I am surprised you get small AAM though.

Here are the results we were getting :

Screen Shot 2019-11-24 at 4 12 12 AM

We see that when increasing β by a small amount (from 1 to 4), highly increases axis alignment (from20% to 65%) due to the regularisation of the total correlation, while increasing β by a large amount (from 4 to 50) decreases axis alignment due to the penalisation of the dimension wise KL. I.e. it is not monotonic.

Justin-Tan commented 4 years ago

Thanks for the confirmation, this is slightly puzzling as the authors of beta-TCVAE report MIG scores of O(0.1) in their paper. Other third-party implementations also report MIG of O(0.1), e.g. "Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations". I might try emailing the authors directly to see how they obtained their results.

metrics

On Sun, Nov 24, 2019 at 8:21 PM Yann Dubois notifications@github.com wrote:

The small MIG is definitely (and unfortunately) something we always had in our experiments. Importantly, I got the same results when using (the author's implementation)[https://github.com/rtqichen/beta-tcvae]. This is one of the reason we introduced AAM, which measures only the disentanglement rather than disentanglement + amount of information of v about z. I am surprised you get small AAM though.

Here are the results we were getting :

[image: Screen Shot 2019-11-24 at 4 12 12 AM] https://user-images.githubusercontent.com/24327668/69492500-a0cb6900-0e71-11ea-9cea-fd7ad8912912.png

We see that when increasing β by a small amount (from 1 to 4), highly increases axis alignment (from20% to 65%) due to the regularisation of the total correlation, while increasing β by a large amount (from 4 to 50) decreases axis alignment due to the penalisation of the dimension wise KL. I.e. it is not monotonic.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/YannDubs/disentangling-vae/issues/52?email_source=notifications&email_token=AGRNY6HZANTIBXI4WQFAQJLQVJBQXA5CNFSM4JLL55R2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEFAHBRI#issuecomment-557871301, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGRNY6FEEHKLWLYGZSITEGDQVJBQXANCNFSM4JLL55RQ .

YannDubs commented 4 years ago

Yes it is, if you get an answer / insights please post it here. I would be interested + other people might be.

And just to be clear, I have not tried rerunning the authors code. I only tried using their MIG code to compute the MIG for our results :/ . I.e. it does not seem that the issue comes from the computation of MIG, but to be honest I have not spent too much time on MIG as this was a late addition before a deadline.

Justin-Tan commented 4 years ago

After some digging I am getting better results using the author's MIG calculation code - around 0.3-0.8 for most of my trained models. Perhaps the problem lies in shuffling the dataloader? I notice when I shuffle the dataloader I get a very low MIG (on dSprites).

# Load dataloader
all_loader = (..., shuffle=False)

vae = model
N = len(all_loader.dataset)     # number of data samples - don't shuffle
K = vae.latent_dim              # number of latent variables
nparams = 2
vae.eval()

qzCx_params = torch.Tensor(N, K, nparams)

n = 0

with torch.no_grad():
    for x, gen_factors in all_loader:
        batch_size = x.size(0)
        x, gen_factors = x.to(device, dtype=torch.float), gen_factors.to(device)
        qzCx_stats = torch.stack(vae.encoder(x)['continuous'], dim=2)
        qzCx_params[n:n + batch_size] = qzCx_stats.data
        n += batch_size

# Reshape to get known generative factors
qzCx_params = qzCx_params.view(3, 6, 40, 32, 32, K, nparams).to(device)

# Sample from diagonal Gaussian posterior q(z|x) using given parameters (mu, logvar)
qzCx_samples = qzCx_sample(params=qzCx_params)

I think the reshape on the second last line requires the dataset to be in the native order so that the generative factors are in the correct order - it's not obvious that they should be though, this is a quirk of the dSprites dataset.

YannDubs commented 4 years ago

Thanks for digging into it. What exactly do you mean by shuffling ? We do no shuffle the test loader ( https://github.com/YannDubs/disentangling-vae/blob/a54b794dcf816a3892a1960b2aa1c9900cb09a16/main.py#L235 ) if that's your point.

BTW : I'm more than happy to accept PRs

Justin-Tan commented 4 years ago

Yeah, that's what I mean about the shuffling, thanks for confirming, so that is not the cause of it. I am still confused because AFAIK this repository and the author's code appear to be doing the exact same thing when calculating MIG (on dSprites at least) but this repository is giving much lower MIG when loading the same models. I'll look more into it over this weekend.

I also looked into the discrete estimation of MIG used in 1, appendix C. (Essentially discretize samples from z in ~20 bins and use sklearn to estimate the discrete MI b/w latents z and generative factors v.) Unfortunately it does not agree with the MIG computed using this sampling-based estimation (consistently lower, reasonably insensitive to number of bins) irrespective of whether we use the mean of the latent representation or we sample from q(z|x), so the jury is still out on how to best estimate MIG I suppose.

Edit: It seems like the MIG scores reported in 1 are consistently lower anyway around 0.2 for the best models, so perhaps this is expected. '