MatchLab-Imperial / keras_triplet_descriptor

Baseline to denoise + learn descriptors in N-HPatches
17 stars 19 forks source link

Triplet loss for descriptor model stays at 1 for InceptionV3 and ResNet #6

Open CliveWongTohSoon opened 5 years ago

CliveWongTohSoon commented 5 years ago

I encountered issue when using triplet loss to train the images with ResNet and InceptionV3. I believe I'm not the only one, according to this keras issue.

A StackOverflow issue here suggests that it could be wrong implementation of triplet loss that causing the loss not decreasing.

Below is how I want to produce the descriptor model, I've tried loading weights from imagenet and setting weights to None. Both result in the loss bouncing at around 1, i.e the network is not learning.

from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.applications.inception_v3 import InceptionV3
from keras.backend import tf as ktf
from keras.layers import Lambda, Input

def get_inception_v3_descriptor_model(shape):
  model = InceptionV3(include_top=False, input_shape=(75,75,1), weights=None)

  # Resize Input images to 75x75
  newInput = Input(batch_shape=(None, shape[0], shape[1], shape[2]))
  resizedImg = Lambda(lambda image: ktf.image.resize_images(image, (75, 75)))(newInput)
  newOutputs = model(resizedImg)
  model = Model(newInput, newOutputs)

#   for layer in model.layers[:]:
#       layer.trainable = True

  output = model.output
  output = GlobalAveragePooling2D()(output)

  # let's add a fully-connected layer
  output = Dense(1024, activation='relu')(output)
  # and a logistic layer of 128 descriptors
  output = Dense(128, activation='softmax')(output)
  return Model(model.input, output)

Could you advise on how do I configure things properly in order to use ResNet and InceptionV3 on the N-HPatches?

alopezgit commented 5 years ago

Hi,

You are using a Softmax activation in the last layer, which is used for classification. Maybe you wanted to use sigmoid instead. If you remove that, it should work!

CliveWongTohSoon commented 5 years ago

After changing the last layer to sigmoid, the loss does decrease but the val loss still maintains at 1.000, i.e

Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.7025100%|██████████| 100000/100000 [00:02<00:00, 36603.41it/s]
2000/2000 [==============================] - 811s 406ms/step - loss: 0.7025 - val_loss: 1.0000
100%|██████████| 10000/10000 [00:00<00:00, 69467.42it/s]
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.6807100%|██████████| 100000/100000 [00:02<00:00, 36599.69it/s]
2000/2000 [==============================] - 779s 389ms/step - loss: 0.6808 - val_loss: 1.0000

I still believe this is the issue with either triplet loss or keras BatchNormalisation error as indicated in previously.

alopezgit commented 5 years ago

Yes, I tried to do it and it seems that it does explode the moving average mean due to a problem with Keras. I will look into it to see if we can find a solution, however, it seems a bug hard to solve. If you want to use a large architecture for this problem, you can try to use another model such as VGG16 that does not use Batch Normalization.

CliveWongTohSoon commented 5 years ago

Running VGG16 also has the same issue:

from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.applications.vgg16 import VGG16
from keras.backend import tf as ktf
from keras.layers import Lambda, Input

def get_vgg16_descriptor_model(shape):
  model = VGG16(include_top=False, input_shape=(32,32,1), weights=None)

  output = model.output

  output = GlobalAveragePooling2D()(output)
  # let's add a fully-connected layer
  output = Dense(1024, activation='relu')(output)
  output = Dense(128, activation='sigmoid')(output)

  return Model(model.input, output)
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 1.0000100%|██████████| 100000/100000 [00:02<00:00, 41341.10it/s]
2000/2000 [==============================] - 479s 240ms/step - loss: 1.0000 - val_loss: 1.0000
Epoch 1/1
 120/2000 [>.............................] - ETA: 7:08 - loss: 1.0000

However, using weights from imagenet does allow for training to occur:

from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.applications.vgg16 import VGG16
from keras.backend import tf as ktf
from keras.layers import Lambda, Input

def get_vgg16_descriptor_model(shape):
  model = VGG16(include_top=False, input_shape=(32,32,3), weights='imagenet')

  newInput = Input(batch_shape=(None, shape[0], shape[1], shape[2]))
  if shape[2] == 1:
    resizedImg = Lambda(lambda image: ktf.image.grayscale_to_rgb(image))(newInput)
  newOutputs = model(resizedImg)
  model = Model(newInput, newOutputs)

  for layer in model.layers[:]:
      layer.trainable = False

  output = model.output
  output = GlobalAveragePooling2D()(output)
  # let's add a fully-connected layer
  output = Dense(1024, activation='relu')(output)
  output = Dropout(0.3)(output)
  output = Dense(128, activation='sigmoid')(output)

  return Model(model.input, output)
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.6935100%|██████████| 100000/100000 [00:02<00:00, 42694.65it/s]
2000/2000 [==============================] - 169s 84ms/step - loss: 0.6935 - val_loss: 0.7135
100%|██████████| 10000/10000 [00:00<00:00, 71184.37it/s]
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.6824100%|██████████| 100000/100000 [00:02<00:00, 41120.87it/s]

