tensorflow / models

Models and examples built with TensorFlow
Other
76.98k stars 45.79k forks source link

Resnet50 pretrained model for fine tuning, the model is not convergence #9926

Open gganduu opened 3 years ago

gganduu commented 3 years ago

Hi,

I'm using pretrained Resnet50 model for my own data's training. The model is not convergence even the train accuracy looks good shown by the log, and validation loss and accuracy is not improved during the training phase.

I also test the trained model on training and val set, the accuracy is very pool(see below)

And I try the tensorflow 1.15.0 and 2.4.0 different version, the problem is the same. Then I just change to VGG model, it works fine(no convergence problem). So could help on this issue?

My code is : `

base_model = tf.keras.applications.ResNet50(include_top=False)
base_model.trainable = False

model = tf.keras.models.Sequential([
    base_model,
    #tf.keras.layers.Conv2D(filters=num_cat, kernel_size=1),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(units=num_cat)
])
model.summary()
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
loss_func = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
eval_func = tf.keras.metrics.CategoricalAccuracy()

model.compile(
    optimizer=optimizer,
    loss=loss_func,
    metrics=[eval_func]
)
history = model.fit(train_ds, epochs=10, validation_data=val_ds)
model.save_weights('./checkpoints/final')`

And the training log is:

Train on 156 steps, validate on 39 steps Epoch 1/10 156/156 [==============================] - 183s 1s/step - loss: 1.8214 - categorical_accuracy: 0.5048 - val_loss: 3.5411 - val_categorical_accuracy: 0.0386 Epoch 2/10 156/156 [==============================] - 33s 211ms/step - loss: 0.4409 - categorical_accuracy: 0.8931 - val_loss: 3.6663 - val_categorical_accuracy: 0.0386 Epoch 3/10 156/156 [==============================] - 34s 219ms/step - loss: 0.2582 - categorical_accuracy: 0.9365 - val_loss: 3.8821 - val_categorical_accuracy: 0.0386 Epoch 4/10 156/156 [==============================] - 32s 203ms/step - loss: 0.1666 - categorical_accuracy: 0.9550 - val_loss: 3.9013 - val_categorical_accuracy: 0.0386 Epoch 5/10 156/156 [==============================] - 31s 201ms/step - loss: 0.1212 - categorical_accuracy: 0.9630 - val_loss: 4.2440 - val_categorical_accuracy: 0.0386 Epoch 6/10 156/156 [==============================] - 31s 201ms/step - loss: 0.0826 - categorical_accuracy: 0.9759 - val_loss: 4.2431 - val_categorical_accuracy: 0.0386 Epoch 7/10 156/156 [==============================] - 31s 198ms/step - loss: 0.0648 - categorical_accuracy: 0.9807 - val_loss: 4.3009 - val_categorical_accuracy: 0.0514 Epoch 8/10 156/156 [==============================] - 32s 205ms/step - loss: 0.0573 - categorical_accuracy: 0.9823 - val_loss: 4.3420 - val_categorical_accuracy: 0.0386 Epoch 9/10 156/156 [==============================] - 31s 196ms/step - loss: 0.0548 - categorical_accuracy: 0.9839 - val_loss: 4.4843 - val_categorical_accuracy: 0.0386 Epoch 10/10 156/156 [==============================] - 31s 200ms/step - loss: 0.0478 - categorical_accuracy: 0.9887 - val_loss: 4.7390 - val_categorical_accuracy: 0.0386

Run inference on training data and validation data: 156/156 [==============================] - 35s 227ms/step - loss: 4.7357 - categorical_accuracy: 0.0386 39/39 [==============================] - 11s 279ms/step - loss: 4.7326 - categorical_accuracy: 0.0386 Train Loss: 4.735676199961931; Train Acc: 0.03858520835638046 Val Loss: 4.732603843395527; Val Acc: 0.03858520835638046

saikumarchalla commented 3 years ago

@gganduu Closing this issue since the issue is not related to models.Please submit a new issue on Tensorflow core repo.Thanks!

google-ml-butler[bot] commented 3 years ago

Are you satisfied with the resolution of your issue? Yes No

gganduu commented 3 years ago

@saikumarchalla Actually I raised this issue in core first, and they let me raise here. And I think this is the resnet50 model problem, because vgg does not have the problem like this. could you help on this?

ymodak commented 3 years ago

Can you please provide a reproducible example using dummy data to validate the issue? I have tried using the ResNet50 model with flowers dataset the model converges successfully.

gganduu commented 3 years ago

Can you please provide a reproducible example using dummy data to validate the issue? I have tried using the ResNet50 model with flowers dataset the model converges successfully.

Thanks your reply ymodak! Please use the link to download the data: https://download.pytorch.org/tutorial/hymenoptera_data.zip my code:

