Beckschen / TransUNet

This repository includes the official project of TransUNet, presented in our paper: TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation.
Apache License 2.0
2.42k stars 502 forks source link

Question about "patch_size" #30

Open benzyz1129 opened 3 years ago

benzyz1129 commented 3 years ago

Thanks for your work. I have some questions about the patch size of patch embedding when using CNN and Transformer as the encoder.

In the section 3.2 of the paper, it mentions that patch embedding is applied to 1x1 patches extracted from the CNN feature maps instead of from raw image when using CNN-Tranformer hybrid as the encoder.

From my understanding, regardless of the height and width of the feature map extracted from CNN, the patch embedding will be the nn.Conv2d with kernel_size=1 and stride=1.

Here is the code.

if config.patches.get("grid") is not None:   # ResNet
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
            n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])  
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
            in_channels = self.hybrid_model.width * 16
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)

When img_size=512, and configurations in get_r50_b16_config is applied, the outputs of patch_embedding will be a tensor which shape is (B, 1024, 16, 16). The height and width is 1/32, not 1/16 of the original image size. So you will need total 5 times of upsampling operations instead of 4 times, which is different from your implementation.

Shouldn't the kernel_size and stride be 1 when using CNN-Tranformer as the encoder?

I would be very grateful for letting me know if it is my misunderstanding.

ChenchenHu007 commented 3 years ago

I meet the same question. According to the paper, when img_size = 224, the patch_size=224 // 16 // 16 = 0.

zlyx525 commented 3 years ago

In train.py, the grid_size is reset. So, when img_size = 224, the grid_size is 224//116=14, and the patch_size=224//14//16 = 1. if args.vit_name.find('R50') != -1: config_vit.patches.grid = (int(args.img_size / args.vit_patches_size), int(args.img_size / args.vit_patches_size))