isaaccorley / torchseg

Segmentation models with pretrained backbones. PyTorch.
MIT License
104 stars 8 forks source link

Replace ``encoder_weights: str | None = "imagenet"`` with ``pretrained: bool=True`` #49

Open DimitrisMantas opened 4 months ago

DimitrisMantas commented 4 months ago

The encoder_weights parameter in the model initializers is a bit ambiguous/strange, especially given its lax type hint.

At least to me, it looked like you could pass strings to represent other weights (e.g., “coco”, “instagram”, “ssl”, etc.) at first glance, but it actually turns out that any non-null argument results in the timm weights for the particular encoder being loaded. This is because timm really expects a boolean flag for this and --- by extension --- torchseg.encoders.get_encoder() only performs the appropriate checks required to comply:

https://github.com/isaaccorley/torchseg/blob/47a4dd67eb503671c788740178dae8113bf55938/torchseg/encoders/__init__.py?plain=1#L62-L83

I understand that this had some value in smp because it had multiple weights for some models, but timm has only one, so maybe it would make sense to change our API to reflect this.

DimitrisMantas commented 4 months ago

I just saw that right above the block I referenced, the value of encoder_weights is actually used to look-up MixTransformer weights. I understand these are pulled from smp, but they are all ImageNet weights...

isaaccorley commented 4 months ago

This was left in for backwards compatibility. I agree with this though. Particularly because with timm you load the weights through the encoder/model name like vit_b16_224.mae. We can likely add a new pretrained arg and then deprecate the weights arg.