NVlabs / DeepInversion

Official PyTorch implementation of Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion (CVPR 2020)
Other
488 stars 79 forks source link

Reproducibility on CIFAR-10 #2

Closed xq1970 closed 4 years ago

xq1970 commented 4 years ago

Hi, I followed Section 4.1 of the paper on CIFAR-10 but only got results like this after 5k iterations (convergence): output_00050_gpu_0

I have modified these parts compared with the code for ImageNet: --r_feature=2.5e-2 --tv_l2 2e-4 --l2 2e-5 --lr 0.01 -setting_id=2 --random_label --main_loss_multiplier 0.25

The idea for setting the hyperparameters is to balance different losses in similar magnitude as in the released code for ImageNet. I used the pre-trained ResNet-34 with 95.5% accuracy trained by https://github.com/mbsariyildiz/resnet-pytorch I have also tried the same hyperparameters as indicated in Section 4.1 of the paper but still couldn't get better synthesized images.

Additionally, the image resolution has been set to 32, and I trained the model without adaptive inversion.

Are there any other changes that should be made in the code for CIFAR-10? Thank you for your help!

Here is the training log for the first 2k iterations. The printed losses have been multiplied by coefficients, and I used the same pre-trained ResNet-34 as the verifier:

------------iteration 100---------- total loss 1.5887961387634277 loss_r_feature 0.8383388519287109 main criterion 0.20274938642978668 loss var l2 0.54662197265625 loss l2 0.001085866928100586 Verifier accuracy: 67.85713958740234 ------------iteration 200---------- total loss 1.1957563161849976 loss_r_feature 0.6756953239440918 main criterion 0.00945814698934555 loss var l2 0.50958203125 loss l2 0.0010206340789794923 Verifier accuracy: 100.0 ------------iteration 300---------- total loss 1.0136175155639648 loss_r_feature 0.5451815128326416 main criterion 0.002356192795559764 loss var l2 0.46513969726562504 loss l2 0.000940232162475586 Verifier accuracy: 100.0 ------------iteration 400---------- total loss 0.8616381883621216 loss_r_feature 0.4419695377349854 main criterion 0.0011746699456125498 loss var l2 0.41763930664062504 loss l2 0.0008546984863281251 Verifier accuracy: 100.0 ------------iteration 500---------- total loss 0.7258148789405823 loss_r_feature 0.354432487487793 main criterion 0.0014118807157501578 loss var l2 0.36920244140625 loss l2 0.0007680741882324219 Verifier accuracy: 100.0 ------------iteration 600---------- total loss 0.6115573644638062 loss_r_feature 0.284948205947876 main criterion 0.004634861368685961 loss var l2 0.32129179687500004 loss l2 0.0006825406646728516 Verifier accuracy: 100.0 ------------iteration 700---------- total loss 0.5235512256622314 loss_r_feature 0.24538052082061768 main criterion 0.000998093979433179 loss var l2 0.276566259765625 loss l2 0.000606338119506836 Verifier accuracy: 100.0 ------------iteration 800---------- total loss 0.46245500445365906 loss_r_feature 0.22423591613769533 main criterion 0.0008074002689681947 loss var l2 0.2368692626953125 loss l2 0.0005424297332763672 Verifier accuracy: 100.0 ------------iteration 900---------- total loss 0.4182308316230774 loss_r_feature 0.2140077829360962 main criterion 0.000744551420211792 loss var l2 0.2029859130859375 loss l2 0.0004926132965087891 Verifier accuracy: 100.0 ------------iteration 1000---------- total loss 0.38464871048927307 loss_r_feature 0.20906946659088135 main criterion 0.0006658562342636287 loss var l2 0.1744591552734375 loss l2 0.00045424114227294926 Verifier accuracy: 100.0 ------------iteration 1100---------- total loss 0.3591064214706421 loss_r_feature 0.2067859172821045 main criterion 0.0003871832450386137 loss var l2 0.151506787109375 loss l2 0.0004265190887451172 Verifier accuracy: 100.0 ------------iteration 1200---------- total loss 0.3402350842952728 loss_r_feature 0.20384662151336672 main criterion 0.0008285769145004451 loss var l2 0.13515140380859375 loss l2 0.0004085089111328125 Verifier accuracy: 100.0 ------------iteration 1300---------- total loss 0.32719239592552185 loss_r_feature 0.2029275894165039 main criterion 0.0004118723445571959 loss var l2 0.12345615234375 loss l2 0.0003967760086059571 Verifier accuracy: 100.0 ------------iteration 1400---------- total loss 0.32120072841644287 loss_r_feature 0.20393846035003663 main criterion 0.0007077441550791264 loss var l2 0.11616383056640625 loss l2 0.00039069297790527346 Verifier accuracy: 100.0 ------------iteration 1500---------- total loss 0.31567034125328064 loss_r_feature 0.20350909233093262 main criterion 0.00039401365211233497 loss var l2 0.11137900390625001 loss l2 0.00038823162078857425 Verifier accuracy: 100.0 ------------iteration 1600---------- total loss 0.31221890449523926 loss_r_feature 0.20267820358276367 main criterion 0.000622493855189532 loss var l2 0.10852938232421876 loss l2 0.0003888259506225586 Verifier accuracy: 100.0 ------------iteration 1700---------- total loss 0.31075066328048706 loss_r_feature 0.20341982841491701 main criterion 0.0003112043777946383 loss var l2 0.10662841796875 loss l2 0.00039120994567871096 Verifier accuracy: 100.0 ------------iteration 1800---------- total loss 0.30613452196121216 loss_r_feature 0.199886953830719 main criterion 0.0005381234805099666 loss var l2 0.10531446533203126 loss l2 0.0003950079727172852 Verifier accuracy: 100.0 ------------iteration 1900---------- total loss 0.30858731269836426 loss_r_feature 0.20306947231292727 main criterion 0.0002929923066403717 loss var l2 0.10482397460937501 loss l2 0.00040085716247558595 Verifier accuracy: 100.0 ------------iteration 2000---------- total loss 0.30381497740745544 loss_r_feature 0.1986505150794983 main criterion 0.0004578616062644869 loss var l2 0.104299462890625 loss l2 0.00040712715148925784 Verifier accuracy: 100.0

