elvisyjlin / AttGAN-PyTorch

AttGAN PyTorch Arbitrary Facial Attribute Editing: Only Change What You Want
MIT License
250 stars 61 forks source link

Bug in Encoding and Decoding Parameters #25

Open EoinKenny opened 3 years ago

EoinKenny commented 3 years ago

Hi I've been messing around with this code and found an error.

If I use these params in the AttGAN class

    def __init__(self, enc_dim=64, enc_layers=6, enc_norm_fn='batchnorm', enc_acti_fn='lrelu',
                 dec_dim=64, dec_layers=6, dec_norm_fn='batchnorm', dec_acti_fn='relu',
                 n_attrs=1, shortcut_layers=1, inject_layers=0, img_size=128):

I get the following error

-----------------------------------------------
RuntimeError  Traceback (most recent call last)
<ipython-input-274-d001aaf018ea> in <module>
----> 1 netG(imgs, a).shape

~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/MNIST/Final Experiments/Substitutability Test/senv/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

<ipython-input-267-7eb972155d92> in forward(self, x, a, mode)
     61         if mode == 'enc-dec':
     62             assert a is not None, 'No given attribute.'
---> 63             return self.decode(self.encode(x), a)
     64         if mode == 'enc':
     65             return self.encode(x)

<ipython-input-267-7eb972155d92> in decode(self, zs, a)
     47         z = torch.cat([zs[-1], a_tile], dim=1)
     48         for i, layer in enumerate(self.dec_layers):
---> 49             z = layer(z)
     50             if self.shortcut_layers > i:  # Concat 1024 with 512
     51                 print(z.shape, zs[len(self.dec_layers) - 2 - i].shape)

~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/MNIST/Final Experiments/Substitutability Test/senv/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

<ipython-input-43-fa89d761303c> in forward(self, x)
    189 
    190         def forward(self, x):
--> 191                 return self.layers(x)
    192 
    193 

~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/MNIST/Final Experiments/Substitutability Test/senv/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/MNIST/Final Experiments/Substitutability Test/senv/lib/python3.7/site-packages/torch/nn/modules/container.py in forward(self, input)
    115     def forward(self, input):
    116         for module in self:
--> 117             input = module(input)
    118         return input
    119 

~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/MNIST/Final Experiments/Substitutability Test/senv/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/MNIST/Final Experiments/Substitutability Test/senv/lib/python3.7/site-packages/torch/nn/modules/conv.py in forward(self, input, output_size)
    905         return F.conv_transpose2d(
    906             input, self.weight, self.bias, self.stride, self.padding,
--> 907             output_padding, self.groups, self.dilation)
    908 
    909 

RuntimeError: Given transposed=1, weight of size [1536, 1024, 4, 4], expected input[32, 2048, 4, 4] to have 1536 channels, but got 2048 channels instead

Any idea how to fix this?

It seems that if you are going to use shortcut layers, you cannot have enc_layers/dec_layers bigger than 5. As I would like to train a version which encodes into a 1D vector, it's a big problem.

elvisyjlin commented 3 years ago

Hi @EoinKenny, the problem seems not related to shortcut layers. I tried training with 6 layers of enc_layers and dec_layers an d it worked. But it crashed when I changed the n_attrs to 1. Why do you want to train a conditional GAN with single condition? I mean, it is impossible to train the discriminator classifier with only 1 label because it cannot form a cross entropy loss with single label.

In attgan.py,

def __init__(self, enc_dim=64, enc_layers=6, enc_norm_fn='batchnorm', enc_acti_fn='lrelu',
                 dec_dim=64, dec_layers=6, dec_norm_fn='batchnorm', dec_acti_fn='relu',
                 n_attrs=13, shortcut_layers=1, inject_layers=0, img_size=128):

The output is

Namespace(attr_path='data/list_attr_celeba.txt', attrs=['Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows', 'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'], b_distribution='none', batch_size=32, beta1=0.5, beta2=0.999, data='CelebA', data_path='data/img_align_celeba', dec_acti='relu', dec_dim=64, dec_layers=5, dec_norm='batchnorm', dis_acti='lrelu', dis_dim=64, dis_fc_acti='relu', dis_fc_dim=1024, dis_fc_norm='none', dis_layers=5, dis_norm='instancenorm', enc_acti='lrelu', enc_dim=64, enc_layers=5, enc_norm='batchnorm', epochs=200, experiment_name='test', gpu=False, image_list_path='data/image_list.txt', img_size=128, inject_layers=0, lambda_1=100.0, lambda_2=10.0, lambda_3=1.0, lambda_gp=10.0, lr=0.0002, mode='wgan', multi_gpu=False, n_d=5, n_samples=16, num_workers=0, sample_interval=1000, save_interval=1000, shortcut_layers=1, test_int=1.0, thres_int=0.5)
Training images: 182000 / Validating images: 637
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1            [4, 64, 64, 64]           3,072
       BatchNorm2d-2            [4, 64, 64, 64]             128
         LeakyReLU-3            [4, 64, 64, 64]               0
       Conv2dBlock-4            [4, 64, 64, 64]               0
            Conv2d-5           [4, 128, 32, 32]         131,072
       BatchNorm2d-6           [4, 128, 32, 32]             256
         LeakyReLU-7           [4, 128, 32, 32]               0
       Conv2dBlock-8           [4, 128, 32, 32]               0
            Conv2d-9           [4, 256, 16, 16]         524,288
      BatchNorm2d-10           [4, 256, 16, 16]             512
        LeakyReLU-11           [4, 256, 16, 16]               0
      Conv2dBlock-12           [4, 256, 16, 16]               0
           Conv2d-13             [4, 512, 8, 8]       2,097,152
      BatchNorm2d-14             [4, 512, 8, 8]           1,024
        LeakyReLU-15             [4, 512, 8, 8]               0
      Conv2dBlock-16             [4, 512, 8, 8]               0
           Conv2d-17            [4, 1024, 4, 4]       8,388,608
      BatchNorm2d-18            [4, 1024, 4, 4]           2,048
        LeakyReLU-19            [4, 1024, 4, 4]               0
      Conv2dBlock-20            [4, 1024, 4, 4]               0
  ConvTranspose2d-21            [4, 1024, 8, 8]      16,990,208
      BatchNorm2d-22            [4, 1024, 8, 8]           2,048
             ReLU-23            [4, 1024, 8, 8]               0
