MadryLab / robustness

A library for experimenting with, training and evaluating neural networks, with a focus on adversarial robustness.
MIT License
903 stars 181 forks source link

[Reduce Runtime] Better utilize GPU resources for PGD with random restarts #98

Open SohamTamba opened 3 years ago

SohamTamba commented 3 years ago

Summary

I'm working on a model that suffers from gradient masking, so I have to use multiple (16) random restarts. I developed an implementation that is faster that the current one while using multiple restarts.

My experiments involve training a BatchNorm Free ResNet-26 on a v100 GPU for CIFAR-10 with a batch size of 64 and 16 random restarts, 20 steps. My implementation costs 170 minutes per epoch, while the current implementation costs 250 minutes per epoch. This is a 1.47x speed-up.

The training curves of both implementations match, I'm pretty sure my implementation is correct.

Please let me know if you guys are interested in a Pull Request.


Details

The current implementation for PGD with random restarts works like this:

Input: X, Y, model

output = X.clone()
for restart_i in range(num_random_restarts):
    noise = random_constrained_noise(shape=X.shape) if restart_i > 0 else 0
    pert_X = X + noise
    adv, is_misclass = run_pgd(pert_X, Y, model)
   output[is_misclass] = adv[is_misclass] 
return output

If the user has enough GPU memory - which will often be the case for CIFAR10 -, then the following implementation would improve GPU utilization:

B, C, H, W = X.shape
Y_stack = torch.stack([Y for _ in range(num_random_restarts)])
X_stack = X.unsqueeze(0) # Or torch.stack([X for _ in range(num_random_restarts)])

noise = random_constrained_noise(shape=[num_random_restarts, B, C, H, W])
noise[0, :, :, :, :] = 0
pert_X_stack = X_stack + noise

pert_X_stack = pert_X_stack.view(-1, C, H, W)
Y_stack = Y_stack.view(-1)

adv, is_misclass = run_pgd(pert_X_stack, Y_stack, model)
adv = adv.view(-1, B, C, H, W)
is_misclass = is_misclass.view(-1, B)
return_ind = is_misclass.argmax(axis=0) 
return adv[return_ind]

If num_restarts is so huge that the input of size num_restarts x batch_size x C x H x W does not fit in GPU memory, then the user could also be allowed to specify a mini_num_restarts < num_restarts so that the GPU processes mini_num_restarts batches at a time. i.e. input size is reduced to mini_num_restarts x batch_size x C x H x W.

My only concern is this might negatively affect BatchNorm. But given how computationally intensive adversarial training is, this might be a worthwhile option to provide users.

Reference: https://github.com/MadryLab/robustness/blob/79d371fd799885ea5fe5553c2b749f41de1a2c4e/robustness/attacker.py#L235