On a side note, after we populated the training generator, we are only left with 1-2GB of RAM, which often crashes in the middle of training. Is there any way to get around with it? It's been really time consuming to restart runtime every time the RAM overloads.

alopezgit commented 5 years ago

Yes, it is always more stable to start with ImageNet weights. As for the RAM issue, you can delete the denoising generators as you are not using them for training the descriptor, that should free some GBs of RAM.

CliveWongTohSoon commented 5 years ago

I'm not sure whether this is the right fix, but there is a temporary fix provided here, and I managed to get InceptionV3 and ResNet to train with weights initialised from imagenet.

Running these two lines:

!pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@fork/keras2.2.4
!pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@bugfix/trainable_bn

And initialise ResNet with weights from 'imagenet':

from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.applications.resnet50 import ResNet50
from keras.backend import tf as ktf
from keras.layers import Lambda, Input

def get_resnet_50_descriptor_model(shape):
  model = ResNet50(include_top=False, input_shape=(32,32,3), weights='imagenet')

  newInput = Input(batch_shape=(None, shape[0], shape[1], shape[2]))
  if shape[2] == 1:
    resizedImg = Lambda(lambda image: ktf.image.grayscale_to_rgb(image))(newInput)
  newOutputs = model(resizedImg)
  model = Model(newInput, newOutputs)

  for layer in model.layers[:]:
      layer.trainable = False

  output = model.output
  output = GlobalAveragePooling2D()(output)
  # let's add a fully-connected layer
  output = Dense(256, activation='relu')(output)
  output = Dense(128, activation='sigmoid')(output)

  return Model(model.input, output)

His patch essentially makes BN actually not trainable, thus train the output layer correctly when resnet layers are set to be untrainable. It won't work if we make every layer trainable, as the BN is still broken in Keras.

I confirm this works ONLY if we initialise weights='imagenet', AND set model layers to be untrainable, i.e for layer in model.layers[:]: layer.trainable = False.

Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.8775100%|██████████| 100000/100000 [00:02<00:00, 34608.98it/s]
2000/2000 [==============================] - 237s 119ms/step - loss: 0.8775 - val_loss: 0.8078
100%|██████████| 10000/10000 [00:00<00:00, 69083.89it/s]
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.7295100%|██████████| 100000/100000 [00:02<00:00, 35442.07it/s]
2000/2000 [==============================] - 225s 113ms/step - loss: 0.7295 - val_loss: 0.7898
100%|██████████| 10000/10000 [00:00<00:00, 67747.75it/s]
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.7096100%|██████████| 100000/100000 [00:03<00:00, 27381.03it/s]
2000/2000 [==============================] - 226s 113ms/step - loss: 0.7096 - val_loss: 0.7715
100%|██████████| 10000/10000 [00:00<00:00, 67637.74it/s]
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.7018100%|██████████| 100000/100000 [00:02<00:00, 34707.58it/s]
2000/2000 [==============================] - 224s 112ms/step - loss: 0.7018 - val_loss: 0.7771
100%|██████████| 10000/10000 [00:00<00:00, 65814.06it/s]
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.6984100%|██████████| 100000/100000 [00:02<00:00, 36293.44it/s]
2000/2000 [==============================] - 225s 112ms/step - loss: 0.6985 - val_loss: 0.7963
100%|██████████| 10000/10000 [00:00<00:00, 66857.37it/s]
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.6961100%|██████████| 100000/100000 [00:02<00:00, 34383.91it/s]
2000/2000 [==============================] - 226s 113ms/step - loss: 0.6961 - val_loss: 0.8160
100%|██████████| 10000/10000 [00:00<00:00, 66138.90it/s]
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.6944100%|██████████| 100000/100000 [00:02<00:00, 35374.37it/s]
2000/2000 [==============================] - 225s 113ms/step - loss: 0.6944 - val_loss: 0.8105
100%|██████████| 10000/10000 [00:00<00:00, 66298.01it/s]
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.6918100%|██████████| 100000/100000 [00:02<00:00, 33619.40it/s]
2000/2000 [==============================] - 225s 113ms/step - loss: 0.6918 - val_loss: 0.8139
100%|██████████| 10000/10000 [00:00<00:00, 60034.41it/s]
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.6915100%|██████████| 100000/100000 [00:02<00:00, 34002.90it/s]
2000/2000 [==============================] - 225s 113ms/step - loss: 0.6915 - val_loss: 0.8104
100%|██████████| 10000/10000 [00:00<00:00, 62052.25it/s]
Epoch 1/1
1999/2000 [============================>.] - ETA: 0s - loss: 0.6870100%|██████████| 100000/100000 [00:02<00:00, 34883.82it/s]
2000/2000 [==============================] - 226s 113ms/step - loss: 0.6870 - val_loss: 0.8012
100%|██████████| 10000/10000 [00:00<00:00, 66300.53it/s]
alopezgit commented 5 years ago

That is a hack that may kind of work, but fixing the network with imagenet weights may not give you good results for this problem. I would try another approach.