ZiangYan / deepdefense.pytorch

Implementation of our NeurIPS 2018 paper: Deep Defense: Training DNNs with Improved Adversarial Robustness
39 stars 7 forks source link

can you add the inverse/backward for batch norm layer #2

Closed sndnyang closed 5 years ago

sndnyang commented 5 years ago

HI, Thank you for your excellent work! I have an issue. You have implemented the inverse layers for Conv/Linear/Dropout/Pool layers, but I found you forgot the batch norm layer which is used widely in NN too... So can you add an NN example with batch norm layers?

ZiangYan commented 5 years ago

Hi, thanks for your interest in our work.

Actually, we use a simple strategy for batch norm layers: just freeze all variables in BN layers during both training and test.

The ResNet-18 result reported in our paper is produced with this method.

The core code about this part should be something like

for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
sndnyang commented 5 years ago

But, in many cases, we need to use the batch norm when other methods use it. So I'm interested in the implementation like below:

For example, I use an MLP like: class MLP(nn.Module): def init(self): super(MLP, self).init()

an affine operation: y = Wx + b

    self.fc1 = nn.Linear(784, 1200)
    self.bn_fc1 = nn.BatchNorm1d(1200)
    self.fc2 = nn.Linear(1200, 600)
    self.bn_fc2 = nn.BatchNorm1d(600)
    self.fc3 = nn.Linear(600, 10)

how to write the InverseMLP and its forward (based on your MLP model).

class InverseMLP(nn.Module): def init(self): super(InverseMLP, self).init() self.transposefc3 = LinearTranspose(10, 600, bias=False) self.transposefc2 = LinearTranspose(600, 1200, bias=False) self.transposefc1 = LinearTranspose(1200, 784, bias=False)

def forward(self, x, relu1_mask, relu2_mask):
    self.relu2_out = self.transposefc3(x)
    self.fc2_out = self.relu2_out * relu2_mask
    self.relu1_out = self.transposefc2(self.fc2_out)
    self.fc1_out = self.relu1_out * relu1_mask
    self.flat_out = self.transposefc1(self.fc1_out)
    self.input_out = self.flat_out.view(-1, 1, 28, 28)
    return self.input_out

Thank you!

ZiangYan commented 5 years ago

Hi, sorry for the late reply.

I've added an example with batch norm.

https://github.com/ZiangYan/deepdefense.pytorch/blob/48621f7d40c5c7f3470b59a77cc42a0e18e2f0bd/models/mnist.py#L258-L337

You can find the download url for reference model of this example in README.

Thanks.