Open DimitrisMantas opened 4 months ago
I just saw that right above the block I referenced, the value of encoder_weights
is actually used to look-up MixTransformer weights. I understand these are pulled from smp, but they are all ImageNet weights...
This was left in for backwards compatibility. I agree with this though. Particularly because with timm you load the weights through the encoder/model name like vit_b16_224.mae
. We can likely add a new pretrained
arg and then deprecate the weights
arg.
The
encoder_weights
parameter in the model initializers is a bit ambiguous/strange, especially given its lax type hint.At least to me, it looked like you could pass strings to represent other weights (e.g., “coco”, “instagram”, “ssl”, etc.) at first glance, but it actually turns out that any non-null argument results in the timm weights for the particular encoder being loaded. This is because timm really expects a boolean flag for this and --- by extension ---
torchseg.encoders.get_encoder()
only performs the appropriate checks required to comply:https://github.com/isaaccorley/torchseg/blob/47a4dd67eb503671c788740178dae8113bf55938/torchseg/encoders/__init__.py?plain=1#L62-L83
I understand that this had some value in smp because it had multiple weights for some models, but timm has only one, so maybe it would make sense to change our API to reflect this.