MadryLab / robustness

A library for experimenting with, training and evaluating neural networks, with a focus on adversarial robustness.
MIT License
918 stars 181 forks source link

About MNIST dataset and input normalization. #69

Closed qimingyudaowenti closed 4 years ago

qimingyudaowenti commented 4 years ago

Thanks for your excellent work and I have three questions:

  1. I notice that robustness does not support the MNIST dataset now. Do you consider adding it in the future?
  2. When dealing with MNIST, do we need input normalization? I'm worried because there is no normalization operation in mnist_challenge but it exists in: https://github.com/MadryLab/robustness/blob/89bdf8088a8f4bd4a8b86925a2801069ec281fee/robustness/attacker.py#L319
  3. Will the above normalization of adversarial images damage the attack ability?

Hope for your answer~

andrewilyas commented 4 years ago

Hello, thanks for the issue! While we don't have immediate plans to support MNIST, it is really easy to add custom datasets! See the datasets.py file in robustness/ for examples. No guarantees the below will work exactly as-is, but it should be something as simple as:

from robustness.datasets import DataSet
from robustness import cifar_models

class MNIST(DataSet):
    def __init__(self, data_path='/tmp/', **kwargs):
        ds_kwargs = {
            'num_classes': 10,
            'mean': ch.tensor([0., 0., 0.]),
            'std': ch.tensor([1., 1., 1.]),
            'custom_class': datasets.MNIST,
            'label_mapping': None, 
            'transform_train': ..., # TODO
            'transform_test': ... # TODO
        }
        ds_kwargs = self.override_args(ds_kwargs, kwargs)
        super(MNIST, self).__init__('mnist', data_path, **ds_kwargs)

    def get_model(self, arch, pretrained):
        if pretrained:
            raise ValueError('CIFAR does not support pytorch_pretrained=True')
        return cifar_models.__dict__[arch](num_classes=self.num_classes)

Notice that we set the mean to be 0s and the stdev to be 1s, which effectively disables normalization.

qimingyudaowenti commented 4 years ago

Thank you for your helpful reply!