google / nitroml

NitroML is a modular, portable, and scalable model-quality benchmarking framework for Machine Learning and Automated Machine Learning (AutoML) pipelines.
Apache License 2.0
42 stars 6 forks source link

Keras Trainer and AutoData adapter #32

Closed cweill closed 4 years ago

cweill commented 4 years ago

Examples:

Trainer

# Lint as: python3
import os

import tensorflow.compat.v2 as tf
import tensorflow.google.compat.v2 as tfg

import data_provider
from tfx.components.trainer import executor as trainer_executor

def get_hparams() -> tfg.HParams:
  """Defines the set of hyper parameters recognized by this model.

  NOTE these hyperparameters are just for exemplifying purposes,
  and should be tuned.

  Returns:
    An tf.HParams instance.
  """
  return tfg.HParams(
      train_batch_size=128,
      eval_batch_size=128,
  )

# TFX Trainer will call this function.
def run_fn(fn_args: trainer_executor.TrainerFnArgs):
  """Train a Keras Model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
      - train_files: A list of uris for train files.
      - transform_output: An optional single uri for transform graph produced by
        TFT. Will be None if not specified.
      - serving_model_dir: A single uri for the output directory of the serving
        model.
      - eval_model_dir: A single uri for the output directory of the eval model.
        Note that this is for estimator only, Keras doesn't require it for TFMA.
      - eval_files:  A list of uris for eval files.
      - schema_file: A single uri for schema file.
      - train_steps: Number of train steps.
      - eval_steps: Number of eval steps.
      - base_model: Base model that will be used for this training job.
      - hyperparameters: An optional kerastuner.HyperParameters config.
  """

  model_hparams = get_hparams()

  data_provider = data_provider.DataProvider(
      transform_graph_dir=fn_args.transform_output)

  feature_columns = data_provider.get_numeric_feature_columns(
  ) + data_provider.get_embedding_feature_columns()
  input_layers = data_provider.get_input_layers()

  # All input_layers must be consumed for the Keras Model to work.
  assert len(feature_columns) >= len(input_layers)

  x = tf.keras.layers.DenseFeatures(feature_columns)(input_layers)
  for numnodes in [64, 64]:
    x = tf.keras.layers.Dense(numnodes)(x)
  output = tf.keras.layers.Dense(
      data_provider.forecast_horizon, activation=None, name='logits')(
          x)

  model = tf.keras.Model(input_layers, output)
  model.compile(
      loss=tf.keras.losses.MeanSquaredError(),
      optimizer=tf.keras.optimizers.Adam(lr=0.001),
      metrics=[
          tf.keras.metrics.RootMeanSquaredError(),
          tf.keras.metrics.MeanSquaredError(),
          tf.keras.metrics.MeanAbsoluteError(),
          tf.keras.metrics.MeanSquaredLogarithmicError(),
      ])
  model.summary()

  # This log path might change in the future.
  log_dir = os.path.join(os.path.dirname(fn_args.serving_model_dir), 'logs')
  tensorboard_callback = tf.keras.callbacks.TensorBoard(
      log_dir=log_dir, update_freq='batch')

  train_dataset = data_provider.get_dataset(
      file_pattern=fn_args.train_files,
      batch_size=model_hparams.train_batch_size,
      num_epochs=None,
      shuffle=True)
  eval_dataset = data_provider.get_dataset(
      file_pattern=fn_args.eval_files,
      batch_size=model_hparams.eval_batch_size,
      num_epochs=1,
      shuffle=False)

  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      epochs=1,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps,
      callbacks=[tensorboard_callback])

  signatures = {
      'serving_default':
          data_provider.get_serve_tf_examples_fn(model).get_concrete_function(
              tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')),
  }
  model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)

DataProvider/AutoData adapter

# Lint as: python3
"""An data provider for Keras Models.

The consumed artifacts include:
 * Dataset schema.
 * Dataset statistics.
 * TensorFlow Transform outputs.
"""

from typing import Any, Dict, List, Optional, Text

import tensorflow.compat.v2 as tf
import tensorflow_transform as tft

from google3.third_party.tensorflow_metadata.proto.v0 import schema_pb2

FeatureColumn = Any

