tensorflow / tensor2tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Apache License 2.0
15.59k stars 3.51k forks source link

RFC: T2T V2 Keras sequential and functional API #1465

Open lgeiger opened 5 years ago

lgeiger commented 5 years ago

It's great to see Tensor2Tensor V2 moving to Keras models and layers 🎉

Currently Tensor2Tensor V2 only supports the model subclassing API of Keras. This is great for complex models and custom training loops, but adds a lot of boilerplate for simple feed forward models.

It would be great to support Keras sequential and functional API too. I think this would lower the barrier of entry for adding new models.

The BasicFcRelu example model would look like this when implemented in functional and sequential style:

@gin.configurable(whitelist=["num_hidden_layers", "hidden_size", "dropout"])
def basic_fc_relu_seq_api(features_info=None, input_names=None,
                          target_names=None, num_hidden_layers=2,
                          hidden_size=64, dropout=0.1):
  input_name = input_names[0]
  input_shape = features_info[input_name].shape
  num_output_classes = features_info[target_names[0]].num_classes

  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Flatten(input_shape=input_shape, name=input_name))
  model.add(tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32) / 255.0))
  for i in range(num_hidden_layers):
    model.add(tf.keras.layers.Dense(hidden_size, activation="relu"))
    model.add(tf.keras.layers.Dropout(rate=dropout))
  model.add(tf.keras.layers.Dense(num_output_classes, activation=None))
  return model

@gin.configurable(whitelist=["num_hidden_layers", "hidden_size", "dropout"])
def basic_fc_relu_func_api(features_info=None, input_names=None,
                           target_names=None, num_hidden_layers=2,
                           hidden_size=64, dropout=0.1):
  input_name = input_names[0]
  input_shape = features_info[input_name].shape
  num_output_classes = features_info[target_names[0]].num_classes

  input_ = tf.keras.layers.Input(shape=input_shape, name=input_name)
  output = tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32) / 255.0)(input_)
  for i in range(num_hidden_layers):
    output = tf.keras.layers.Dense(hidden_size, activation="relu")(output)
    output = tf.keras.layers.Dropout(rate=dropout)(output)
  output = tf.keras.layers.Dense(num_output_classes, activation=None)(output)
  return tf.keras.models.Model(inputs=[input_], outputs=[output])

Note that this example currently fails because the dataset input is not passed correctly to the model function. I'd also prefer the input setup to be handled outside of the model function, though I'm not sure what the right abstraction would be.

@lukaszkaiser What do you think about adding support for this?

lgeiger commented 5 years ago

@afrozenator @lukaszkaiser Do you have a Roadmap for supporting Tensorflow 2.0 and would this interface align with your vision?

lukaszkaiser commented 5 years ago

@lgeiger : I think you're right, but currently some Keras APIs just don't work with some things (e.g., tracing is broken for more difficult code in functional style, distribution strategies have bugs in sub-class style). So we're basically waiting for TF to correct these bugs before we can add more support.

While we wait, we've made a new version of T2T with JAX that's coming up nicely: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/trax/README.md

If that sounds interesting to you, I opened an RFC issue to get comments about trax there: https://github.com/tensorflow/tensor2tensor/issues/1478