clustr-official-account / Rethinking-Clustering-for-Robustness

This is the official implementation of ClusTR: Clustering Training for Robustness paper.
https://arxiv.org/abs/2006.07682
MIT License
20 stars 4 forks source link

Rethinking Clustering for Robustness - PyTorch implementation

This repository implements "ClusTR" in PyTorch. Running this code succesfully reproduces the results in the manuscript.

ClusTR is a theoretically-motivated framework for training Deep Neural Networks (DNNs) for adversarial robustness.

ClusTR is based on using a Clustering Loss, i.e. a loss that encourages clustering of semantically-similar instances in representation space, for inducing semantics into the training of the network. ClusTR harnesses this loss, together with standard techniques for DNN training, to train robust networks without the need of conducting training on adversaries.

As a Clustering Loss, ClusTR employs the Magnet Loss, introduced in Metric Learning with Adaptive Density Discrimination, Rippel et al., ICLR 2016. Since there is no official implementation of the Magnet Loss, we borrow parts of the Magnet Loss implementation from MagnetLossPyTorch, Vithursan Thangarasa, 2018, GitHub.

Installation

This repository has several dependencies. These dependencies can be easily installed by using Anaconda to create an environment based on the file utils/environment.yml. Downloading all requirements and creating an environment based on this yml file is achieved by running:

cd utils

And then

conda env create -f environment.yml

These lines create an environment called pytorch. To activate this environment, run:

conda activate pytorch

Repository structure

This repository has three main folders, as described next:

The main file for running training/testing in this repository is main_magnet.py. The arguments this file takes as input are established in utils/train_setting.py.

Training ClusTR+QTRADES

To train ClusTR+QTRADES, pre-trained weights are required. These pre-trained weights are in the pretrained_weights directory.

To train, run

python main_clustr.py --checkpoint clustr_qtrades --pretrained-path pretrained_weights/resnet18.pt --epochs 25 --consistency-lambda 8

This command will run training of ClusTR+QTRADES on CIFAR10 for 25 epochs, starting from the pre-trained weights at pretrained_weights/resnet18.pt, with a coefficient for the TRADES loss of 8, i.e. the in Equation (5) in the manuscript. The results will be saved at directory clustr_qtrades.

When the command finishes running, the directory clustr_qtrades will have three files, described next.

Please refer to the script utils/train_settings.py or run python main_clustr.py --help for details of the possible arguments to pass to the main_clustr.py script.

Evaluating ClusTR+QTRADES

To evaluate the checkpoint obtained from the training procedure, run

python main_clustr.py --checkpoint clustr_qtrades_pgd20 --evaluate-ckpt experiments/clustr_qtrades/checkpoint.pth --iterations 20 --restarts 10 --L 20

This evaluation procedure will consider the closest 20 clusters (the L parameter) and run PGD attacks (l-infinity norm bounded attacks) with 20 iterations and 10 restarts will be run for assessing robustness. When the procedure finishes there will be a csv file called attack_results_ext.csv under the experiments/clustr_qtrades_pgd20 directory. This file has the results from the PGD attack. There are two columns: epsilons and _test_setaccs. Each row of the file shows the resulting PGD accuracy at the corresponding value of epsilon (the strength of the attack).

You can also run stronger PGD-100 attacks by running (this may take very long, depending on your GPU).

python main_clustr.py --checkpoint clustr_qtrades_pgd100 --evaluate-ckpt experiments/clustr_qtrades/checkpoint.pth --iterations 100 --restarts 10 --L 20

Citation

If you find our work useful, please consider citing it as

@misc{alfarra2020clustr,
    title={ClusTR: Clustering Training for Robustness},
    author={Motasem Alfarra and Juan C. Pérez and Adel Bibi and Ali Thabet and Pablo Arbeláez and Bernard Ghanem},
    year={2020},
    eprint={2006.07682},
    archivePrefix={arXiv}
}