Open haoweiz23 opened 3 years ago
maybe because index 0 is CLS token?
maybe because index 0 is CLS token?
I think so. You probably found this line that last_map = last_map[:, :, 0, 1:] in class Part_Attention. So, when I canceled the class token and fixed that code to last_map = last_map[:, :, 0], some errors happened. After I Annotated part_inx = part_inx + 1, it had been fixed.
last_map = lastmap[:,:,0,1:] , max_inx = last_map.max(2) The max_inx start from 0 because of .max() but actually we skip the first element (1:). So, when project back to the original tensor, we need to + 1 to the indexes.
https://github.com/TACJu/TransFG/blob/ff28b5842fe41dbbb03a3cf698888dd57c882b8a/models/modeling.py#L266
Why part_inx variable need to plus 1?