isaaccorley / torchseg

Segmentation models with pretrained backbones. PyTorch.
MIT License
104 stars 8 forks source link

TorchSeg

TorchSeg is an actively maintained and up-to-date fork of the Segmentation Models PyTorch (smp) library.

Install

pip install torchseg

Updates

The goal of this fork is to 1) provide maintenance support for the original library and 2) add features relevant to modern semantic segmentation. Since the fork, this library has added some features which can be summarized below:

Additionally we have performed the following for improved software standards:

Features

The main features of this library are:

Example Usage

TorchSeg models at their base are just torch nn.Modules. They can be created as follows:

import torchseg

model = torchseg.Unet(
    encoder_name="resnet50",
    encoder_weights=True,
    in_channels=3,
    classes=3,
)

TorchSeg has an encoder_params feature which passes additional parameters to timm.create_model() when defining an encoder backbone. One can specify different activitions, normalization layers, and more like below.

You can also define a functools.partial callable as an activation/normalization layer. See the timm docs for more information on available activations and normalization layers. You can even used pretrained weights while changing the activations/normalizations!

model = torchseg.Unet(
    encoder_name="resnet50",
    encoder_weights=True,
    in_channels=3,
    classes=3,
    encoder_params={
      "act_layer": "prelu",
      "norm_layer": "layernorm"
    }
)

Some models like Swin and ConvNext perform a downsampling of scale=4 in the first block (stem) and then downsample by 2 afterwards with only depth=4 blocks. This results in an output size of half after the decoder. To get the same output size as the input you can pass head_upsampling=2 which will upsample once more prior to the segmentation head.

model = torchseg.Unet(
    "convnextv2_tiny",
    in_channels=3,
    classes=2,
    encoder_weights=True,
    encoder_depth=4,
    decoder_channels=(256, 128, 64, 32),
    head_upsampling=2
)

model = torchseg.Unet(
    "swin_tiny_patch4_window7_224",
    in_channels=3,
    classes=2,
    encoder_weights=True,
    encoder_depth=4,
    decoder_channels=(256, 128, 64, 32),
    head_upsampling=2,
    encoder_params={"img_size": 256}  # need to define img size since swin is a ViT hybrid
)

model = torchseg.Unet(
    "maxvit_small_tf_224",
    in_channels=3,
    classes=2,
    encoder_weights=True,
    encoder_depth=5,
    decoder_channels=(256, 128, 64, 32, 16),
    encoder_params={"img_size": 256}
)

TorchSeg supports pretrained ViT encoders from timm by extracting intermediate transformer block features specified by the encoder_indices and encoder_depth arguments.

You will also need to define scale_factors for upsampling the feature layers to the resolutions expected by the decoders. For U-Net depth=5 this would be scales=(8, 4, 2, 1, 0.5). For depth=4 this would be scales=(4, 2, 1, 0.5), for depth=3 this would be scales=(2, 1, 0.5) and so on.

Another benefit of using timm is that by passing in a new img_size, timm automatically interpolates the ViT positional embeddings to work with your new image size which creates a different number of patch tokens.

import torch
import torchseg

model = torchseg.Unet(
    "vit_small_patch16_224",
    in_channels=8,
    classes=2,
    encoder_depth=5,
    encoder_indices=(2, 4, 6, 8, 10),  # which intermediate blocks to extract features from
    encoder_weights=True,
    decoder_channels=(256, 128, 64, 32, 16),
    encoder_params={  # additional params passed to timm.create_model and the vit encoder
        "scale_factors": (8, 4, 2, 1, 0.5), # resize scale_factors for patch size 16 and 5 layers
        "img_size": 256,  # timm automatically interpolates the positional embeddings to your new image size
    },
)

Models

Architectures (Decoders)

Encoders

TorchSeg relies entirely on the timm library for pretrained encoder support. This means that TorchSeg supports any timm model which has features_only feature extraction functionality. Additionally we support any ViT models with a get_intermediate_layers method. This results in a total of 852/1017 (~84%) encoders from timm including ResNet, Swin, ConvNext, ViT, and more!

To list the following supported encoders:

import torchseg

torchseg.list_encoders()

We have additionally pulled the the feature extractor metadata of each model with features_only support from timm at output_stride=32. This metadata provides information such as the number of intermediate layers, channels for each layer, layer name, and downsampling reduction.

import torchseg

metadata = torchseg.encoders.TIMM_ENCODERS["convnext_base"]
print(metadata)

"""
{
   'channels': [128, 256, 512, 1024],
   'indices': (0, 1, 2, 3),
   'module': ['stages.0', 'stages.1', 'stages.2', 'stages.3'],
   'reduction': [4, 8, 16, 32],
}
"""

metadata = torchseg.encoders.TIMM_ENCODERS["resnet50"]
print(metadata)

"""
{
   'channels': [64, 256, 512, 1024, 2048],
   'indices': (0, 1, 2, 3, 4),
   'module': ['act1', 'layer1', 'layer2', 'layer3', 'layer4'],
   'reduction': [2, 4, 8, 16, 32]
}
"""

Models API

Input channels

Timm encoders supports the use of pretrained weights with arbitrary input channels by repeating weights for channels if > 3. For example, if in_channels=6, RGB ImageNet pretrained weights in the initial layer would be repeated RGBRGB to avoid random initialization. For in_channels=7 this would result in RGBRGBR. Below is a diagram to visualize this method.


Auxiliary Classifier

All models support an optional auxiliary classifier head through the use of aux_params. If aux_params != None then the model will produce the a label output in addition to the mask output with shape (N, C). Classification head consists of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be configured by aux_params as follows:

aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.5,               # dropout ratio, default is None
    activation=nn.Sigmoid(),   # activation function, default is Identity
    classes=4,                 # define number of output labels
)
model = torchseg.Unet('resnet18', classes=4, aux_params=aux_params)
mask, label = model(x)
Depth

Depth represents the number of downsampling operations in the encoder, so you can make your model lighter by specifying less depth. Defaults to depth=5.

Note that some models like ConvNext and Swin only have 4 intermediate feature blocks. Therefore, to use these encoders set encoder_depth=4. This can be found in the metadata above.

model = torchseg.Unet('resnet50', encoder_depth=4)

Contribute

We welcome new contributions for modern semantic segmentation models, losses, and methods!

Install dev dependencies

For development you can install the required dependencies using `pip install '.[all]'.

Code Formatting/Linting

To format files run ruff format. To check for linting errors run ruff check.

Tests

To run tests use pytest -ra