qubvel-org / segmentation_models.pytorch

Semantic segmentation models with 500+ pretrained convolutional and transformer-based backbones.
https://smp.readthedocs.io/
MIT License
9.73k stars 1.68k forks source link

Update MixVisionTransformer #975

Open brianhou0208 opened 2 weeks ago

brianhou0208 commented 2 weeks ago

Hi, @qubvel This PR introduces support for different output strides in the Mit Encoder. The original Mit Encoder required extra parameters (H, W) to be passed along with the features during forward propagation, which is not compatible with the current self.get_stage format.

This update enables PAN, DeepLabv3, and DeepLabv3+ to support the Mit encoder.

Update

  1. Replaced the original nn.LayerNorm to support different input shapes: (B, C, H, W) or (B, N, C).
  2. Passes between different stages in the (B, C, H, W) format.

Test Code

import torch
import segmentation_models_pytorch as smp

def get_features(name='resnet18', output_stride=32):
    x = torch.rand(1, 3, 256, 256)
    backbone = smp.encoders.get_encoder(name, depth=5, output_stride=output_stride)
    features = backbone(x)
    print(name, output_stride, [f.detach().numpy().shape for f in features])

if __name__ == '__main__':
    torch.manual_seed(0)
    get_features('resnet18', 32)
    get_features('resnet18', 16)
    get_features('resnet18', 8)

    get_features('mit_b0', 32)
    get_features('mit_b0', 16)
    get_features('mit_b0', 8)

    get_features('tu-mobilenetv3_small_050', 32)
    get_features('tu-mobilenetv3_small_050', 16)
    get_features('tu-mobilenetv3_small_050', 8)

output

resnet18 32 [(1, 3, 256, 256), (1, 64, 128, 128), (1, 64, 64, 64), (1, 128, 32, 32), (1, 256, 16, 16), (1, 512, 8, 8)]
resnet18 16 [(1, 3, 256, 256), (1, 64, 128, 128), (1, 64, 64, 64), (1, 128, 32, 32), (1, 256, 16, 16), (1, 512, 16, 16)]
resnet18 8 [(1, 3, 256, 256), (1, 64, 128, 128), (1, 64, 64, 64), (1, 128, 32, 32), (1, 256, 32, 32), (1, 512, 32, 32)]
mit_b0 32 [(1, 3, 256, 256), (1, 0, 128, 128), (1, 32, 64, 64), (1, 64, 32, 32), (1, 160, 16, 16), (1, 256, 8, 8)]
mit_b0 16 [(1, 3, 256, 256), (1, 0, 128, 128), (1, 32, 64, 64), (1, 64, 32, 32), (1, 160, 16, 16), (1, 256, 16, 16)]
mit_b0 8 [(1, 3, 256, 256), (1, 0, 128, 128), (1, 32, 64, 64), (1, 64, 32, 32), (1, 160, 32, 32), (1, 256, 32, 32)]
tu-mobilenetv3_small_050 32 [(1, 3, 256, 256), (1, 16, 128, 128), (1, 8, 64, 64), (1, 16, 32, 32), (1, 24, 16, 16), (1, 288, 8, 8)]
tu-mobilenetv3_small_050 16 [(1, 3, 256, 256), (1, 16, 128, 128), (1, 8, 64, 64), (1, 16, 32, 32), (1, 24, 16, 16), (1, 288, 16, 16)]
tu-mobilenetv3_small_050 8 [(1, 3, 256, 256), (1, 16, 128, 128), (1, 8, 64, 64), (1, 16, 32, 32), (1, 24, 32, 32), (1, 288, 32, 32)]