keras-team / keras-io

Keras documentation, hosted live at keras.io
Apache License 2.0
2.75k stars 2.03k forks source link

Updating OCR model for reading Captchas Keras 3 example (TF-Only) #1788

Closed sitamgithub-MSIT closed 6 months ago

sitamgithub-MSIT commented 6 months ago

This PR updates the OCR model for reading Captchas Keras 3.0 example [TF Only Backend]. Many TF ops are replaced with corresponding Keras ops.

For example, here is the notebook link provided: https://colab.research.google.com/drive/1vCDb45wLmSI3iBI2_BfDDYDSztgxZ4Qp?usp=sharing

cc: @fchollet @divyashreepathihalli

The following describes the Git difference for the changed files:

Changes: ``` diff --git a/examples/vision/captcha_ocr.py b/examples/vision/captcha_ocr.py index 3a2b8e96..06115b0f 100644 --- a/examples/vision/captcha_ocr.py +++ b/examples/vision/captcha_ocr.py @@ -35,6 +35,7 @@ from collections import Counter import tensorflow as tf import keras +from keras import ops from keras import layers """ @@ -109,9 +110,9 @@ def split_data(images, labels, train_size=0.9, shuffle=True): # 1. Get the total size of the dataset size = len(images) # 2. Make an indices array and shuffle it, if required - indices = np.arange(size) + indices = ops.arange(size) if shuffle: - np.random.shuffle(indices) + keras.random.shuffle(indices) # 3. Get the size of training samples train_samples = int(size * train_size) # 4. Split data into training and validation sets @@ -132,10 +133,10 @@ def encode_single_sample(img_path, label): # 3. Convert to float32 in [0, 1] range img = tf.image.convert_image_dtype(img, tf.float32) # 4. Resize to the desired size - img = tf.image.resize(img, [img_height, img_width]) + img = ops.image.resize(img, [img_height, img_width]) # 5. Transpose the image because we want the time # dimension to correspond to the width of the image. - img = tf.transpose(img, perm=[1, 0, 2]) + img = ops.transpose(img, axes=[1, 0, 2]) # 6. Map the characters in label to numbers label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8")) # 7. Return a dict as our model is expecting two inputs @@ -184,13 +185,13 @@ plt.show() def ctc_batch_cost(y_true, y_pred, input_length, label_length): - label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32) - input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32) - sparse_labels = tf.cast(ctc_label_dense_to_sparse(y_true, label_length), tf.int32) + label_length = ops.cast(ops.squeeze(label_length, axis=-1), dtype="int32") + input_length = ops.cast(ops.squeeze(input_length, axis=-1), dtype="int32") + sparse_labels = ops.cast(ctc_label_dense_to_sparse(y_true, label_length), dtype="int32") - y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon()) + y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon()) - return tf.expand_dims( + return ops.expand_dims( tf.compat.v1.nn.ctc_loss( inputs=y_pred, labels=sparse_labels, sequence_length=input_length ), @@ -199,41 +200,41 @@ def ctc_batch_cost(y_true, y_pred, input_length, label_length): def ctc_label_dense_to_sparse(labels, label_lengths): - label_shape = tf.shape(labels) - num_batches_tns = tf.stack([label_shape[0]]) - max_num_labels_tns = tf.stack([label_shape[1]]) + label_shape = ops.shape(labels) + num_batches_tns = ops.stack([label_shape[0]]) + max_num_labels_tns = ops.stack([label_shape[1]]) def range_less_than(old_input, current_input): - return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill( + return ops.expand_dims(ops.arange(ops.shape(old_input)[1]), 0) < tf.fill( max_num_labels_tns, current_input ) - init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool) + init = ops.cast(tf.fill([1, label_shape[1]], 0), dtype="bool") dense_mask = tf.compat.v1.scan( range_less_than, label_lengths, initializer=init, parallel_iterations=1 ) dense_mask = dense_mask[:, 0, :] - label_array = tf.reshape( - tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape + label_array = ops.reshape( + ops.tile(ops.arange(0, label_shape[1]), num_batches_tns), label_shape ) label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask) - batch_array = tf.transpose( - tf.reshape( - tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns), + batch_array = ops.transpose( + ops.reshape( + ops.tile(ops.arange(0, label_shape[0]), max_num_labels_tns), tf.reverse(label_shape, [0]), ) ) batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask) - indices = tf.transpose( - tf.reshape(tf.concat([batch_ind, label_ind], axis=0), [2, -1]) + indices = ops.transpose( + ops.reshape(ops.concatenate([batch_ind, label_ind], axis=0), [2, -1]) ) vals_sparse = tf.compat.v1.gather_nd(labels, indices) return tf.SparseTensor( - tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64) + ops.cast(indices, dtype="int64"), vals_sparse, ops.cast(label_shape, dtype="int64") ) @@ -245,12 +246,12 @@ class CTCLayer(layers.Layer): def call(self, y_true, y_pred): # Compute the training-time loss value and add it # to the layer using `self.add_loss()`. - batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64") - input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64") - label_length = tf.cast(tf.shape(y_true)[1], dtype="int64") + batch_len = ops.cast(tf.shape(y_true)[0], dtype="int64") + input_length = ops.cast(tf.shape(y_pred)[1], dtype="int64") + label_length = ops.cast(tf.shape(y_true)[1], dtype="int64") - input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64") - label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64") + input_length = input_length * ops.ones(shape=(batch_len, 1), dtype="int64") + label_length = label_length * ops.ones(shape=(batch_len, 1), dtype="int64") loss = self.loss_fn(y_true, y_pred, input_length, label_length) self.add_loss(loss) @@ -355,10 +356,10 @@ and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): - input_shape = tf.shape(y_pred) + input_shape = ops.shape(y_pred) num_samples, num_steps = input_shape[0], input_shape[1] - y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon()) - input_length = tf.cast(input_length, tf.int32) + y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon()) + input_length = ops.cast(input_length, dtype="int32") if greedy: (decoded, log_prob) = tf.nn.ctc_greedy_decoder( (END) ```
sitamgithub-MSIT commented 6 months ago

Looks good, thank you. Please add the generated files.

The files in .md and .ipynb have been added. I also attempted the same code approach on a handwritten OCR sample. For that example, model training with Keras 3 was also successful, so I will be making that PR shortly.