zihangJiang / TokenLabeling

Pytorch implementation of "All Tokens Matter: Token Labeling for Training Better Vision Transformers"
Apache License 2.0
426 stars 36 forks source link

label_map does not do the same augmentation (random crop) as the input image #18

Closed haooooooqi closed 2 years ago

haooooooqi commented 3 years ago

Hi Thanks so much for the nice work! I am curious if you could share the insight on processing of the label_map. If I understand it correctly, after we load image and the corresponding, we shall do the same cropping/ flip/ resize, but in https://github.com/zihangJiang/TokenLabeling/blob/aa438eff9b9fc2daa8c8b4cc6bfaa6e3721f995e/tlt/data/label_transforms_factory.py#L58-L73 Seems only image was cropped, but the label map does not do the same cropping, which make the label map not match with the image?

Shall we do

        return torchvision_F.resized_crop(
                img, i, j, h, w, self.size, interpolation
        ), torchvision_F.resized_crop(
                label_map, i / ratio, j / ratio, h / ratio, w / ratio, self.size, interpolation
        )

Thanks

zihangJiang commented 3 years ago

Thanks for your question, the coords (i.e. i,j,h,w) for the random crop are stored in the label map and will be used later here. https://github.com/zihangJiang/TokenLabeling/blob/aa438eff9b9fc2daa8c8b4cc6bfaa6e3721f995e/tlt/data/mixup.py#L48-L70 This helps to crop the label map using the given coords in parallel with rio_align function, which will be slightly faster than processing each label map individually as in your example.