idiap / attention-sampling

This Python package enables the training and inference of deep learning models for very large data, such as megapixel images, using attention-sampling
Other
98 stars 18 forks source link

Extracting weird patches #6

Closed andersbhc-mmmi closed 4 years ago

andersbhc-mmmi commented 5 years ago

Hi!

As I have been using this implementation for my own classification tasks, I have started to see a weird trend. In the first couple of epochs, the patches seem to be taken from diffuse and widely differing areas of the training images (which is fine, and expected). In the later epochs, as the training starts to converge, the attention model seems to focus on extracting patches from one particular edge in the images, where background (=black) meets the real content of the image. I can see this from the attention maps, as they appear white in this region, and from the patches that are extracted. This does not make any sense to me, as that particular edge does not reveal any important information for the classification task, but the algorithm is still able to achieve >80% accuracy from those patches. Also, it is always the same edge, even though there are multiple similar edges in the training images. Is this something you have experienced before? Is there anything, you think, I could be doing wrong?

Thanks in advance, Anders

angeloskath commented 5 years ago

Hi,

Does it achieve >80% also in the test set?

So I have seen this before and I usually deal with it by tweaking the following:

Just to make sure that I understand, the attention focuses on parts of the image that are at the side where there is some artifact. This is basically attention overfitting. The black to color causes large activations which the attention has not learned to ignore yet or they just influence the feature network to use them which creates a positive feedback loop and ends up using just them.

Let me know if you need more help or if anything I mentioned above does not make sense.

Cheers, Angelos

andersbhc-mmmi commented 5 years ago

Yes, it does.

You are correct, the attention model is focusing on the edge between background (black) and foreground (real content, not black).

I will try some of the points you have mentioned here. I already tried tweaking the regularization strength for the attention model, controlling its exploration vs exploitation. When adding more regularization, the model seemed to distribute the patches more (do more exploration), but eventually ended up focusing on the same edges.

Also, I tried to save the models' weights using the Keras ModelCheckpoint and load them again after initializing the model, but then the accuracy drops significantly. Here is my code for building the models, saving and loading weights:

#Defining the attention model
def getAttentionModel(input_shape):
    attention = Sequential([
        Conv2D(8, kernel_size=3, activation="relu", padding="same",
               input_shape=input_shape),
        Conv2D(16, kernel_size=3, activation="relu", padding="same"),
        Conv2D(32, kernel_size=3, activation="relu", padding="same"),
        Conv2D(64, kernel_size=3, activation="relu", padding="same"), 
        Conv2D(128, kernel_size=3, activation="relu", padding="same"),
        Conv2D(1, kernel_size=3, padding="same"),
        SampleSoftmax(squeeze_channels=True, smooth=1e-5)
    ])

    return attention
#Defining the feature extraction model
def getVGGModel(input_shape, pre_trained=True):
    if pre_trained:
        model = VGG16(include_top=False, weights='imagenet', input_shape=input_shape, pooling='max')
    else:
        model = VGG16(include_top=False, weights=None, input_shape=input_shape, pooling='max')

    #We only fine-tune the last CONV layer and the Dense layers
    for layer in model.layers[:15]:
        layer.trainable = False

    last = model.output
    x = L2Normalize()(last)
    x = Dropout(0.5)(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.5)(x)

    feature_model = Model(inputs=model.input, outputs=x)

    return feature_model
#Making the attention sampling network
def get_model(outputs, width, height, scale, n_patches, patch_size, reg):
    # Define the shapes
    x_in = Input(shape=(height, width, 3))
    x_high = x_in
    x_low = ResizeImages((int(height*scale), int(width*scale)))(x_high)
    shape_high = (height, width, 3)
    shape_low = (int(height*scale), int(width*scale), 3)

    # Make the attention and feature models
    attention = getAttentionModel(shape_low)
    feature = getVGGModel(shape_high, pre_trained=True)

    # Let's build the attention sampling network
    features, attention, patches = attention_sampling(
        attention,
        feature,
        patch_size,
        n_patches,
        replace=False,
        attention_regularizer=multinomial_entropy(reg)
    )([x_low, x_high])
    y = Dense(outputs, activation="softmax")(features)

    return (
        Model(inputs=x_in, outputs=[y]),
        Model(inputs=x_in, outputs=[attention, patches, x_low])
    )
#Instantiating the models
model, att_model = get_model(
    outputs=2,
    width=1444,
    height=1184,
    scale=0.2,
    n_patches=args["n_patches"],
    patch_size=args["patch_size"],
    reg=args["regularizer_strength"]
)
#Instantiating callbacks including ModelCheckpoint which saves weights after each epoch
callbacks = [
    lr_sched,
    AttentionSaver(args["output"], att_model, training_set),
    ModelCheckpoint(
        os.path.join(args["output"], "weights.{epoch:02d}.h5"),
        save_weights_only=True
    ),
    CSVLogger(filename=os.path.join(args["output"], "train_history.csv"))
]

Then I fit the model.


When I then recreate the models in the same way as above and load the weights:

model.load_weights(args["weights_path"]) Then evaluating the model on the same data as before yields a low accuracy.

angeloskath commented 5 years ago

Hi,

The fact that the test accuracy does not drop means that the patches are informative, so it will be harder to get rid of them.

Regarding saving and loading, that is weird. I just copy/pasted your code in a shell and saving and loading works fine. I would start by comparing outputs for a single image. Load your code in a shell if possible and then train for a few gradient updates and then save and reload and check that the attention is exactly the same and that for a given image the two models give approximately equal results. It helps checking the deterministic parts of the code because they should match exactly!

Let me know if you still have problems saving and loading the models or if I can help with the weird patches being selected.

If you really want to get rid of them no matter what, you could generate a mask and apply it to the attention so that patches that contain some black are never selected.

Cheers, Angelos

andersbhc-mmmi commented 5 years ago

Hi Angelos,

So I finally made saving and loading work, and I can now reproduce the results after saving and loading the model and its weights. I used the built-in methods model_to_json() and model_from_json() along with the ModelCheckpoint to save the weights and load_weights() to load the weights in my test script. I also had to define some get_config() methods in the custom layers to make it work.

I'm still confused as to how the patches containing the edges help in the classification. I'll continue tweaking some hyperparameters and see if I can make it better.

Thank you for your help!

angeloskath commented 5 years ago

This sounds awesome, do you mind sharing your additions? I would gladly merge a pull request.

Let me know if I can do anything more.

Angelos

andersbhc-mmmi commented 5 years ago

Hi, Yeah definitely. It's not much, but I'll share it for sure.

Anders