class DataProvider():
  """Creates feature columns and specs from TFX artifacts."""

  def __init__(self, transform_graph_dir: Text):
    """Initializes the DataProvider from TFX artifacts.

    Args:
      transform_graph_dir: Path to the TensorFlow Transform graph artifacts.
    """

    # Parse transform.
    self._tf_transform_output = tft.TFTransformOutput(transform_graph_dir)

    # Parse schema.
    self._dataset_schema = self._tf_transform_output.transformed_metadata.schema

  @property
  def raw_label_keys(self) -> List[Text]:
    """The raw label key as defined in the ProblemStatement."""

    # TODO(weill): Make this label configurable.
    return ['future_sales']

  @property
  def transformed_label_keys(self) -> List[Text]:
    """The label key after applying TensorFlow Transform to the Examples."""

    return self.raw_label_keys

  @property
  def forecast_horizon(self) -> int:
    """The int forecast horizon for future sales."""

    # 28 days.
    return 28

  def get_input_layers(self) -> Dict[Text, tf.keras.layers.Input]:
    """Returns input layers for a Keras Model."""

    feature_spec = self._tf_transform_output.transformed_feature_spec().copy()
    feature_spec.pop(self.transformed_label_keys[0])
    input_layers = {}
    for colname, spec in feature_spec.items():
      input_layers[colname] = tf.keras.layers.Input(
          name=colname, shape=spec.shape, dtype=spec.dtype)
    return input_layers

  def get_numeric_feature_columns(self,
                                  include_integer_columns: bool = False
                                 ) -> List[FeatureColumn]:
    """Creates a set of feature columns.

    Args:
      include_integer_columns: Whether integer columns in the examples should be
        included in the numeric columns as floats.

    Returns:
      A list of feature columns.
    """

    numeric_columns = []
    for feature in self._dataset_schema.feature:

      feature_name = feature.name
      if feature_name in self.raw_label_keys:
        continue

      feature_storage_type = _get_feature_storage_type(self._dataset_schema,
                                                       feature_name)

      if feature_storage_type == tf.int64 and not include_integer_columns:
        continue

      # NOTE: Int features are treated as both numerical and categorical. For
      # example MNIST stores its features as int16 features, but are continuous.
      if feature_storage_type == tf.float32 or feature_storage_type == tf.int64:

        # Numerical feature.
        dim = _get_feature_dim(self._dataset_schema, feature_name)

        # Numerical feature normalized in 0-1.
        current_feature = tf.feature_column.numeric_column(
            feature_name, shape=(dim,), dtype=feature_storage_type)
        numeric_columns.append(current_feature)
    return numeric_columns

  def get_sparse_categorical_feature_columns(
      self, include_integer_columns: bool = True) -> List[FeatureColumn]:
    """Creates a set of sparse categorical feature columns.

    Args:
      include_integer_columns: Whether integer columns in the examples should be
        included in the categorical columns.

    Returns:
      A list of feature columns.
    """

    feature_columns = []
    for feature in self._dataset_schema.feature:

      feature_name = feature.name
      if feature_name in self.raw_label_keys:
        continue

      feature_storage_type = _get_feature_storage_type(self._dataset_schema,
                                                       feature_name)

      if feature_storage_type == tf.float32:
        continue

      if feature_storage_type == tf.int64:
        if not include_integer_columns:
          continue

        # Categorical or categorical-set feature stored as an integer(s).
        num_buckets = (
            self._tf_transform_output.num_buckets_for_transformed_feature(
                feature_name))
        new_feature_column = tf.feature_column.categorical_column_with_identity(
            feature_name, num_buckets=num_buckets)
      else:
        # Note TFT automatically converts string columns to int columns.
        raise ValueError('Unsupported dtype.')
      feature_columns.append(new_feature_column)
    return feature_columns

  def get_embedding_feature_columns(self,
                                    include_integer_columns: bool = True
                                   ) -> List[FeatureColumn]:
    """Creates a set of embedding feature columns.

    Args:
      include_integer_columns: Whether integer columns in the examples should be
        included in the embeddings.

    Returns:
      A list of feature columns.
    """

    return [
        tf.feature_column.embedding_column(column, dimension=10) for column in
        self.get_sparse_categorical_feature_columns(include_integer_columns)
    ]

  def get_dataset(self,
                  file_pattern: Text,
                  batch_size: int,
                  num_epochs: Optional[int] = None,
                  shuffle: Optional[bool] = True,
                  shuffle_buffer_size: int = 10000,
                  shuffle_seed: Optional[int] = None,
                  prefetch_buffer_size: Optional[int] = None,
                  reader_num_threads: Optional[int] = None,
                  parser_num_threads: Optional[int] = None,
                  sloppy_ordering: bool = False,
                  drop_final_batch: bool = False) -> tf.data.Dataset:
    """Returns an input_fn that returns a `tf.data.Dataset` from Examples.

    Args:
      file_pattern: List of files or patterns of file paths containing Example
        records. See tf.io.gfile.glob for pattern rules.
      batch_size: An int representing the number of records to combine in a
        single batch.
      num_epochs: Integer specifying the number of times to read through the
        dataset. If None, cycles through the dataset forever. Defaults to None.
      shuffle: A boolean, indicates whether the input should be shuffled.
        Defaults to True.
      shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity
        ensures better shuffling but would increase memory usage and startup
        time.
      shuffle_seed: Randomization seed to use for shuffling.
      prefetch_buffer_size: Number of feature batches to prefetch in order to
        improve performance. Recommended value is the number of batches consumed
        per training step. Defaults to auto-tune.
      reader_num_threads: Number of threads used to read Example records. If >1,
        the results will be interleaved. Defaults to 1.
      parser_num_threads: Number of threads to use for parsing Example tensors
        into a dictionary of Feature tensors. Defaults to 2.
      sloppy_ordering: If True, reading performance will be improved at the cost
        of non-deterministic ordering. If False, the order of elements produced
        is deterministic prior to shuffling (elements are still randomized if
        shuffle=True. Note that if the seed is set, then order of elements after
        shuffling is deterministic). Defaults to False.
      drop_final_batch: If True, and the batch size does not evenly divide the
        input dataset size, the final smaller batch will be dropped. Defaults to
        False.

    Returns:
      Returns an input_fn that returns a `tf.data.Dataset`.
    """

    # Since we're not applying the transform graph here, we're using Transform
    # materialization.
    feature_spec = self._tf_transform_output.transformed_feature_spec().copy()

    def _pop_labels(features):
      label_keys = self.transformed_label_keys
      labels = []
      for key in label_keys:
        labels.append(features.pop(key))
      return features, tf.concat(labels, axis=1)

    def _gzip_reader_fn(files):
      return tf.data.TFRecordDataset(files, compression_type='GZIP')

    dataset = tf.data.experimental.make_batched_features_dataset(
        file_pattern,
        batch_size,
        feature_spec,
        reader=_gzip_reader_fn,
        num_epochs=num_epochs,
        shuffle=shuffle,
        shuffle_buffer_size=shuffle_buffer_size,
        shuffle_seed=shuffle_seed,
        prefetch_buffer_size=prefetch_buffer_size,
        reader_num_threads=reader_num_threads,
        parser_num_threads=parser_num_threads,
        sloppy_ordering=sloppy_ordering,
        drop_final_batch=drop_final_batch)
    return dataset.map(_pop_labels)

  def get_serve_tf_examples_fn(self, model: tf.keras.Model):
    """Returns a function that parses a serialized tf.Example and applies TFT."""

    model.tft_layer = self._tf_transform_output.transform_features_layer()

    @tf.function
    def serve_tf_examples_fn(serialized_tf_examples):
      """Returns the output to be used in the serving signature."""
      feature_spec = self._tf_transform_output.raw_feature_spec()
      feature_spec.pop(self.transformed_label_keys[0])
      parsed_features = tf.io.parse_example(serialized_tf_examples,
                                            feature_spec)
      transformed_features = model.tft_layer(parsed_features)
      return model(transformed_features)

    return serve_tf_examples_fn

def _get_feature_storage_type(schema: schema_pb2.Schema,
                              feature_name: Text) -> tf.dtypes.DType:
  """Get the storage type of at tf.Example feature."""

  for feature in schema.feature:
    if feature.name == feature_name:
      if feature.type == schema_pb2.FeatureType.BYTES:
        return tf.string
      if feature.type == schema_pb2.FeatureType.FLOAT:
        return tf.float32
      if feature.type == schema_pb2.FeatureType.INT:
        return tf.int64
  raise ValueError('Feature not found: {}'.format(feature_name))

def _get_feature_dim(schema: schema_pb2.Schema, feature_name: Text) -> int:
  """Get the dimension of the tf.Example feature."""

  for feature in schema.feature:
    if feature.name == feature_name:
      return feature.shape.dim[0].size
  raise ValueError('Feature not found: {}'.format(feature_name))