OpenMined / PySyft-TensorFlow

SOON TO BE DEPRECATED - The TensorFlow bindings for PySyft
Apache License 2.0
57 stars 11 forks source link

Switch to SavedModel for tf.keras.models.Model serde #30

Closed jvmncs closed 4 years ago

jvmncs commented 4 years ago

Switches model serialization over to the SavedModel format, as opposed to hdf5.

With SavedModel, we can handle more general models, i.e.

class CustomModel(tf.keras.Model):
    def __init__(self):
        self.layer = tf.keras.layers.Dense(3)
        self.scale = tf.Variable(2.)

    def call(self, x):
        return self.layer(x) * self.scale

or

class CustomModule(tf.Module):
    def __init__(self):
        self.layer = tf.keras.layers.Dense(3)
        self.scale = tf.Variable(2.)

    @tf.function
    def call(self, x):
        return self.layer(x) * self.scale

The hdf5 format cannot always handle these custom models, so we prefer to use SavedModel. However, this implementation is much, much slower than the hdf5 implementation. This is because we currently have to walk the entire SavedModel directory structure while serializing the model, and that walk is taking place in regular old Python. There is probably a more efficient way to serialize the SavedModel directory, but we prefer to go forward with this in the meantime. A future alternative could be to allow the user to control this behavior with a flag that switches between h5py & SavedModel simplifier/detailers, with h5py being the default for speed.