sayakpaul / Handwriting-Recognizer-in-Keras

This project shows how to build a simple handwriting recognizer in Keras with the IAM dataset.
Apache License 2.0
13 stars 2 forks source link

Incorporating self-attention #3

Open sayakpaul opened 3 years ago

sayakpaul commented 3 years ago

I wanted to also incorporate self-attention into the model to make the example a bit more interesting and fun.

Here's how I am doing it currently:

...
# Second conv block.
x =  keras.layers.Conv2D(
    64,
    (3, 3),
    activation="relu",
    kernel_initializer="he_normal",
    padding="same",
    name="Conv2",
)(x)
x =  keras.layers.MaxPooling2D((2, 2), name="pool2")(x)

# Self-attention.
attended_outputs, attention_scores = keras.layers.Attention(
    use_scale=True, dropout=0.2, name="attention"
)([x, x], return_attention_scores=True)

# We have used two max pool with pool size and strides 2.
# Hence, downsampled feature maps are 4x smaller. The number of
# filters in the last layer is 64. Reshape accordingly before
# passing the output to the RNN part of the model.
new_shape = ((image_width // 4), (image_height // 4) * 64)
x =  keras.layers.Reshape(target_shape=new_shape, name="reshape")(attended_outputs)
x =  keras.layers.Dense(64, activation="relu", name="dense1")(x)
x =  keras.layers.Dropout(0.2)(x)
...

The visualization looks like so:

image

Here's the Colab. Note that the results are from 10 epochs of training.

Wanted to get your thoughts. @AakashKumarNain

AakashKumarNain commented 3 years ago

This looks good but we need to validate two things:

  1. How much of improvement does it provide?
  2. Should we increase the depth of the network?

For the first one, we need to add a metric, the most popular being editdistance

sayakpaul commented 3 years ago

I can look into incorporating the edit distance metric. Funnily, it has got varied names in different kinds of literature.

After that, would you like to experiment with 2.? To maintain the brevity of the example I suggest making it (the one with edit distance and self-attention) a separate one. But I am open to your thoughts.

AakashKumarNain commented 3 years ago

Yes, that sounds good to me. Let's push that first, we can always make another PR for the addons

sayakpaul commented 3 years ago

@AakashKumarNain since the prediction model is different from the main training model here's how I am envisioning the evaluation with edit distance.

We train the model as it is and then extract the prediction model. After that, we run the edit distance evaluation. Sample code that I have on mind is (just for a single batch):

# Get a single batch and convert its labels to sparse tensors.
test_batch = next(iter(test_ds))
saprse_labels = tf.cast(
    tf.sparse.from_dense(test_batch["label"]), dtype=tf.int64
)

# Make predictions and convert them to sparse tensors.
predictions = prediction_model.predict(test_batch)
input_len = np.ones(predictions.shape[0]) * predictions.shape[1]
predictions_decoded = keras.backend.ctc_decode(predictions, input_length=input_len, greedy=True)[0][0][
    :, :max_len
]
sparse_predictions = tf.cast(
    tf.sparse.from_dense(predictions_decoded), dtype=tf.int64
)

# Compute individual edit distances and average them out.
edit_distances = tf.edit_distance(
    sparse_predictions, saprse_labels, normalize=False
)
mean_edit_distance = tf.reduce_mean(edit_distances)
AakashKumarNain commented 3 years ago

@sayakpaul why not add a callback for metric evaluation during training as well?

sayakpaul commented 3 years ago

@AakashKumarNain does this work?

https://colab.research.google.com/gist/sayakpaul/dc7439ee9421f5a994e6a75e7b0e624a/handwriting_recognition.ipynb

AakashKumarNain commented 3 years ago

@sayakpaul I think this can be made more simple. I will try refining it today

sayakpaul commented 3 years ago

@AakashKumarNain please share what you have in mind. I can also work on further simplifying it from there. But simplification should not lead to hampering the readability aspect IMO.

AakashKumarNain commented 3 years ago

Instead of defining a prediction model every time we hit the callback, can't we just make a shallow copy of the model weights and just reuse that?

sayakpaul commented 3 years ago

Do you mean initialize a prediction model class and load the updated weights every time the callback is hit? But that would still require subclassing the main model (that contains the CTC layer) and then extracting the weights, no?

AakashKumarNain commented 3 years ago

Yes, that is true!

sayakpaul commented 3 years ago

Yeah, so that does not introduce a whole lot of improvements to the current base IMO. But please correct me if I am missing something.

AakashKumarNain commented 3 years ago

Agreed. I will review it once more in the evening and will let you know and then we can proceed with it

AakashKumarNain commented 3 years ago

@sayakpaul check this out: https://drive.google.com/file/d/1_aixkxpDKlDQe2FICtPzmKqnHUYxnTNd/view?usp=sharing

sayakpaul commented 3 years ago

Bohot khub.

@AakashKumarNain feel free to incorporate it in the PR directly. Neat.

Also, WDYT about the self-attention part? Should we cover that in a follow-up example?

AakashKumarNain commented 3 years ago

Yes, we can push this one for now. For self-attention, we will make another example

sayakpaul commented 3 years ago

SGTM. Curious to see the results with SA 🧐

AakashKumarNain commented 3 years ago

@sayakpaul can you push the changes with edit distance callback? It will be good if only one of us push the changes to that PR. Will be less cluttered IMO

sayakpaul commented 3 years ago

Okay sir, will do. 💻

AakashKumarNain commented 3 years ago

Thanks a lot :beers:

sayakpaul commented 3 years ago

@AakashKumarNain just wanted to circle back to this part of the blog post. Happy to help with anything you might need.

AakashKumarNain commented 3 years ago

@sayakpaul I got busy with work. Will get back to this soon

AakashKumarNain commented 3 years ago

@sayakpaul I ran some experiments today. Although attention did provide some improvements, the improvements aren't that huge. I will try to showcase it in a colab side-by-side soon

sayakpaul commented 3 years ago

Okay. Maybe we need to reformulate how it's being used currently.

sayakpaul commented 3 years ago

@AakashKumarNain I have been thinking more about this lately.

Since the characters inside the images are sort of presented in a tight-nit manner, I doubt if incorporating self-attention would provide that extra boost. I doubt if it will help the model to learn contextual dependencies any more than what the CNN part of our model is already doing.

Happy to brainstorm more and design experiments, though.

AakashKumarNain commented 3 years ago

@sayakpaul yes, in my experiments both the models (with and without self-attention) performs almost the same. Let's take this offline and discuss the next steps