isaaccorley / torchseg

Segmentation models with pretrained backbones. PyTorch.
MIT License
104 stars 8 forks source link

DeepLabV3Plus is not compatible with encoder_depth=4 and swin models #54

Open Akshay1-6180 opened 4 months ago

Akshay1-6180 commented 4 months ago

So i was working with both swinv2_tiny_window8_256 and swinv2_base_window12to16_192to256 and noticed that it was not loading with torchseg.DeepLabV3Plus

model = torchseg.DeepLabV3Plus(
    "swinv2_base_window12to16_192to256",
    in_channels=1,
    classes=2,
    encoder_weights=True,
    encoder_depth=4,
    decoder_channels=256,
    encoder_output_stride=16,
    encoder_params={"img_size": 1024}  # need to define img size since swin is a ViT hybrid
)

in both the cases it gives this error , for a sample code here u go

import torchseg
model = torchseg.DeepLabV3Plus(
    "swinv2_tiny_window8_256",
    in_channels=1,
    classes=2,
    encoder_weights=True,
    encoder_depth=4,
    decoder_channels=256,
    encoder_output_stride=16,
    encoder_params={"img_size": 1024}  # need to define img size since swin is a ViT hybrid
)
dummy_im = torch.randn(
            4, 1, 1024, 1024
        )  
out = model.encoder(dummy_im)
dummy_dec_out = model.decoder(*out)

It gives this error RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 128 but got size 256 for tensor number 1 in the list. But this error occurs also with resent50 when the encoder depth = 4

model = torchseg.DeepLabV3Plus(
    "resnet50",
    in_channels=1,
    classes=2,
    encoder_weights=True,
    encoder_depth=4,
    decoder_channels=256,
    encoder_output_stride=16,
    #encoder_params={"img_size": 1024}  # need to define img size since swin is a ViT hybrid
)

So i changed the encoder depth as 5 and this worked now for resent50. But swin models have a maximum depth of 4 for the encoders and this makes it incompatible to work with swin models , is there any easy fix around for it ?

Akshay1-6180 commented 4 months ago

The issue lies here and it gets solved when u do highres_in_channels = encoder_channels[-3] and high_res_features = self.block1(features[-3]) not sure if its a good workaround , would love to hear others opinion

class DeepLabV3PlusDecoder(nn.Module):
    def __init__(
        self,
        encoder_channels,
        out_channels=256,
        atrous_rates=(12, 24, 36),
        output_stride=16,
    ):
        super().__init__()
        if output_stride not in {8, 16}:
            raise ValueError(f"Output stride should be 8 or 16, got {output_stride}.")

        self.out_channels = out_channels
        self.output_stride = output_stride

        self.aspp = nn.Sequential(
            ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
            SeparableConv2d(
                out_channels, out_channels, kernel_size=3, padding=1, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

        scale_factor = 2 if output_stride == 8 else 4
        self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)

        highres_in_channels = encoder_channels[-3]  # Changed from -4 to -3
        highres_out_channels = 48  # proposed by authors of paper
        self.block1 = nn.Sequential(
            nn.Conv2d(
                highres_in_channels, highres_out_channels, kernel_size=1, bias=False
            ),
            nn.BatchNorm2d(highres_out_channels),
            nn.ReLU(),
        )
        self.block2 = nn.Sequential(
            SeparableConv2d(
                highres_out_channels + out_channels,
                out_channels,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, *features):
        aspp_features = self.aspp(features[-1])
        aspp_features = self.up(aspp_features)
        high_res_features = self.block1(features[-3])  # Changed from -4 to -3

        concat_features = torch.cat([aspp_features, high_res_features], dim=1)
        fused_features = self.block2(concat_features)
        return fused_features
Akshay1-6180 commented 4 months ago

It also gets resolved if the scale factor is 8 , but not sure of the far reaching implications of this change while training , needs to be empirically tested. self.up = nn.UpsamplingBilinear2d(scale_factor=8)

Akshay1-6180 commented 4 months ago

But these changes would make it incompatible with encoder_depth=5 , so there should be a way to handle different depth cases

Akshay1-6180 commented 3 months ago

@isaaccorley any idea on how to go about this