junyongyou / triq

TRIQ implementation
MIT License
133 stars 23 forks source link

Same output for every input image #20

Closed sulakshgupta988 closed 2 years ago

sulakshgupta988 commented 2 years ago
def create_triq_model(n_quality_levels,
                      input_shape=(None, None, 3),
                      backbone='resnet50',
                      transformer_params=(2, 32, 8, 64),
                      maximum_position_encoding=193,
                      vis=False):
    chanDim = -1
    # define the model input
    inputs = Input(shape=input_shape)
    filters = (32, 64, 128)
    # loop over the number of filters
    for (i, f) in enumerate(filters):
        # if this is the first CONV layer then set the input
        # appropriately
        if i == 0:
            x = Rescaling(1./255)(inputs)

        # CONV => RELU => BN => POOL
        x = Conv2D(f, (3, 3), padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization(axis=chanDim)(x)
        x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(256, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization(axis=chanDim)(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = ZeroPadding2D(padding=(1, 1))(x)
    x = Conv2D(2048, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization(axis=chanDim)(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    dropout_rate = 0.1

    transformer = TriQImageQualityTransformer(
        num_layers=transformer_params[0],
        d_model=transformer_params[1],
        num_heads=transformer_params[2],
        mlp_dim=transformer_params[3],
        dropout=dropout_rate,
        n_quality_levels=n_quality_levels,
        maximum_position_encoding=maximum_position_encoding,
        vis=vis
    )
    outputs = transformer(x)

    model = Model(inputs=inputs, outputs=outputs)
    model.summary()
    return model

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
input_shape = (564, 504, 3)
#model = create_triq_model(n_quality_levels=5, input_shape=input_shape, backbone='vgg16')
model = create_triq_model(n_quality_levels=1, input_shape=input_shape, backbone='resnet50')

from tensorflow.keras.optimizers import Adam
opt = Adam(learning_rate=0.001, decay=1e-3 / 200)
model.compile(loss="mean_squared_error", optimizer=opt)
model.fit(trainImagesX, trainY, validation_data=(valImagesX, valY),
          epochs=108, batch_size=16)

In the above code, I have modified the create_triq_model function in such a way that it uses a custom CNN model instead of the RSNET or VGGNet. The custom CNN model is such that its output shape is (18, 16, 2048). This output is fed to TriqImageQualityTransformer.

The issue is that after training the model predicts the same value for every input. I have experimented with various hyperparameters. It might output different values for different hyperparameter settings but for a particular setting, for every image as input, it outputs the same output. One more thing to note is that if I do not use a transformer but instead use an Artificial Neural Network, then the network trains well.

Ca you please suggest what am I doing wrong here?

junyongyou commented 2 years ago

I suspect this is because from one of the medium layers, the output become a fixed value, e.g., 0. Maybe you can check the outputs from medium layers. If the hyper-parameters change, the model architecture also changes, and the outputs from medium layers change accordingly.

It is noted that from my personal experience an IQA model is heavily dependent on other pretrained nets, e.g., ResNet50 on ImageNet. If you use a custom net, meaning that you probably have not pretrained it on large-scale databases, which can definitely affect the performance.

sulakshgupta988 commented 2 years ago

Thanks a lot. I will use your suggestions.

Even when I use resnet50 as the backbone as you have used, the same problem occurs. Can you suggest some thoughts here? Can hyperparameter tuning help to solve this or some other problem might be the cause of this behavior?

junyongyou commented 2 years ago

I don't fully understand your problem. Did you mean if you are using exactly the same code as mine and you still got same output for your input images?

sulakshgupta988 commented 2 years ago

yes exactly

junyongyou commented 2 years ago

Can you try to use my trained weights (TRIQ.h5) and run image_quality_prediction.py on your images and see how it works? Will you still get same quality output of all your images?

junyongyou commented 2 years ago

I think another potential reason is x = Rescaling(1./255)(inputs). You first scale the pixel values to [0, 1], and then if you are using my generator, normalization will be performed. These can possibly normalize your images to 0. You can also check this.