hlzhang109 / DDG

MIT License
57 stars 7 forks source link

cannot use resnet-50 model #11

Closed danni9594 closed 2 years ago

danni9594 commented 2 years ago

Hi,

I encountered bugs when I try to switch from resnet-18 to resnet-50 on PACS dataset using DDG (i.e., set 'resnet18' to False in hparams_registry.py). Problem seems to appear in the AdaINGen model. Can you please upload a version that can run resnet-50 without any bugs so that we can reproduce the results in the paper? Thanks!

yfzhang114 commented 2 years ago

Can you share the detailed error or the log? It seems ok for me to use resnet-50, which is also the default backbone for PACS.

danni9594 commented 2 years ago

Hi,

Thank you for attending to my issue so promptly! When I run the following as suggested in the readme file without changing anything, the code runs just fine

python train.py\ --data-dir /my/datasets/path\ --algorithm DDG\ --dataset PACS\ --stage 0

But that is because the default value for the 'resnet18' hyperparameter in the hparams_registry.py file is True, which means that by running the above we are using resnet18 for the backbone. When I change it to False (switching to resnet50), I encounter the following error:

Traceback (most recent call last): File "train.py", line 239, in step_vals = algorithm.update(minibatches_device, minibatches_device_neg, pretrain_model=algorithm_copy) File "/home/danni/DDG_orig/DDG/algorithms.py", line 1276, in update x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p = self.forward(images_a, images_b, pos_a, pos_b) File "/home/danni/DDG_orig/DDG/algorithms.py", line 1115, in forward x_ba = self.gen.decode(s_b, f_a) # x_ba: generated from identity of a and style of b File "/home/danni/DDG_orig/DDG/networks.py", line 483, in decode adain_params_w = torch.cat( (self.mlp_w1(ID1), self.mlp_w2(ID2), self.mlp_w3(ID3), self.mlp_w4(ID4)), 1) File "/home/danni/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "/home/danni/DDG_orig/DDG/networks.py", line 662, in forward return self.model(x.view(x.size(0), -1)) File "/home/danni/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "/home/danni/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward input = module(input) File "/home/danni/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/danni/DDG_orig/DDG/networks.py", line 959, in forward out = self.fc(x) File "/home/danni/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/danni/anaconda3/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 103, in forward return F.linear(input, self.weight, self.bias) File "/home/danni/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 1848, in linear return torch._C._nn.linear(input, weight, bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (6x512 and 2048x512)

It seems that something is wrong with the gen model, which is an instance of AdaINGen.

yfzhang114 commented 2 years ago

It seems ok now. But you do not need resnet50 as an image encoder during training the GAN, the default image encoder for our GANs is resnet18.

danni9594 commented 2 years ago

Hi Yifan,

Thanks for addressing the issue so quickly! The code runs perfectly fine with resnet50 now!

And regarding your suggestion that

you do not need resnet50 as an image encoder during training the GAN, the default image encoder for our GANs is resnet18.

I think the resnet18 hyperparameter is for the main backbone featurizer. In the DDG algorithm, it corresponds to the id_featurizer, which I believe is pre-trained and reused in stage 1, as from line 194 in the main train.py file

state_dict = {k: v for k, v in pretext_model.items() if k in alg_dict.keys() and ('id_featurizer' in k or 'gen' in k)}

And what's more, I've actually completed training for DDG (stage 0 + stage 1) on PACS, VLCS and TerraIncognita by using resnet18 (i.e., hparams['resnet18']=True). The performance is much worse than those reported in the DDG paper, and close to the resnet18 performance reported in other papers. Hence, I believe using resnet50 is necessary to reproduce the results.

Will let you know if I encounter further problems. Thanks again for the very prompt problem shooting!!!

yfzhang114 commented 2 years ago

It's been so long that I can't remember the details of the implementation a bit, I re-run the code and you should be right. In addition to this, I have four suggestions for recapitulating the results we reported.

  1. To check whether the environment is ok. You can reproduce the result of ERM to attain this goal.
  2. Note that we use test-domain validation, where we choose the model maximizing the accuracy on a validation set that follows the same distribution of the test domain.
  3. The hyper-parameters should be adjusted and the default parameters are not the best for all datasets.
  4. Because of the use of additional GANs, we use a small batch size and the training step is enlarged. A larger batch size may supply more stable gradients and you can try it instead of extending training steps.
danni9594 commented 2 years ago

Hi Yifan.

Thanks for the suggestions! I'm actually aware of most of them (i.e., the results reported are oracle validation, smaller batch size and a larger number of training steps are used). Unfortunately, due to resource constraints, I also cannot increase the batch size for more stable optimization :P. So I just leave it as it is~

Anyways, I managed to reproduce roughly similar results from the paper. Thank you for all your support along the way! This indeed is a solid and inspiring work! Thank you for your contribution to the community and wish you all the best in your future endeavors!