` import tensorflow as tf from pathlib import Path import matplotlib.pyplot as plt

def argument_train(image, label): image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.central_crop(image, central_fraction=0.85) image = tf.image.resize(image, size=[224, 224]) image = tf.image.random_flip_left_right(image)

image = tf.image.random_brightness(image, max_delta=0.2)

# image = tf.image.random_contrast(image, lower=0.0, upper=0.2)
# tf.image.random_saturation(image, lower=0.0, upper=0.2)
return image, label

def argument_val(image, label): image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.resize(image, size=[224, 224]) return image, label

def decode_image(path, label): image = tf.io.read_file(path) image = tf.image.decode_image(image, channels=3)

must set shape specific here, due to using decode_image. It will handle gif format different with others.

# or you can use specific API instead, e.g. decode_jpeg
image.set_shape([None, None, 3])
return image, label

def gen_ds(root): label_names = [p.name for p in Path(root).iterdir() if p.is_dir()] label_dict = {n: i for i, n in enumerate(label_names)} image_paths = [str(p) for n in label_names for p in (Path(root)/n).glob('.jpg')] image_labels = [label_dict.get(p.parent.name) for n in label_names for p in Path(Path(root)/n).glob('.jpg')] label_encoded = tf.one_hot(image_labels, depth=len(label_names), dtype=tf.float32) ds = tf.data.Dataset.from_tensor_slices((image_paths, label_encoded)) return ds.map(decode_image, num_parallel_calls=tf.data.experimental.AUTOTUNE), len(label_names)

if name == 'main': image_root = { 'train': '../datasets/hymenoptera_data/train', 'val': '../datasets/hymenoptera_data/val' } train_ds, num_cat = gen_ds(image_root.get('train')) valds, = gen_ds(image_root.get('val')) train_ds = train_ds.map(argument_train, num_parallel_calls=tf.data.experimental.AUTOTUNE) val_ds = val_ds.map(argument_val, num_parallel_calls=tf.data.experimental.AUTOTUNE) train_ds = train_ds.cache().shuffle(buffer_size=500).batch(8) val_ds = val_ds.cache().shuffle(buffer_size=500).batch(8)

base_model = tf.keras.applications.ResNet50V2(include_top=False)
base_model.trainable = False

model = tf.keras.models.Sequential([
    base_model,
    #tf.keras.layers.Conv2D(filters=num_cat, kernel_size=1),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(units=num_cat)
])
model.summary()
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
loss_func = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
eval_func = tf.keras.metrics.CategoricalAccuracy()
eval_func_train = tf.keras.metrics.CategoricalAccuracy()
eval_func_val = tf.keras.metrics.CategoricalAccuracy()

model.compile(
    optimizer=optimizer,
    loss=loss_func,
    metrics=[eval_func]
)

history = model.fit(train_ds, epochs=10, validation_data=val_ds)
model.save_weights('./checkpoints/final')
plt.subplot(211)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss')
plt.ylim(0)
plt.legend(loc='best')
plt.subplot(212)
plt.plot(history.history['categorical_accuracy'], label='Train Accuracy')
plt.plot(history.history['val_categorical_accuracy'], label='Val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy')
plt.ylim(min(history.history['val_categorical_accuracy']), 1)
plt.legend(loc='best')
plt.show()`
ymodak commented 3 years ago

Have you considered switching the to different optimizers, learning rate and loss functions?

google-ml-butler[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

gganduu commented 3 years ago

Sure, I tried different optimizer, learning rate, I don't think it caused by this. Have you try my code and repeat this issue?

sachinprasadhs commented 3 years ago

Since you have already tried with different hyper parameter tuning and getting the convergence issue, one more thing you could do is to shuffle the entire data before train/test/val split to have good distribution of your data and try running the model again. If you still get the convergence issues, it could due to below reason.

Since, Resnet is mainly used for complex data and to avoid Vanishing gradient problem you can choose the other network if the data is simple or if it is working well, like VGG in your case. It is not necessary that all the network should perform well on your data, due to many reasons like Architectural changes of the network.

google-ml-butler[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

gganduu commented 2 years ago

Since you have already tried with different hyper parameter tuning and getting the convergence issue, one more thing you could do is to shuffle the entire data before train/test/val split to have good distribution of your data and try running the model again. If you still get the convergence issues, it could due to below reason.

  • Resnet is comparatively deeper network than VGG network, deeper neural tends to over-complicate things sometimes, the more complex network you use on a simple data, more are the chances of your model not converging.

Since, Resnet is mainly used for complex data and to avoid Vanishing gradient problem you can choose the other network if the data is simple or if it is working well, like VGG in your case. It is not necessary that all the network should perform well on your data, due to many reasons like Architectural changes of the network.

To be honest, I don't think it's because of complex data. I use Resnet in pytorch, it works fine. So I think it's model problem or maybe caused by some wrong configuration in my side. Actually you could just copy my code and repeat the same problem.