rstudio / sparktf

R interface to Spark TensorFlow Connector
Other
13 stars 8 forks source link

How to inference tensorflow model in spark? #3

Open harryprince opened 5 years ago

harryprince commented 5 years ago

https://github.com/yahoo/TensorFlowOnSpark/wiki/Scala-Inference-API#inference-via-tfos-scala-api

References

harryprince commented 5 years ago

predict tfrecord mode

import com.yahoo.tensorflowonspark.Inference

var conf = Inference.Config(export_dir: String = "${TFoS_HOME}/mnist_export",
                    input: String = "${TFoS_HOME}/mnist/tfr/test",
                    schema_hint: StructType = new StructType(struct<image:array<float>,label:array<float>>),
                    input_mapping: Map[String, String] = {"image": "inputs/x", "label": "inputs/y_"},
                    output_mapping: Map[String, String] = {"prediction": "prediction", "layer/hidden_layer/Relu": "features"},
                    output: String = "${TFoS_HOME}/predictions",
                    verbose: Boolean = false)

Inference.run(sc, conf)
harryprince commented 5 years ago

predict data.frame mode

config.export_dir means model file path

import com.yahoo.tensorflowonspark.TFModel

    val model = new TFModel().setModel(config.export_dir)
                              .setInputMapping(config.input_mapping)
                              .setOutputMapping(config.output_mapping)

    // transform the input DataFrame
    // Note: we're currently dropping input columns for simplicity, you can retrieve them as Tensors if needed.
    val predDF = model.transform(df)