slala2121 commented 4 years ago

I also tried inverting ResNet-34 with accuracy ~92% on CIFAR10, using the same parameters as in the paper. The loss has converged, the teacher network attains 100% accuracy on synthesized images, but the images don't appear to be comparable in quality (below).

Are there other changes that need to be made? Thanks.

Cat: id_033

Dog:

id_217

Horse: id_062

Here's a sample loss plots. Seems like there's a trade off between the other losses and l2 as the iterations increase. Is this typical?

history.pdf

sccbhxc commented 4 years ago

I also tried inverting ResNet-34 with accuracy ~92% on CIFAR10, using the same parameters as in the paper. The loss has converged, the teacher network attains 100% accuracy on synthesized images, but the images don't appear to be comparable in quality (below).

Are there other changes that need to be made? Thanks.

Cat: id_033

Dog:

id_217

Horse: id_062

Here's a sample loss plots. Seems like there's a trade off between the other losses and l2 as the iterations increase. Is this typical?

history.pdf

Could you share the code with me?

hongxuyin commented 4 years ago

Hi, the hypers in paper for CIFAR-10 are for scheme without multi-resolution, in setting_id 1 or 2. If networks/setting_id are different, a quick search for hyper will work: r_feature, lr, iterations, etc.

pamolchanov commented 4 years ago

We have updated the code with an example of DeepInversion for ResNet34 on CIFAR10. We did not optimize hyper parameters and there is a room to improve the visual quality of images.

pamolchanov commented 4 years ago

I also tried inverting ResNet-34 with accuracy ~92% on CIFAR10, using the same parameters as in the paper. The loss has converged, the teacher network attains 100% accuracy on synthesized images, but the images don't appear to be comparable in quality (below). Are there other changes that need to be made? Thanks. Cat:

Dog:

Horse:

Here's a sample loss plots. Seems like there's a trade off between the other losses and l2 as the iterations increase. Is this typical? history.pdf

Could you share the code with me?

Please see the example in the folder cifar10