walkwithfastai / walkwithfastai.github.io

Host for https://walkwithfastai.com
Other
143 stars 53 forks source link

MobileNet-V3 Compatibility with timm #11

Closed rsomani95 closed 3 years ago

rsomani95 commented 3 years ago

MobileNet-V3 has a Conv2D layer after the pooling layer, which cuts off the last bit of the model through the default function

Here's the last bit of the actual body:

Screenshot 2020-11-05 at 8 48 59 AM

The model is cut correctly if you pass in cut=-1 though. However, I don't think most users would be aware of this beforehand. So, maybe something like this would be nice:

def create_timm_body(arch:str, pretrained=True, cut=None, n_in=3):
    "Creates a body from any model in the `timm` library."
    if 'mobilenetv3' in arch and cut is None:
        raise ValueError(f"Due to the special architecture of MobNet-V3, you need to pass cut=-1 to use it")
    model = create_model(arch, pretrained=pretrained, num_classes=0, global_pool='', exportable=True)
    _update_first_layer(model, n_in, pretrained)
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut): return cut(model)
    else: raise NamedError("cut must be either integer or function")

Or you could also hard-code it and pass a warning message that cut is being overriden if the user is passing in a value

muellerzr commented 3 years ago

A 1x1 conv layer is equivalent to a Linear layer, so besides losing one layer of pretrained weights there's not really a difference here (especially since it's after a pooling layer)