Closed siebenkopf closed 3 months ago
There is an unnecessary complication in the dataset classes. Each dataset includes a normalization with mean and std, which transforms the data into a range R different from [0,1]: https://github.com/melihcatal/advsecurenet/blob/eec54973826484ec22d401dd7aed17145a62fd3d/advsecurenet/datasets/base_dataset.py#L62 However, all attacks require inputs in range [0,1], sometimes even enforcing this via torch.clamp(.,0,1): https://github.com/melihcatal/advsecurenet/blob/eec54973826484ec22d401dd7aed17145a62fd3d/advsecurenet/attacks/fgsm.py#L58 Consequently, for adversarial training, the data needs to be unnormalized first:https://github.com/melihcatal/advsecurenet/blob/eec54973826484ec22d401dd7aed17145a62fd3d/advsecurenet/defenses/adversarial_training.py#L118 then the attack can run, and then normalized again.
A better option would have added the data normalization (which is a standard torch.nn.Module, cf. https://pytorch.org/vision/main/_modules/torchvision/transforms/transforms.html#Normalize) as first layer in the network, instead of at the end of the data loading.
torch.nn.Module
There is an unnecessary complication in the dataset classes. Each dataset includes a normalization with mean and std, which transforms the data into a range R different from [0,1]: https://github.com/melihcatal/advsecurenet/blob/eec54973826484ec22d401dd7aed17145a62fd3d/advsecurenet/datasets/base_dataset.py#L62 However, all attacks require inputs in range [0,1], sometimes even enforcing this via torch.clamp(.,0,1): https://github.com/melihcatal/advsecurenet/blob/eec54973826484ec22d401dd7aed17145a62fd3d/advsecurenet/attacks/fgsm.py#L58 Consequently, for adversarial training, the data needs to be unnormalized first:https://github.com/melihcatal/advsecurenet/blob/eec54973826484ec22d401dd7aed17145a62fd3d/advsecurenet/defenses/adversarial_training.py#L118 then the attack can run, and then normalized again.
A better option would have added the data normalization (which is a standard
torch.nn.Module
, cf. https://pytorch.org/vision/main/_modules/torchvision/transforms/transforms.html#Normalize) as first layer in the network, instead of at the end of the data loading.