thwjoy / ccvae

The official codebase for Capturing label characteristics in VAEs
13 stars 4 forks source link

Application to MNIST dataset #1

Open bloomingfield opened 2 years ago

bloomingfield commented 2 years ago

Hi Tom!

This is a really great paper, and a really cool way to apply VAE's. I've just started my own PhD, and I've been working towards reproducing your results from table 6 in the paper, where this method is applied in the context of multi-class classification. I've made some headway, but I can't quite reach the same level of performance as you did! Would you be happy to also share the code that was used for that dataset? If it's easier to just share the pyro code, I can work with that too.

Thanks so much! Nathaniel

thwjoy commented 2 years ago

Hi Nathaniel,

Thank you! To obtain the best results for multi-class classification you need to enumerate out y, i.e. just sum all of the values rather than sampling if for unsupervised case. You can do this quire easily in Pyro using enuemrate https://pyro.ai/examples/ss-vae.html?highlight=semi%20supervised. What sort of accuracy are you getting? I can share the code with you if you wish, but it's a big mess and I'd rather keep it private if I can.

Best Wishes, Tom

On Fri, 4 Mar 2022 at 11:56, Nathaniel Bloomfield @.***> wrote:

Hi Tom!

This is a really great paper, and a really cool way to apply VAE's. I've just started my own PhD, and I've been working towards reproducing your results from table 6 in the paper, where this method is applied in the context of multi-class classification. I've made some headway, but I can't quite reach the same level of performance as you did! Would you be happy to also share the code that was used for that dataset? If it's easier to just share the pyro code, I can work with that too.

Thanks so much! Nathaniel

— Reply to this email directly, view it on GitHub https://github.com/thwjoy/ccvae/issues/1, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFUQ3PPJGYZ5ZMGIJWUB763U6H26ZANCNFSM5P5K54KA . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

You are receiving this because you are subscribed to this thread.Message ID: @.***>

bloomingfield commented 2 years ago

Hi Tom

Thanks for getting back to me so quickly! I'll keep in mind enumerating over the y labels, and no worries at all about your code. Maybe I could ask a few questions?

I also noticed that on line 103 in ccvae.py in the pytorch version of this repo, you use both logqyzc (where Z gradient is detached) and log_qyzc (where Z gradient is not detached). Is this at odds with equation 8 in the paper and the Pyro repo? Or I might have missed something...

I'm currently getting ~85% accuracy for 0.4% labelled, ~96% for 6% labelled, and ~98% for 100% labelled on MNIST. I'm most likely using very different architectures to you though

bloomingfield commented 2 years ago

Hi Tom

Thanks for your help with this! I've put together a pyro best guess in a fork

https://github.com/bloomingfield/ccvae

I've just upscaled the MNIST data so that the same networks and hyperparameters can be used as for CELEBA. Are there any additional tricks that were used to get that last bit of accuracy? Particularly in the f=0.004 labelled case, this code seems to struggle (even with enumerating out y).