yogeshbalaji / Generate_To_Adapt

Implementation of "Generate To Adapt: Aligning Domains using Generative Adversarial Networks"
https://arxiv.org/pdf/1704.01705.pdf
142 stars 33 forks source link

What params to reproduce SVHN-->MNIST? #2

Closed OswinGuai closed 6 years ago

OswinGuai commented 6 years ago

I used the default params to train and evaluate GTA. But the accuracy reach 80%+ after 100 iteratations. Please guide me how to reach a better result. Thanks ahead.

yogeshbalaji commented 6 years ago

Hi, The code will store two models - model_best (model that gives best source validation performance) and current check point. The current checkpoint usually gives better results than model best. Also, there will be fluctuations in performance. So, perform multiple runs and take an average. Also, perform early stopping at around 40-60 epochs. We found that early stopping gives good performance. As a minor point, to exactly replicate the numbers reported in the paper, you need to use 1 channel images (Although I don't think this will make a big difference). Even though the data I provided are grayscale images, pytorch dataloader loads it as 3 channel image. So, you need to make some changes in the dataloader to make it load grayscale image.

yogeshbalaji commented 6 years ago

Also, there was a small bug.. mean and std vectors were one-dimensional and they had to be 3-dimensional as the number of input channels are 3. I commited this change. I verified that after this change, I was able to get the numbers reported in the paper. Please run the code and let me know if you are able to replicate the performance.

OswinGuai commented 6 years ago

@yogeshbalaji Thanks! I just run the default params and scripts again, but havent gained a better result. What is the proper version of Torch?

OswinGuai commented 6 years ago

It seems to be OK now. One of my tests reach 91.5%, under version 0.3.1 of pytorch. But there are some warnings under version 0.4, and the result seems wrong.

yogeshbalaji commented 6 years ago

The code I wrote is for the earlier version of Pytorch. Will fix the warnings. Glad that it worked for you!