MadryLab / smoothed-vit

Certified Patch Robustness via Smoothed Vision Transformers
https://arxiv.org/abs/2110.07719
MIT License
41 stars 4 forks source link

A minor issue #1

Closed ljb121002 closed 2 years ago

ljb121002 commented 2 years ago

Hi, very solid work. But I think there is a typo in the class "MuskProcessor" of /src/utils/custom_models/preprocess.py. "ones_mask = torch.where(ones_mask.view(-1) > 0)[0]" this line choose those patches which intersect with ablation columns or blocks, and they belong to [0, 195]. To keep the class token always, the code just adds a 0 at the beginning of ones_mask in the next line. But I think we should first change the range from [0, 195] to [1, 196], then add the 0 :)

scoutsaachi commented 2 years ago

Hi!

The full line here is ones_mask = torch.where(ones_mask.view(-1) > 0)[0] + 1 which does first change the range to [1,196] and then prepends the 0. The code thus already performs exactly as you described.

ljb121002 commented 2 years ago

Ohh yeah right, sorry about this!