martinsbruveris / tensorflow-image-models

TensorFlow port of PyTorch Image Models (timm) - image models with pretrained weights.
https://tfimm.readthedocs.io/en/latest/
Apache License 2.0
287 stars 25 forks source link
imagenet tensorflow vision-transformer

TensorFlow Image Models

Test Status Documentation Status License Slack

Introduction

TensorFlow Image Models (tfimm) is a collection of image models with pretrained weights, obtained by porting architectures from timm to TensorFlow. The hope is that the number of available architectures will grow over time. For now, it contains vision transformers (ViT, DeiT, CaiT, PVT and Swin Transformers), MLP-Mixer models (MLP-Mixer, ResMLP, gMLP, PoolFormer and ConvMixer), various ResNet flavours (ResNet, ResNeXt, ECA-ResNet, SE-ResNet), the EfficientNet family (including AdvProp, NoisyStudent, Edge-TPU, V2 and Lite versions), MobileNet-V2, VGG, as well as the recent ConvNeXt. tfimm has now expanded beyond classification and also includes Segment Anything.

This work would not have been possible wihout Ross Wightman's timm library and the work on PyTorch/TensorFlow interoperability in HuggingFace's transformer repository. I tried to make sure all source material is acknowledged. Please let me know if I have missed something.

Usage

Installation

The package can be installed via pip,

pip install tfimm

To load pretrained weights, timm needs to be installed separately.

Creating models

To load pretrained models use

import tfimm

model = tfimm.create_model("vit_tiny_patch16_224", pretrained="timm")

We can list available models with pretrained weights via

import tfimm

print(tfimm.list_models(pretrained="timm"))

Most models are pretrained on ImageNet or ImageNet-21k. If we want to use them for other tasks we need to change the number of classes in the classifier or remove the classifier altogether. We can do this by setting the nb_classes parameter in create_model. If nb_classes=0, the model will have no classification layer. If nb_classes is set to a value different from the default model config, the classification layer will be randomly initialized, while all other weights will be copied from the pretrained model.

The preprocessing function for each model can be created via

import tensorflow as tf
import tfimm

preprocess = tfimm.create_preprocessing("vit_tiny_patch16_224", dtype="float32")
img = tf.ones((1, 224, 224, 3), dtype="uint8")
img_preprocessed = preprocess(img)

Saving and loading models

All models are subclassed from tf.keras.Model (they are not functional models). They can still be saved and loaded using the SavedModel format.

>>> import tesnorflow as tf
>>> import tfimm
>>> model = tfimm.create_model("vit_tiny_patch16_224")
>>> type(model)
<class 'tfimm.architectures.vit.ViT'>
>>> model.save("/tmp/my_model")
>>> loaded_model = tf.keras.models.load_model("/tmp/my_model")
>>> type(loaded_model)
<class 'tfimm.architectures.vit.ViT'>

For this to work, the tfimm library needs to be imported before the model is loaded, since during the import process, tfimm is registering custom models with Keras. Otherwise, we obtain the following output

>>> import tensorflow as tf
>>> loaded_model = tf.keras.models.load_model("/tmp/my_model")
>>> type(loaded_model)
<class 'keras.saving.saved_model.load.Custom>ViT'>

Models

The following architectures are currently available:

Loading pytorch models from HF hub

It is possible to load pre-trained model weights from the HF hub. See the huggingface-model-weights notebook for details. For this to work, it is important that the weight names and shapes on HF hub are compatible with one of the tfimm model configurations.

Profiling

To understand how big each of the models is, I have done some profiling to measure

The results can be found in the results/profiling_{k80, v100}.csv files.

For backpropagation, we use as loss the mean of model outputs

def backprop():
    with tf.GradientTape() as tape:
        output = model(x, training=True)
        loss = tf.reduce_mean(output)
        grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Contact

All things related to tfimm can be discussed via Slack.