Closed mzweilin closed 5 years ago
Hi Weilin,
Thank you for trying the code. I can see that you are already achieving state-of-the-art results in terms of verified accuracy.
As you found out, the best model is indeed the large model on MNIST (although the medium model should come close to it at eps = 0.1). The code published here is, in fact, a subset of the code used in the paper (trimmed in order to be open-sourced) so there are subtle differences that might explain a small gap in error rate.
I just trained a "large" model with epsilon = 0.3 (epsilon train = 0.33 - the paper used 0.4 but I believe 0.33 is a better trade-off in general) on my machine with the published IBP code and get 1.69% nominal error rate and 8.79% verified error rate (kappa set to 0.5 and without using a complete solver). See below for additional tricks if you need to squeeze out more performance.
The verified accuracy will definitely improve once you use a complete solver. As a reference, your result for eps = 0.1 matches exactly our IBP only bound (without using a MIP solver).
Lowering the nominal error rate will involve tweaking the value of kappa (--nominal_xent_final
and --verified_xent_final
) slightly (0.5 is a good starting point, but you can improve results by modifying it). Additionally, if you really want to push IBP to the limit (and possibly improve upon on the paper results), we found that adding a little bit of adversarial cross-entropy helps stabilize training (--attack_xent_init
and --attack_xent_final
).
The last commit consists of cosmetic changes and should not affect results.
Regards, Sven
Hi Sven,
Thanks for your suggestions in the reply. I really appreciate it. I agree that the MNIST results were close to the reported ones and I wouldn't worry about it too much.
However, I am also struggling in reproducing the CIFAR10 results. I have composed some training commands from your description on the paper and found that the small model architecture consistently produced the best results. However, even the best results were not comparable to those in Table 3. Do you think I made some terrible mistakes in the experiment?
$ python train.py --dataset=cifar10 --model=small --output_dir=results/cifar10_small_eps2 --steps=350001 --batch_size=50 --learning_rate=1e-3,1e-4@200000,1e-5@250000,1e-6@300000 --warmup_steps=10000 --rampup_steps=150000 --epsilon=0.007843137254902
Results: nominal accuracy = 47.55%, verified = 18.39%, attack = 33.89%
$ python train.py --dataset=cifar10 --model=small --output_dir=results/cifar10_small_eps8 --steps=350001 --batch_size=50 --learning_rate=1e-3,1e-4@200000,1e-5@250000,1e-6@300000 --warmup_steps=10000 --rampup_steps=150000 --epsilon=0.031372549019608
Results: nominal accuracy = 41.81%, verified = 18.97%, attack = 27.73%
Also, do you have a plan to make the reported results easily reproducible shortly? I think it is essential if you want those results to serve as a useful baseline for other researchers.
Best, -Weilin
Hi Weilin,
The results are indeed far from ours.
Here are some things to keep in mind concerning CIFAR-10: Per default, tf.keras.datasets.cifar10.load_data() will return images between 0 and 255. These images need to be rescaled between 0 and 1. Additionally, CIFAR-10 requires normalization and random translations and flips to work well. The large model performs best for us.
Just to add another datapoint, we can train the small model to reach 46% nominal accuracy, 25% verified accuracy and 30% PGD accuracy with epsilon = 8/255. Maybe it's easier for you to train the small model.
The results should already be reproducible from the paper description (if this is not the case, we will amend the paper accordingly). Perhaps, you mean repeatable instead of reproducible. At this point, I have no immediate plans to modify the code in train.py
. I am, however, interested in your results.
Hi Sven,
Thanks for pointing out the issue. I have corrected the result of epsilon 8/255.
I notice that you have implemented scaling to [0,1] in build_dataset(). I will try if data augmentation will significantly improve the results.
Best, -Weilin
I ran some tests locally. Using data augmentation as suggested achieves (using the small network) a nominal accuracy of 43.55%, a verified accuracy of 26.01% (without MIP/LP verification) and an accuracy under PGD attack of 30.39%.
I've used the following command (after adding normalization and data augmentation) which gets away with less training steps:
python train.py --model=small --steps=200001 --warmup_steps=10000 --rampup_steps=120000 --batch_size=100 --epsilon=0.03137254901 --learning_rate=1e-3,1e-4@150000,1e-5@180000 --dataset=cifar10
I am planning to commit a newer version of train.py to reflect the addition of data augmentation soon. I am now considering the case closed.
HI @sgowal, could you please share more information on the normalization and data augmentation that you used? I'm struggling reproducing your results.
I added the Image Normalization layer you provide in the repository, with the correct dataset mean and std already divided by 255. mean=[0.491, 0.482, 0.447]
and std=[0.247, 0.243, 0.262]
.
I'm also using random up/down left/right flip, crop/zoom and rotation, but using the command you provide in tha last post I'm only reaching this level of protection:
200000: loss = 1.6716492176055908, nominal accuracy = 44.62%, verified = 19.75%, attack = 27.70%
Thanks for your help.
~Marco
Hi Marco,
mean = (0.4914, 0.4822, 0.4465)
and std = (0.2023, 0.1994, 0.2010)
Let me know if you have any other issues.
Thanks, less augmentation and the right std did the trick! I achieved the performance you mentioned now!
Unfortunately, I'm still not able to reproduce the results for the large network and eps=2./255
with the hyper-parameters suggested in the appendix of your paper:
python examples/train.py --model=large --steps=350001 --warmup_steps=10000 --rampup_steps=150000 --batch_size=50 --epsilon=0.00784313725 --learning_rate=1e-3,1e-4@200000,1e-5@250000,1e-6@300000 --dataset=cifar10
350000: loss = 1.1290245056152344, nominal accuracy = 61.89%, verified = 41.46%, attack = 50.21%
I've just ran an experiment with the faster scheduling. I'll let you know.
You should be able to achieve better robustness even with the small model (although you're only 4% off). Can you try to change this line:
https://github.com/deepmind/interval-bound-propagation/blob/master/examples/train.py#L159
from FLAGS.epsilon,
to FLAGS.epsilon * 1.1,
.
Thanks, I'll give it a try. My concern is more the nominal accuracy that is ~8% far from the values on the paper, that on CIFAR10 is a lot :-) Once the nominal accuracy is fixed, I guess the protection would be too.
HI @sgowal, I like your paper. However, I also have trouble reproducing your results.
In the paper,
During training, we add random translations and flips, and normalize each image channel (using the channel statistics from the training set).
and you said
In particular, I only use random left/right flips and random cropping (pad with 4 pixels on each size and take a random 32x32 crop). There is no zooming, no rotations and no up/down flip.
So did you mean the translation by the cropping? Can I implement it as follow(I'm using pytorch)?
from torchvision import datasets, transforms
trainset_transforms = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
trainset = datasets.CIFAR10(root='./data/', train=True, download=True, transform=trainset_transforms)
testset = datasets.CIFAR10(root='./data/', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]))
Thank you in advance.
Hi Sungyoon-Lee,
The transformation in pytorch seems to be correct.
Here are a few additional details that just came to my attention.
Model architecture. I must apologize: the latest CIFAR-10 results (in v3 of the paper) are obtained with a slightly different network architecture (we plan on correcting that at the next update). The network that performs best is a truncated VGG:
32x32x3 images -> Conv2D+ReLU with 3x3x64 kernels and 1 stride (output is 32x32x64) -> Conv2D+ReLU with 3x3x64 and 1 stride (output is 32x32x64) -> Conv2D+ReLU with 3x3x128 with 2 stride (output is 16x16x128) -> Conv2D+ReLU with 3x3x128 with 1 stride (output is 16x16x128) -> FC+ReLU with 512 output (output is 512) -> FC with 10 output.
There is one less Conv2D layer but a larger FC layer compared to the older architecture. However, 73% verified accuracy is achievable with the old large network on 8/255 (57% at 2/255) with the old schedule (see below).
Schedule. The schedule detailed in the v3 version is an old schedule. With this old schedule you should be able to reach between 71-72% in verified accuracy with the newer model. The newer schedule trains for much longer and uses 32 TPUs. We train for 3200 epochs with batch sizes of 1600. The total number of training steps is 100K. We decay the learning rate by 10x at steps 60K and 90K. We use warm-up and ramp-up durations of 5K and 50K steps, respectively.
Training epsilon. Increasing slightly the perturbation radius used for training w.r.t. testing helps. As a rule of thumb I use 1.1 x testing radius, so 8/255 becomes 8.8/255.
Kappa schedule. Some people have reported slightly better result by decreasing kappa from 1 to 0.5 more quickly.
I hope these help.
Thank you for your helpful explanation and scheduling techniques!
However, 73% verified accuracy is achievable with the old large network on 8/255 (57% at 2/255) with the old schedule (see below).
I think the verified accuracy is better when eps is smaller. Can you clarify this? Thanks
Sorry I indeed meant verified error rate.
On Wed, May 15, 2019 at 09:35 Sungyoon Lee notifications@github.com wrote:
Thank you for your helpful explanation and scheduling techniques!
However, 73% verified accuracy is achievable with the old large network on 8/255 (57% at 2/255) with the old schedule (see below).
I think the verified accuracy is better when eps is smaller.
— You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub https://github.com/deepmind/interval-bound-propagation/issues/1?email_source=notifications&email_token=AAG5ZY5G42QNJJSP4ZPCM3LPVPDOLA5CNFSM4GSHMBM2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODVN6GQI#issuecomment-492561217, or mute the thread https://github.com/notifications/unsubscribe-auth/AAG5ZYZZGJYMS7VIROAVM23PVPDOLANCNFSM4GSHMBMQ .
Oh, I got it. Thank you again.
Hi there,
Thanks for sharing the code. The technique you proposed in the paper sounds promising and the results are exciting to me! However, I have some difficulty in reproducing your results in Table 3 and I wonder if there's something wrong in my experimental setup.
It looks like the default program options in examples/train.py have already configured everything for the MNIST training procedure, except for the epsilon_train which you have presented in Table 3 separately.
I have tried all three models: small, medium and large as you didn't mention which one to use in Table 3. I have found that the large model always produced the best results. However, even the best results I have got are constantly worse than those reported in Table 3, in terms of both test error and verified error.
For the test error, the best I can get is 1.12% (over 1.06% in the paper, where eps_train=0.2) and 2.31% (over 1.66% in the paper, where eps_train=0.4).
For the verified error, the best results I have produced are [2.93%, 5.60%, 9.24%, 16.69%] respectively for the epsilon values 0.1, 0.2, 0.3, 0.4, which is always slightly worse than the reported results in Table 3: [2.23%, 4.48%, 8.05%, 14.88%].
Could you please advise me on how to correctly reproduce your results? I believe an exact verifier may help to narrow the gap for the verified error, but it won't help to reduce the test error.
BTW, I used your initial commit version 15340d3 for the experiment. I don't know if the latest commit 5fa09e7 of yesterday will change the results.
Best, -Weilin