ConvTranspose2dBlock-24            [4, 1024, 8, 8]               0
  ConvTranspose2d-25           [4, 512, 16, 16]      12,582,912
      BatchNorm2d-26           [4, 512, 16, 16]           1,024
             ReLU-27           [4, 512, 16, 16]               0
ConvTranspose2dBlock-28           [4, 512, 16, 16]               0
  ConvTranspose2d-29           [4, 256, 32, 32]       2,097,152
      BatchNorm2d-30           [4, 256, 32, 32]             512
             ReLU-31           [4, 256, 32, 32]               0
ConvTranspose2dBlock-32           [4, 256, 32, 32]               0
  ConvTranspose2d-33           [4, 128, 64, 64]         524,288
      BatchNorm2d-34           [4, 128, 64, 64]             256
             ReLU-35           [4, 128, 64, 64]               0
ConvTranspose2dBlock-36           [4, 128, 64, 64]               0
  ConvTranspose2d-37           [4, 3, 128, 128]           6,147
             Tanh-38           [4, 3, 128, 128]               0
ConvTranspose2dBlock-39           [4, 3, 128, 128]               0
================================================================
Total params: 43,352,707
Trainable params: 43,352,707
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 9.75
Forward/backward pass size (MB): 186.50
Params size (MB): 165.38
Estimated Total Size (MB): 361.63
----------------------------------------------------------------
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1            [4, 64, 64, 64]           3,072
    InstanceNorm2d-2            [4, 64, 64, 64]             128
         LeakyReLU-3            [4, 64, 64, 64]               0
       Conv2dBlock-4            [4, 64, 64, 64]               0
            Conv2d-5           [4, 128, 32, 32]         131,072
    InstanceNorm2d-6           [4, 128, 32, 32]             256
         LeakyReLU-7           [4, 128, 32, 32]               0
       Conv2dBlock-8           [4, 128, 32, 32]               0
            Conv2d-9           [4, 256, 16, 16]         524,288
   InstanceNorm2d-10           [4, 256, 16, 16]             512
        LeakyReLU-11           [4, 256, 16, 16]               0
      Conv2dBlock-12           [4, 256, 16, 16]               0
           Conv2d-13             [4, 512, 8, 8]       2,097,152
   InstanceNorm2d-14             [4, 512, 8, 8]           1,024
        LeakyReLU-15             [4, 512, 8, 8]               0
      Conv2dBlock-16             [4, 512, 8, 8]               0
           Conv2d-17            [4, 1024, 4, 4]       8,388,608
   InstanceNorm2d-18            [4, 1024, 4, 4]           2,048
        LeakyReLU-19            [4, 1024, 4, 4]               0
      Conv2dBlock-20            [4, 1024, 4, 4]               0
           Linear-21                  [4, 1024]      16,778,240
             ReLU-22                  [4, 1024]               0
      LinearBlock-23                  [4, 1024]               0
           Linear-24                     [4, 1]           1,025
      LinearBlock-25                     [4, 1]               0
           Linear-26                  [4, 1024]      16,778,240
             ReLU-27                  [4, 1024]               0
      LinearBlock-28                  [4, 1024]               0
           Linear-29                    [4, 13]          13,325
      LinearBlock-30                    [4, 13]               0
================================================================
Total params: 44,718,990
Trainable params: 44,718,990
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 62.19
Params size (MB): 170.59
Estimated Total Size (MB): 233.53
----------------------------------------------------------------
  0%|                       | 11/5687 [01:10<10:38:33,  6.75s/it, d_loss=-23.9, epoch=0, g_loss=103, iter=6]
EoinKenny commented 3 years ago

@elvisyjlin Thanks for the response.

I should clarify some stuff. I am looking to adapt the code for explainable AI purposes, in particular to generate counterfactuals for a pre-trained CNN.

So my CNN has two outputs, one for class 1 and another for class 2.

The only code I am borrowing from this repo is the actual generator, and all imports necessary to instantiate the class. So for my purposes I am currently messing around with a single attribute which represents the class I'd like to make counterfactuals with.

Does that make sense? Was the error you got the same as mine? If so, is there a way to make the code run with the parameters I have in my original post?

Thank you.

elvisyjlin commented 3 years ago

Sorry for my late reply. I understand your doing something related to explanable AI. If the class 1 represents content and class 2 represents style, it sounds like a topic about disentanglement learning. For me, it is more like what DRIP or MUNIT does.

The AttGAN proposed originally says it can learn from multiple sets of conditions. Class 1 is smiling, gender, hair color, glasses, etc. Class 2 is the color of glasses, class 3 is the style of hair, and so on. It is slightly different from what I mentioned above. However, I did not implement this feature. If this is what you are looking for, please refer to the original paper or the original author. By the way, I'm just a passionate nobody who implements the proposed network in pytorch.