WizzWriters / spelldraw-app

Free and open-source whiteboard.
Other
0 stars 0 forks source link

Spatial transformer implementation :car: #73

Closed ventus550 closed 1 year ago

ventus550 commented 1 year ago

image

ventus550 commented 1 year ago

https://arxiv.org/pdf/1506.02025.pdf https://colab.research.google.com/github/tulasiram58827/ocr_tflite/blob/main/colabs/KERAS_OCR_TFLITE.ipynb#scrollTo=DaRONZgExrQL

ventus550 commented 1 year ago

https://www.tensorflow.org/api_docs/python/tfm/vision/spatial_transform_ops ?

ventus550 commented 1 year ago

https://groups.google.com/a/tensorflow.org/g/tfjs/c/_GnkC7CNA1s

ventus550 commented 1 year ago

Using custom layers example

First define custom keras layer in Python:

class ReshapeLayer(keras.layers.Layer):
    def __init__(self, target_shape, **kwargs):
        self.target_shape = tuple(target_shape)
        super().__init__(**kwargs)

    def call(self, inputs):
        return tf.reshape(inputs, [-1, *self.target_shape])

    def get_config(self):
        config = super().get_config()
        config['target_shape'] = self.target_shape
        return config

Then build and save the model:

input = keras.layers.Input(shape=(2, 2))
output = ReshapeLayer((1, 4))(input)
model = keras.models.Model(inputs=input, outputs=output)

Now define and register mirrored version of the custom layer:

class ReshapeLayer extends tf.layers.Layer {
  constructor(config) {
    super(config);
    /* Must correspond to the camel cased name of the argument */
    this.targetShape = config.targetShape;
  }

  call(inputs) {
    /* Layers communicate through arrays */
    const input = inputs[0]

    /* Do stuff here */
    return tf.reshape(input, [-1, ...this.targetShape]);
  }

  static get className() {
    /* The camel case name of the custom layer being registered */
    return 'ReshapeLayer';
  }

  getConfig() {
    const config = super.getConfig();
    config.targetShape = this.targetShape;
    return config;
  }
}

/* Register custom reshape layer with TensorFlow.js serialization system */
tf.serialization.registerClass(ReshapeLayer);

Now model can be loaded with tf.loadLayersModel

ventus550 commented 1 year ago

https://github.com/ventus550/SpatialTransformer :triumph:

ventus550 commented 1 year ago

It works but is not very useful for handwriting recognition task.