qubvel-org / segmentation_models.pytorch

Semantic segmentation models with 500+ pretrained convolutional and transformer-based backbones.
https://smp.readthedocs.io/
MIT License
9.58k stars 1.67k forks source link

Allow kwargs in TimmUniversalEncoder #954

Open DimitrisMantas opened 5 days ago

DimitrisMantas commented 5 days ago

In my opinion, the most attractive use case for certain timm encoders such as ResNet-18, which is also available in torchvision, is that timm generally allows for various additional configuration parameters to be passed to the constructor, such as anti-aliasing, attention, and stochastic depth.

However, smp does not support this this feature at the moment. This is because TimmUniveralEncoder has a local list of kwargs that it passed off to timm and does not accept any others in its initializer.

A very easy fix for this would be to allow the initializer to accept its own kwargs and join them to the corresponding local variable before making any calls to timm.

JulienMaille commented 4 days ago

Sorry for the OT comment but may I ask how you add anti-aliasing to ResNet through timm?

DimitrisMantas commented 4 days ago

The ResNet constructor accepts an aa_layer argument (https://github.com/huggingface/pytorch-image-models/blob/310ffa32c5758474b0a4481e5db1494dd419aa23/timm/models/resnet.py#L405), which you can set to timm.layers.BlurPool2d (https://github.com/huggingface/pytorch-image-models/blob/310ffa32c5758474b0a4481e5db1494dd419aa23/timm/layers/blur_pool.py#L20)

DimitrisMantas commented 4 days ago

This is what I'm suggesting, the super call should stay as is I think and the local kwargs should be joined to the new kwargs or passed individually to timm.create_model

JulienMaille commented 4 days ago

Have you tried something like this?

class TimmUniversalEncoder(nn.Module):
    def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32, **kwargs):
        super().__init__()

        # Initialize default kwargs
        default_kwargs = dict(
            in_chans=in_channels,
            features_only=True,
            output_stride=output_stride,
            pretrained=pretrained,
            out_indices=tuple(range(depth)),
        )

        # update with any provided kwargs
        default_kwargs.update(kwargs)
DimitrisMantas commented 4 days ago

Not yet, but it should work