hendrycks / ss-ood

Self-Supervised Learning for OOD Detection (NeurIPS 2019)
MIT License
266 stars 31 forks source link

Reproducing robustness results for CIFAR-10 via auxilliary rotation task #4

Closed nupurkmr9 closed 5 years ago

nupurkmr9 commented 5 years ago

Hi,

I found your research paper very interesting.

However, when I was implementing your paper, I was unable to reproduce the results for CIFAR-10 with the following configs: Network: wrn 40-2 Training loss = c.e(adv) + 0.5(Loss_rotation) Adv. perturbation creation loss = cross-entropy(x,y) + Loss_rotation SGD, learning rate = 0.1 , momentum = 0.9 and nestrov=true , batch=128 with cosine annealing for 205 epochs. i.e. optimizer = torch.optim.SGD([ {'params': model.parameters()}, {'params': rotate_classifier.parameters()} ] , lr=0.1 , nesterov = True , momentum = 0.9,weight_decay=0.0005)

scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: cosine_annealing( step, 205 * len(base_loader), 1, # since lr_lambda computes multiplicative factor 1e-6 / 0.1))

I am getting the following result: Test Accuracy: 72.4425 : Rotation Accuracy 80.5675 : Adversarial Accuracy(pgd-10 only on cross-entropy loss ): 10.42

Can you please mention the hyper-parameters again for learning rate scheduler and the number of epochs of training you used for getting the results?

Thanks

hendrycks commented 5 years ago
def train():
    net.train()  # enter train mode
    loss_avg = 0.0
    for bx, by in train_loader:
        curr_batch_size = bx.size(0)
        by_prime = torch.cat((torch.zeros(bx.size(0)), torch.ones(bx.size(0)),
                              2*torch.ones(bx.size(0)), 3*torch.ones(bx.size(0))), 0).long()
        bx = bx.numpy()
        bx = np.concatenate((bx, bx, np.rot90(bx, 1, axes=(2, 3)),
                             np.rot90(bx, 2, axes=(2, 3)), np.rot90(bx, 3, axes=(2, 3))), 0)
        bx = torch.FloatTensor(bx)
        bx, by, by_prime = bx.cuda(), by.cuda(), by_prime.cuda()

        adv_bx = adversary(net, bx, by, by_prime, curr_batch_size)

        # forward
        logits, pen = net(adv_bx * 2 - 1)

        # backward
        scheduler.step()
        optimizer.zero_grad()
        loss = F.cross_entropy(logits[:curr_batch_size], by)
        loss += 0.5 * F.cross_entropy(net.module.rot_pred(pen[curr_batch_size:]), by_prime)
        loss.backward()
        optimizer.step()

        # exponential moving average
        loss_avg = loss_avg * 0.9 + float(loss) * 0.1

    state['train_loss'] = loss_avg

class PGD(nn.Module):
    def __init__(self, epsilon, num_steps, step_size, grad_sign=True, attack_rotations=True):
        super().__init__()
        self.epsilon = epsilon
        self.num_steps = num_steps
        self.step_size = step_size
        self.attack_rotations = attack_rotations

    def forward(self, model, bx, by, by_prime, curr_batch_size):
        """
        :param model: the classifier's forward method
        :param bx: batch of images
        :param by: true labels
        :return: perturbed batch of images
        """
        adv_bx = bx.detach()
        adv_bx += torch.zeros_like(adv_bx).uniform_(-self.epsilon, self.epsilon)

        for i in range(self.num_steps):
            adv_bx.requires_grad_()
            with torch.enable_grad():
                logits, pen = model(adv_bx * 2 - 1)
                loss = F.cross_entropy(logits[:curr_batch_size], by, reduction='sum')
                if self.attack_rotations:
                    loss += F.cross_entropy(model.module.rot_pred(pen[curr_batch_size:]), by_prime, reduction='sum')
            grad = torch.autograd.grad(loss, adv_bx, only_inputs=True)[0]

            adv_bx = adv_bx.detach() + self.step_size * torch.sign(grad.detach())

            adv_bx = torch.min(torch.max(adv_bx, bx - self.epsilon), bx + self.epsilon).clamp(0, 1)

        return adv_bx