mit-han-lab / gan-compression

[CVPR 2020] GAN Compression: Efficient Architectures for Interactive Conditional GANs
Other
1.1k stars 150 forks source link

Distill Problem #106

Closed saijo0404 closed 2 years ago

saijo0404 commented 2 years ago

I tried to train a pix2pix model on the edges2shoes-r dataset using train_full.sh.

#!/usr/bin/env bash
python distill.py --dataroot database/edges2shoes-r \
  --distiller resnet \
  --log_dir logs/pix2pix/edges2shoes-r/distill \
  --batch_size 4 \
  --restore_teacher_G_path logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_G.pth \
  --restore_pretrained_G_path logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_G.pth \
  --pretrained_netG resnet_9blocks \
  --teacher_netG resnet_9blocks \
  --student_netG resnet_9blocks \
  --restore_D_path logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_D.pth \
  --real_stat_path real_stat/edges2shoes-r_B.npz \
  --meta_path datasets/metas/edges2shoes-r/train1.meta 

After training, I used this bash, but I get an AssertionError. In weight_transfer.py line 14, in transfer_Conv2d assert isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)) How can I solve this problem?

lmxyy commented 2 years ago

Could you provide some more information? What is the type of your m1 and m2?

saijo0404 commented 2 years ago

I try to print m1 and m2 type, the result look like this.

distiller [ResnetDistiller] was created
Load network at logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_G.pth
isinstance(netA, nn.DataParallel):  False
isinstance(netB, nn.DataParallel):  False
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)):  True
m1 type:  <class 'torch.nn.modules.conv.Conv2d'>
m2 type:  <class 'torch.nn.modules.conv.Conv2d'>
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)):  True
m1 type:  <class 'torch.nn.modules.conv.Conv2d'>
m2 type:  <class 'torch.nn.modules.conv.Conv2d'>
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)):  True
m1 type:  <class 'torch.nn.modules.conv.Conv2d'>
m2 type:  <class 'torch.nn.modules.conv.Conv2d'>
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)):  False
m1 type:  <class 'models.modules.resnet_architecture.resnet_generator.ResnetBlock'>
m2 type:  <class 'models.modules.resnet_architecture.resnet_generator.ResnetBlock'>
lmxyy commented 2 years ago

I see. This is a minor bug in weight_transfer.py because of a typo. I've fixed it in this commit. Could you pull the latest commit and try again?

lmxyy commented 2 years ago

I will close this issue. Let me know if there are some further issues!