hsinyilin19 / ResNetVAE

Variational AutoEncoder + ResNet Transfer Learning
218 stars 31 forks source link

The resnet VAE does not converge #1

Open bmabir17 opened 4 years ago

bmabir17 commented 4 years ago

I have used this resnet VAE to reconstruct images from a dataset (test and train splitted) but the images are not being reconstructed at all. I have used 50 epochs to train the model and test it on separate test set. Should i have to train it with more epochs? i have used the same dataset to train a 6 layers downscaling and upscaling conv2d autoencoder model and it reconstructs properly

hsinyilin19 commented 4 years ago

Thank you for your questions. First may I ask what kind of dataset you're using? My observation was that if a dataset is too complicated, eg. CIFAR10, a VAE does not recover that well. This is simply the restriction of VAE. So based on my prior experience, I would suggest you train more epochs, say 100~200, and maybe even more conv2d layers and nodes.This is the reason why I chose simpler image data as Olivetti faces and MNIST for demonstration. Please let me know if that works for you.

bmabir17 commented 4 years ago

Thank you so much for your quick response. Yes, my dataset is too complicated. I am using this https://www.mvtec.com/company/research/datasets/mvtec-ad/ So should i replace the encoding fc layers with conv2d or increase the conv2d in decode layers, starting with 512 as input? if so, can you suggest how many conv2d layers will give a close enough reconstruction ? as right now i can only see random pixels in reconstruction

bmabir17 commented 4 years ago

I have trained your implementation with 200 epochs, now only thing that is visible is a shape. loss is around 0.923 using ssim loss

hsinyilin19 commented 4 years ago

Thank you for the information. I've spotted your dataset from MVtec, I think it's quite complicated for VAE to reconstruct. In particular, if you intend to use the encoder of VAE for anomaly detection, this would be very difficult from my understanding. In a sense that VAE does not recover that well for complex image data, this is a well-reported issue. However, adding more layers (conv2d or fc) is possibly, although I don't know specifically what kinds of layers would benefit your case (sorry for that, data science is an experimental science). Or I suggest you look for other advanced technique (eg. VQVAE)

bmabir17 commented 4 years ago

Thank you for your value able insight into my problem. What i am trying to do is to conduct some architectural experiments with various auto encoder networks. VAE is just a baseline that i am trying to get results from, as close enough as possible. This is my current network https://gist.github.com/bmabir17/990762d11cd587c05ddfa211d07829b6

hsinyilin19 commented 4 years ago

Thanks for sharing the gist link of your work, it's a good thought and a good starting point for anomaly detection. I will be very interested to know if later you come up more advanced architectures.

bmabir17 commented 4 years ago

@hsinyilin19 thank you so much for pointing me towards VQVAE. I have found an implementation of it https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py Its results are providing good reconstruction for my dataset

noamzilo commented 3 years ago

Hi I noticed you said "I think it's quite complicated for VAE to reconstruct" and "this would be very difficult from my understanding. In a sense that VAE does not recover that well for complex image data, this is a well-reported issue".

what makes a dataset complicated for VAE? Maybe I should abandon VAE before I will go down a road trying to solve my problem :)

I am trying to do semi-supervised learning on 1d multi channel data. How will I be able to tell if VAE is not the correct choice?

hsinyilin19 commented 3 years ago

Hi, I think in general there is not a standard or easy way to tell if VAE will work well or not before trying... But if a dataset contains a wide variety of patterns and rich backgrounds (such as ciphar 10), it is probably too complicated and difficult to recover using VAE.