Closed qimingyudaowenti closed 4 years ago
Hello, thanks for the issue! While we don't have immediate plans to support MNIST, it is really easy to add custom datasets! See the datasets.py
file in robustness/
for examples. No guarantees the below will work exactly as-is, but it should be something as simple as:
from robustness.datasets import DataSet
from robustness import cifar_models
class MNIST(DataSet):
def __init__(self, data_path='/tmp/', **kwargs):
ds_kwargs = {
'num_classes': 10,
'mean': ch.tensor([0., 0., 0.]),
'std': ch.tensor([1., 1., 1.]),
'custom_class': datasets.MNIST,
'label_mapping': None,
'transform_train': ..., # TODO
'transform_test': ... # TODO
}
ds_kwargs = self.override_args(ds_kwargs, kwargs)
super(MNIST, self).__init__('mnist', data_path, **ds_kwargs)
def get_model(self, arch, pretrained):
if pretrained:
raise ValueError('CIFAR does not support pytorch_pretrained=True')
return cifar_models.__dict__[arch](num_classes=self.num_classes)
Notice that we set the mean to be 0s and the stdev to be 1s, which effectively disables normalization.
Thank you for your helpful reply!
Thanks for your excellent work and I have three questions:
Hope for your answer~