Open harryprince opened 5 years ago
predict tfrecord mode
import com.yahoo.tensorflowonspark.Inference
var conf = Inference.Config(export_dir: String = "${TFoS_HOME}/mnist_export",
input: String = "${TFoS_HOME}/mnist/tfr/test",
schema_hint: StructType = new StructType(struct<image:array<float>,label:array<float>>),
input_mapping: Map[String, String] = {"image": "inputs/x", "label": "inputs/y_"},
output_mapping: Map[String, String] = {"prediction": "prediction", "layer/hidden_layer/Relu": "features"},
output: String = "${TFoS_HOME}/predictions",
verbose: Boolean = false)
Inference.run(sc, conf)
predict data.frame mode
config.export_dir
means model file path
import com.yahoo.tensorflowonspark.TFModel
val model = new TFModel().setModel(config.export_dir)
.setInputMapping(config.input_mapping)
.setOutputMapping(config.output_mapping)
// transform the input DataFrame
// Note: we're currently dropping input columns for simplicity, you can retrieve them as Tensors if needed.
val predDF = model.transform(df)
https://github.com/yahoo/TensorFlowOnSpark/wiki/Scala-Inference-API#inference-via-tfos-scala-api
References