DSE-MSU / DeepRobust

A pytorch adversarial library for attack and defense methods on images and graphs
MIT License
994 stars 192 forks source link

[Bug] The CNN model is not working with MNIST data #71

Closed zhjwy9343 closed 3 years ago

zhjwy9343 commented 3 years ago

The CNN model has an FC layer with input dim as int(self.H/4) int(self.W/4) self.out_channel2, which is 7 7 64 according to the MNIST image size. But when using the MNIST data through CNN's conv layers, the size becomes 8 8 64, which will cause mat size mismatch error.

YaxinLi0-0 commented 3 years ago

Hi! Thanks for reporting this to us. Could you provide more detail of your error message? We can calculate the shape of tensor with the following formulation: (1 + 2p -k) / s +1. So after it passes the first conv layer, the tensor shape is 32 28 28. Next the tensor passes the pooling layer, it becomes 32 14 14. Then the second conv layer and pooling layer. Finally we get 64 7 7. So I think there is no mismatch for the size.

zhjwy9343 commented 3 years ago

I am sorry to get you confused in my issue description.

Here comes the error: RuntimeError: Error(s) in loading state_dict for Net: size mismatch for fc1.weight: copying a param with shape torch.Size([1024, 3136]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).

This bug occurs at the line #38 in the deeprobust.image.netmodels.CNN.py where the fc1 is defined. self.fc1 = nn.Linear(8 8 out_channel2, 1024)

But just as you calculated, the actual size should be 7 7 as in the line #46, instead of 8 8. Right? x = x.view(-1, int(self.H/4) int(self.W/4) self.out_channel2)

I copied the Net class out and changed the self.fc1 to nn.Linear( 7 7 out_channel2, 1-24), then error is gone.

YaxinLi0-0 commented 3 years ago

I think there might be two possibilities: (1) We have updated the CNN file once, previously the fc1 layer is defined as: self.fc1 = nn.Linear(8 8 out_channel2, 1024) while we fix that in the latest version. Try download the package and install again may solve the problem.

(2) From the error message we can see the model shape is mismatched with the current model file. Where you get the model from? Maybe the model is trained on a different model file while the current model file has changed.