Yuantian013 / E2GAN

[ECCV 2020]"Off-Policy Reinforcement Learning for Efficient and Effective GAN Architecture Search" By Yuan Tian, Qin Wang, Zhiwu Huang, Wen Li, Dengxin Dai, Minghao Yang, Jun Wang, Olga Fink
MIT License
40 stars 12 forks source link

Training on 512x512 #1

Closed ysig closed 4 years ago

ysig commented 4 years ago

I tried to apply your code to a 512x512x3 set of images, in a real world dataset. I made several modifications to the basic code for data-loading and used the cal_fid_stat.py script from AutoGAN to generate stats for my test-set.

I modified the search.sh file to the following:

CUDA_VISIBLE_DEVICES=3 python -u search.py \
-gen_bs 16 \
-dis_bs 8 \
--dataset stl10 \
--bottom_width 4 \
--img_size 512 \
--gen_model shared_gan \
--dis_model shared_gan \
--controller controller \
--latent_dim 512 \
--gf_dim 512 \
--df_dim 256 \
--g_spectral_norm False \
--d_spectral_norm True \
--g_lr 0.0002 \
--d_lr 0.0002 \
--beta1 0.0 \
--beta2 0.9 \
--init_type xavier_uniform \
--n_critic 5 \
--val_freq 20 \
--ctrl_sample_batch 1 \
--shared_epoch 15 \
--grow_step1 15 \
--grow_step2 35 \
--max_search_iter 65 \
--ctrl_step 30 \
--random_seed 12345 \
--exp_name e2gan_search --data_path /home/user/data-E2GAN | tee search.log

I run it and I got the the following error:

search progress:   0%|                                  | 0/100 [00:35<?, ?it/s]
Traceback (most recent call last):
  File "search.py", line 227, in <module>
    main()
  File "search.py", line 155, in main
    action = Agent.select_action([layer, last_R,0.01*last_fid] + last_state,Best)
  File "/home/user/E2GAN/search/sac.py", line 60, in select_action
    action1,action2,action3,action4, action5,action6,_,_,_,_,_, _ ,_,_, _, _ ,_,_,= self.policy.sample(state)
  File "/home/user/E2GAN/search/sac_model.py", line 117, in sample
    mean_1, log_std_1,mean_2, log_std_2,mean_3, log_std_3,mean_4, log_std_4,mean_5, log_std_5,mean_6, log_std_6= self.forward(state)
  File "/home/user/E2GAN/search/sac_model.py", line 84, in forward
    x = F.relu(self.linear1(state.cuda()))
  File "/home/user/miniconda3/envs/ganspace/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/user/miniconda3/envs/ganspace/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/user/miniconda3/envs/ganspace/lib/python3.7/site-packages/torch/nn/functional.py", line 1370, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: size mismatch, m1: [1 x 515], m2: [131 x 128] at /opt/conda/conda-bld/pytorch_1573049306803/work/aten/src/THC/generic/THCTensorMathBlas.cu:290

I guess the error has to do with downscaling/upscaling convolutions, but I am not sure. I was curious if you had tried a 512px model in the past or if there is a straightforward problem you can observe in the configuration of my script file.

Thanks in advance!

qinenergy commented 4 years ago

Hi,

For E2GAN, we didn't try the search on 512x512 resolution. To make it work on 512x512 resolution, probably more cells are needed and you will need to modify the discriminator accordingly. In addition, you might need to modify the upsample blocks a bit to match your resolution in the final output.

Regarding this specific error, you will need to change Line 115 "Agent=SAC(131)" to "Agent=SAC(515)" in search/search.py. This number represents the number of channels in the state representation. It has been changed because of the new dataset and new parameters you used in the script.

Good luck on your own dataset. Qin

ysig commented 4 years ago

Thanks a lot for your really immediate answer, I really appreciate it!