martinsbruveris / tensorflow-image-models

TensorFlow port of PyTorch Image Models (timm) - image models with pretrained weights.
https://tfimm.readthedocs.io/en/latest/
Apache License 2.0
286 stars 25 forks source link

PVT model not training.. #71

Open ma7555 opened 2 years ago

ma7555 commented 2 years ago

Describe the bug PVT model does not train.

To Reproduce Steps to reproduce the behaviour:

import tfimm 
import tensorflow_datasets as tfds
import tensorflow as tf

def resize_normalize(x, y):
    x = tf.image.resize(x, (224, 224)) / 255
    return x, y

train_ds = tfds.load('imagenet_v2', 
               split='test', 
               as_supervised=True)
train_ds = train_ds.map(resize_normalize).batch(32)

model = tfimm.create_model("pvt_tiny", pretrained=None)

model.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.fit(train_ds)
Epoch 1/5
313/313 [==============================] - 67s 187ms/step - loss: 15.6424 - accuracy: 6.0000e-04
Epoch 2/5
313/313 [==============================] - 60s 191ms/step - loss: 16.2130 - accuracy: 0.0010
Epoch 3/5
313/313 [==============================] - 60s 191ms/step - loss: 16.2144 - accuracy: 0.0010
Epoch 4/5
313/313 [==============================] - 60s 191ms/step - loss: 16.2418 - accuracy: 0.0010
Epoch 5/5
313/313 [==============================] - 60s 191ms/step - loss: 16.2417 - accuracy: 0.0010

Expected behaviour Convergance of model

Desktop (please complete the following information):

Also note that setting the LR to 1e-4 as the paper does not solve the problem.