TorchSeg is an actively maintained and up-to-date fork of the Segmentation Models PyTorch (smp) library.
pip install torchseg
The goal of this fork is to 1) provide maintenance support for the original library and 2) add features relevant to modern semantic segmentation. Since the fork, this library has added some features which can be summarized below:
ResNet
, EfficientNet
, etc., but now extends to include modern architectures like ConvNext
, Swin
, PoolFormer
, MaxViT
and more!ViT
, DeiT
, FlexiViT
!Additionally we have performed the following for improved software standards:
ruff
and mypy
torch
, timm
, and einops
)torchmetrics
and timm
The main features of this library are:
TorchSeg models at their base are just torch nn.Modules. They can be created as follows:
import torchseg
model = torchseg.Unet(
encoder_name="resnet50",
encoder_weights=True,
in_channels=3,
classes=3,
)
TorchSeg has an encoder_params
feature which passes additional parameters to timm.create_model()
when defining an encoder backbone. One can specify different activitions, normalization layers, and more like below.
You can also define a functools.partial
callable as an activation/normalization layer. See the timm docs for more information on available activations and normalization layers. You can even used pretrained weights while changing the activations/normalizations!
model = torchseg.Unet(
encoder_name="resnet50",
encoder_weights=True,
in_channels=3,
classes=3,
encoder_params={
"act_layer": "prelu",
"norm_layer": "layernorm"
}
)
Some models like Swin
and ConvNext
perform a downsampling of scale=4 in the first block (stem) and then downsample by 2 afterwards with only depth=4
blocks. This results in an output size of half after the decoder. To get the same output size as the input you can pass head_upsampling=2
which will upsample once more prior to the segmentation head.
model = torchseg.Unet(
"convnextv2_tiny",
in_channels=3,
classes=2,
encoder_weights=True,
encoder_depth=4,
decoder_channels=(256, 128, 64, 32),
head_upsampling=2
)
model = torchseg.Unet(
"swin_tiny_patch4_window7_224",
in_channels=3,
classes=2,
encoder_weights=True,
encoder_depth=4,
decoder_channels=(256, 128, 64, 32),
head_upsampling=2,
encoder_params={"img_size": 256} # need to define img size since swin is a ViT hybrid
)
model = torchseg.Unet(
"maxvit_small_tf_224",
in_channels=3,
classes=2,
encoder_weights=True,
encoder_depth=5,
decoder_channels=(256, 128, 64, 32, 16),
encoder_params={"img_size": 256}
)
TorchSeg supports pretrained ViT encoders from timm by extracting intermediate transformer block features specified by the encoder_indices
and encoder_depth
arguments.
You will also need to define scale_factors
for upsampling the feature layers to the resolutions expected by the decoders. For U-Net depth=5
this would be scales=(8, 4, 2, 1, 0.5)
. For depth=4
this would be scales=(4, 2, 1, 0.5)
, for depth=3
this would be scales=(2, 1, 0.5)
and so on.
Another benefit of using timm is that by passing in a new img_size
, timm automatically interpolates the ViT positional embeddings to work with your new image size which creates a different number of patch tokens.
import torch
import torchseg
model = torchseg.Unet(
"vit_small_patch16_224",
in_channels=8,
classes=2,
encoder_depth=5,
encoder_indices=(2, 4, 6, 8, 10), # which intermediate blocks to extract features from
encoder_weights=True,
decoder_channels=(256, 128, 64, 32, 16),
encoder_params={ # additional params passed to timm.create_model and the vit encoder
"scale_factors": (8, 4, 2, 1, 0.5), # resize scale_factors for patch size 16 and 5 layers
"img_size": 256, # timm automatically interpolates the positional embeddings to your new image size
},
)
TorchSeg relies entirely on the timm library for pretrained encoder support. This means that TorchSeg supports any timm model which has features_only
feature extraction functionality. Additionally we support any ViT models with a get_intermediate_layers
method. This results in a total of 852/1017 (~84%) encoders from timm including ResNet
, Swin
, ConvNext
, ViT
, and more!
To list the following supported encoders:
import torchseg
torchseg.list_encoders()
We have additionally pulled the the feature extractor metadata of each model with features_only
support from timm at output_stride=32
. This metadata provides information such as the number of intermediate layers, channels for each layer, layer name, and downsampling reduction.
import torchseg
metadata = torchseg.encoders.TIMM_ENCODERS["convnext_base"]
print(metadata)
"""
{
'channels': [128, 256, 512, 1024],
'indices': (0, 1, 2, 3),
'module': ['stages.0', 'stages.1', 'stages.2', 'stages.3'],
'reduction': [4, 8, 16, 32],
}
"""
metadata = torchseg.encoders.TIMM_ENCODERS["resnet50"]
print(metadata)
"""
{
'channels': [64, 256, 512, 1024, 2048],
'indices': (0, 1, 2, 3, 4),
'module': ['act1', 'layer1', 'layer2', 'layer3', 'layer4'],
'reduction': [2, 4, 8, 16, 32]
}
"""
model.encoder
- pretrained backbone to extract intermediate featuresmodel.decoder
- network for processing the intermediate features to the original image resolution (Unet
, DeepLabv3+
, FPN
)model.segmentation_head
- final block producing the mask output (includes optional upsampling and activation)model.classification_head
- optional block which create classification head on top of encodermodel.forward(x)
- sequentially pass x
through model`s encoder, decoder and segmentation head (and classification head if specified)Timm encoders supports the use of pretrained weights with arbitrary input channels by repeating weights for channels if > 3. For example, if in_channels=6
, RGB ImageNet pretrained weights in the initial layer would be repeated RGBRGB
to avoid random initialization. For in_channels=7
this would result in RGBRGBR
. Below is a diagram to visualize this method.
All models support an optional auxiliary classifier head through the use of aux_params
. If aux_params != None
then the
model will produce the a label
output in addition to the mask
output with shape (N, C)
.
Classification head consists of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be
configured by aux_params
as follows:
aux_params=dict(
pooling='avg', # one of 'avg', 'max'
dropout=0.5, # dropout ratio, default is None
activation=nn.Sigmoid(), # activation function, default is Identity
classes=4, # define number of output labels
)
model = torchseg.Unet('resnet18', classes=4, aux_params=aux_params)
mask, label = model(x)
Depth represents the number of downsampling operations in the encoder, so you can make
your model lighter by specifying less depth
. Defaults to depth=5
.
Note that some models like ConvNext
and Swin
only have 4 intermediate feature blocks. Therefore, to use these encoders set encoder_depth=4
. This can be found in the metadata above.
model = torchseg.Unet('resnet50', encoder_depth=4)
We welcome new contributions for modern semantic segmentation models, losses, and methods!
For development you can install the required dependencies using `pip install '.[all]'.
To format files run ruff format
. To check for linting errors run ruff check
.
To run tests use pytest -ra