eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
936 stars 96 forks source link

Extension to Estimator API: Output shape different from prediction shape #82

Closed mandar2812 closed 6 years ago

mandar2812 commented 6 years ago

I am experimenting with some self implemented loss functions like RBFWeightedSWLoss.

Problem Definition

The key idea behind these loss functions is that the shape of the output prediction might not match the shape of the output labels.

In this particular example RBFWeightedSWLoss, the output is consists of a 2 dimensional vector which has 1 target prediction f 2 time index when it will be observed. i

The targets consist of sliding time windows of the target signal y(t)

The loss is then computed as a kernel weighted average as specified in the source.

State of Affairs

Currently, the Estimator API throws an exception in this case

java.lang.IllegalArgumentException: Expected output shapes compatible with '([?, 128, 128, 4],[?])', but got dataset with output shapes '([?, 128, 128, 4],[?, 72])'.
  org.platanios.tensorflow.api.ops.io.data.Iterator.createInitializer(Iterator.scala:67)
  org.platanios.tensorflow.api.learn.estimators.FileBasedEstimator$$anonfun$trainWithHooks$1.apply$mcV$sp(FileBasedEstimator.scala:155)
  org.platanios.tensorflow.api.learn.estimators.FileBasedEstimator$$anonfun$trainWithHooks$1.apply(FileBasedEstimator.scala:137)
  org.platanios.tensorflow.api.learn.estimators.FileBasedEstimator$$anonfun$trainWithHooks$1.apply(FileBasedEstimator.scala:137)

Proposal

I am wondering if it is possible to have a new feature/workaround of this problem. It would be super helpful if its not necessary for the output and predictions to match in shape as long as the loss can compute the loss function given the two.

eaplatanios commented 6 years ago

@mandar2812 Could you please show me the code snippet that you use to construct the model that you pass to your estimator? In particular, how do you define your model inputs?

mandar2812 commented 6 years ago

@eaplatanios Sorry for the delayed response! So about the model and its inputs.

Architecture

Notes : I have added a convenience function dtflearn.conv2d_unit which basically creates a 2d convolutional layer with unique layer name (determined by its index i.e. last curried argument of the function).

    tf.learn.Cast("Input/Cast", FLOAT32) >>
      dtflearn.conv2d_unit(Shape(2, 2, 4, 64), (1, 1))(0) >>
      dtflearn.conv2d_unit(Shape(2, 2, 64, 32), (2, 2))(1) >>
      dtflearn.conv2d_unit(Shape(2, 2, 32, 16), (4, 4))(2) >>
      dtflearn.conv2d_unit(Shape(2, 2, 16, 8), (8, 8), dropout = false)(3) >>
      tf.learn.MaxPool("MaxPool_3", Seq(1, 2, 2, 1), 1, 1, SamePadding) >>
      tf.learn.Flatten("Flatten_3") >>
      tf.learn.Linear("FC_Layer_4", 128) >>
      tf.learn.ReLU("ReLU_4", 0.1f) >>
      tf.learn.Linear("FC_Layer_5", 64) >>
      tf.learn.ReLU("ReLU_5", 0.1f) >>
      tf.learn.Linear("FC_Layer_6", 8) >>
      tf.learn.Sigmoid("Sigmoid_6") >>
      tf.learn.Linear("OutputLayer", 2)

Training/Test Data

Notes : Dont think there is anything non-standard here, I create a Tensor data set of [?,128,128,4] images and labels of shape [?, 2]

    val dataSet = helios.create_helios_data_set(
      collated_data,
      tt_partition,
      scaleDownFactor = 2,
      resample)

    val trainImages = tf.data.TensorSlicesDataset(dataSet.trainData)

    val train_labels = dataSet.trainLabels

    val labels_mean = dataSet.trainLabels.mean(axes = Tensor(0))

    val labels_stddev = dataSet.trainLabels.subtract(labels_mean).square.mean(axes = Tensor(0)).sqrt

    val norm_train_labels = train_labels.subtract(labels_mean).divide(labels_stddev)

    val trainLabels = tf.data.TensorSlicesDataset(norm_train_labels)

    val trainData =
      trainImages.zip(trainLabels)
        .repeat()
        .shuffle(10000)
        .batch(64)
        .prefetch(10)

Model Definition


    val input = tf.learn.Input(
      UINT8,
      Shape(
        -1,
        dataSet.trainData.shape(1),
        dataSet.trainData.shape(2),
        dataSet.trainData.shape(3))
    )

    val trainInput = tf.learn.Input(FLOAT32, Shape(-1))

    val trainingInputLayer = tf.learn.Cast("TrainInput", INT64)

    val lossFunc = new RBFWeightedSWLoss("Loss/RBFWeightedL2", collated_data.head._2._2.length)

    val loss = lossFunc >>
      tf.learn.Mean("Loss/Mean") >>
      tf.learn.ScalarSummary("Loss", "ModelLoss")

    val optimizer = tf.train.AdaGrad(0.002)

    val summariesDir = java.nio.file.Paths.get(tf_summary_dir.toString())

    //Now create the model
    val (model, estimator) = tf.createWith(graph = Graph()) {
      val model = tf.learn.Model.supervised(
        input, arch, trainInput, trainingInputLayer,
        loss, optimizer)

      println("Training the linear regression model.")

      val estimator = tf.learn.FileBasedEstimator(
        model,
        tf.learn.Configuration(Some(summariesDir)),
        tf.learn.StopCriteria(maxSteps = Some(iterations)),
        Set(
          tf.learn.StepRateLogger(log = false, summaryDir = summariesDir, trigger = tf.learn.StepHookTrigger(5000)),
          tf.learn.SummarySaver(summariesDir, tf.learn.StepHookTrigger(5000)),
          tf.learn.CheckpointSaver(summariesDir, tf.learn.StepHookTrigger(5000))),
        tensorBoardConfig = tf.learn.TensorBoardConfig(summariesDir, reloadInterval = 5000))

      estimator.train(() => trainData, tf.learn.StopCriteria(maxSteps = Some(iterations)))

      (model, estimator)
    }

If you want to check the full source, just go here lines 842 onward.

Let me know what you think!

eaplatanios commented 6 years ago

@mandar2812 No worries! :) I think the issue lies with your trainInput. The dataset you're loading seems to load tensors of shape [-1, 72] as the labels for supervision, where you define your train input as:

val trainInput = tf.learn.Input(FLOAT32, Shape(-1))

Try changing that to the following and see if the inputs to the model work:

val trainInput = tf.learn.Input(FLOAT32, Shape(-1, 72))

I hope this helps. :)

mandar2812 commented 6 years ago

@eaplatanios Its working now. Thanks a lot!