MobileNet-V3 has a Conv2D layer after the pooling layer, which cuts off the last bit of the model through the default function
Here's the last bit of the actual body:
The model is cut correctly if you pass in cut=-1 though. However, I don't think most users would be aware of this beforehand. So, maybe something like this would be nice:
def create_timm_body(arch:str, pretrained=True, cut=None, n_in=3):
"Creates a body from any model in the `timm` library."
if 'mobilenetv3' in arch and cut is None:
raise ValueError(f"Due to the special architecture of MobNet-V3, you need to pass cut=-1 to use it")
model = create_model(arch, pretrained=pretrained, num_classes=0, global_pool='', exportable=True)
_update_first_layer(model, n_in, pretrained)
if cut is None:
ll = list(enumerate(model.children()))
cut = next(i for i,o in reversed(ll) if has_pool_type(o))
if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
elif callable(cut): return cut(model)
else: raise NamedError("cut must be either integer or function")
Or you could also hard-code it and pass a warning message that cut is being overriden if the user is passing in a value
A 1x1 conv layer is equivalent to a Linear layer, so besides losing one layer of pretrained weights there's not really a difference here (especially since it's after a pooling layer)
MobileNet-V3 has a Conv2D layer after the pooling layer, which cuts off the last bit of the model through the default function
Here's the last bit of the actual body:
The model is cut correctly if you pass in
cut=-1
though. However, I don't think most users would be aware of this beforehand. So, maybe something like this would be nice:Or you could also hard-code it and pass a warning message that
cut
is being overriden if the user is passing in a value