TensorFlow Image Models (tfimm
) is a collection of image models with pretrained
weights, obtained by porting architectures from
timm to TensorFlow. The hope is
that the number of available architectures will grow over time. For now, it contains
vision transformers (ViT, DeiT, CaiT, PVT and Swin Transformers), MLP-Mixer models
(MLP-Mixer, ResMLP, gMLP, PoolFormer and ConvMixer), various ResNet flavours (ResNet,
ResNeXt, ECA-ResNet, SE-ResNet), the EfficientNet family (including AdvProp,
NoisyStudent, Edge-TPU, V2 and Lite versions), MobileNet-V2, VGG, as well as the recent
ConvNeXt. tfimm
has now expanded beyond classification and also includes Segment
Anything.
This work would not have been possible wihout Ross Wightman's timm
library and the
work on PyTorch/TensorFlow interoperability in HuggingFace's transformer
repository.
I tried to make sure all source material is acknowledged. Please let me know if I have
missed something.
The package can be installed via pip
,
pip install tfimm
To load pretrained weights, timm
needs to be installed separately.
To load pretrained models use
import tfimm
model = tfimm.create_model("vit_tiny_patch16_224", pretrained="timm")
We can list available models with pretrained weights via
import tfimm
print(tfimm.list_models(pretrained="timm"))
Most models are pretrained on ImageNet or ImageNet-21k. If we want to use them for other
tasks we need to change the number of classes in the classifier or remove the
classifier altogether. We can do this by setting the nb_classes
parameter in
create_model
. If nb_classes=0
, the model will have no classification layer. If
nb_classes
is set to a value different from the default model config, the
classification layer will be randomly initialized, while all other weights will be
copied from the pretrained model.
The preprocessing function for each model can be created via
import tensorflow as tf
import tfimm
preprocess = tfimm.create_preprocessing("vit_tiny_patch16_224", dtype="float32")
img = tf.ones((1, 224, 224, 3), dtype="uint8")
img_preprocessed = preprocess(img)
All models are subclassed from tf.keras.Model
(they are not functional models).
They can still be saved and loaded using the SavedModel
format.
>>> import tesnorflow as tf
>>> import tfimm
>>> model = tfimm.create_model("vit_tiny_patch16_224")
>>> type(model)
<class 'tfimm.architectures.vit.ViT'>
>>> model.save("/tmp/my_model")
>>> loaded_model = tf.keras.models.load_model("/tmp/my_model")
>>> type(loaded_model)
<class 'tfimm.architectures.vit.ViT'>
For this to work, the tfimm
library needs to be imported before the model is loaded,
since during the import process, tfimm
is registering custom models with Keras.
Otherwise, we obtain the following output
>>> import tensorflow as tf
>>> loaded_model = tf.keras.models.load_model("/tmp/my_model")
>>> type(loaded_model)
<class 'keras.saving.saved_model.load.Custom>ViT'>
The following architectures are currently available:
It is possible to load pre-trained model weights from the HF hub. See the
huggingface-model-weights notebook for
details. For this to work, it is important that the weight names and shapes on
HF hub are compatible with one of the tfimm
model configurations.
To understand how big each of the models is, I have done some profiling to measure
float32
and mixed precision.The results can be found in the results/profiling_{k80, v100}.csv
files.
For backpropagation, we use as loss the mean of model outputs
def backprop():
with tf.GradientTape() as tape:
output = model(x, training=True)
loss = tf.reduce_mean(output)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
This repository is released under the Apache 2.0 license as found in the LICENSE file.
All things related to tfimm
can be discussed via
Slack.