OpenMined / PySyft-TensorFlow

SOON TO BE DEPRECATED - The TensorFlow bindings for PySyft
Apache License 2.0
57 stars 11 forks source link

Send custom model without having to call before model.predict(dummy_data) #34

Open yanndupis opened 4 years ago

yanndupis commented 4 years ago

If you look at Part 2 tutorial, for custom models (tf.keras.models.Model), before sending the model to the worker, we need to run model.predict(dummy_data) to set the input_shape ( required by tf.keras.models.save_model).

Ideally we would like to remove this step or just have to call model(dummy_data) before sending the model. You can find more information in this conversation.

arshjot commented 4 years ago

We can set the input shape while defining the model as shown below:

class CustomModel(tf.keras.Model):

    def __init__(self, num_classes=10):
        super(CustomModel, self).__init__(name='custom_model')
        self.num_classes = num_classes

        self.flatten = tf.keras.layers.Flatten()
        self.dense_1 = tf.keras.layers.Dense(128, activation='relu')
        self.dense_2 = tf.keras.layers.Dense(num_classes, activation='softmax')

        # set input shape
        self._set_inputs(tf.TensorSpec(shape=[None, 28, 28], dtype=tf.float32))

    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.dense_1(x)
        return self.dense_2(x)

model = CustomModel(10)
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model_ptr = model.send(bob)
model_ptr.fit(x_train_ptr, y_train_ptr, epochs=2, validation_split=0.2)

Or we can just replace model.predict(dummy_data) with model._set_inputs(tf.TensorSpec(shape=[None, 28, 28], dtype=tf.float32))

Would any of these be a satisfactory solution?

jvmncs commented 4 years ago

Hmm, I don't think this is ideal, since that method _set_inputs is meant to be internal and not exposed to the user. Then again, I do like placing that in the constructor a bit more than model.predict(x) for the tutorial. I just reviewed the conversation @yanndupis & I had in the original PR, if Keras is explicitly requiring that their users call fit, predict, or _set_inputs, then I think it's okay for us to expect the same as well.

The only thing left to change here would be to handle this a bit more cleanly in the case of model.send(bob). It would be great if we had our own error to report & redirect, since a user might not realize that sending a model has this call to save_model, which could be confusing.