martinsbruveris / tensorflow-image-models

TensorFlow port of PyTorch Image Models (timm) - image models with pretrained weights.
https://tfimm.readthedocs.io/en/latest/
Apache License 2.0
287 stars 25 forks source link

Adding convnext edge models #76

Open scrouzet opened 1 year ago

scrouzet commented 1 year ago

Many versions of ConvNeXt are now available pretrained in timm.

To be able to load them in tfimm, the only code to add in: https://github.com/martinsbruveris/tensorflow-image-models/blob/b6742e455fe0d9a550f829917a8cef68000831b5/tfimm/architectures/convnext.py#L439

Would be the following:

@register_model
def convnext_atto():
    cfg = ConvNeXtConfig(
        name="convnext_atto",
        url="[timm]",
        embed_dim=(40, 80, 160, 320),
        nb_blocks=(2, 2, 6, 2),
    )
    return ConvNeXt, cfg

@register_model
def convnext_femto():
    cfg = ConvNeXtConfig(
        name="convnext_femto",
        url="[timm]",
        embed_dim=(48,  96,  192,  384),
        nb_blocks=(2, 2, 6, 2),
    )
    return ConvNeXt, cfg

@register_model
def convnext_pico():
    cfg = ConvNeXtConfig(
        name="convnext_pico",
        url="[timm]",
        embed_dim=(64, 128,  256,  512),
        nb_blocks=(2, 2, 6, 2),
    )
    return ConvNeXt, cfg

@register_model
def convnext_nano():
    cfg = ConvNeXtConfig(
        name="convnext_nano",
        url="[timm]",
        embed_dim=(80, 160,  320,  640),
        nb_blocks=(2, 2, 8, 2),
    )
    return ConvNeXt, cfg

I've tested it locally and it works perfectly. Thanks in advance

martinsbruveris commented 1 year ago

Support for many more ConvNeXt models will be added in the next release (work in this PR). It will be a minor release, since adding support for ConvNeXt-V2 requires changes to the config and so breaks compatibility with previously saved models.