databricks / spark-deep-learning

Deep Learning Pipelines for Apache Spark
https://databricks.github.io/spark-deep-learning
Apache License 2.0
2k stars 494 forks source link

Working with Estimators #148

Open jpurnell01 opened 6 years ago

jpurnell01 commented 6 years ago

I've tried using the GraphModelFactory() to load the estimator-based model exported in this regression example, but when I try to register the UDF like this:

val graph = new com.databricks.sparkdl.python.GraphModelFactory()
  .sqlContext(spark.sqlContext)
  .fetches(asJava(Seq("dnn/logits/BiasAdd")))
  .inputs(asJava(Seq("csv_rows")), asJava(Seq("csv_rows")))
  .graphFromFile("tensorflow/tf-estimator-tutorials/trained_models/reg-model-01/export/1532033282/saved_model.pb")

graph.registerUDF("my_scheduling_blocked", blocked=true)

I get and an error like this:

Exception in thread "main" com.google.protobuf.InvalidProtocolBufferException: While parsing a protocol message, the input ended unexpectedly in the middle of a field.  This could mean either that the input has been truncated or that an embedded message misreported its own length.
    at com.google.protobuf.InvalidProtocolBufferException.truncatedMessage(InvalidProtocolBufferException.java:86)
    at com.google.protobuf.CodedInputStream$ArrayDecoder.skipRawBytes(CodedInputStream.java:1295)
    at com.google.protobuf.CodedInputStream$ArrayDecoder.skipField(CodedInputStream.java:683)
    at com.google.protobuf.CodedInputStream$ArrayDecoder.skipMessage(CodedInputStream.java:745)
    at com.google.protobuf.CodedInputStream$ArrayDecoder.skipField(CodedInputStream.java:676)
    at com.google.protobuf.MapEntryLite.parseEntry(MapEntryLite.java:186)
    at com.google.protobuf.MapEntry.<init>(MapEntry.java:106)
    at com.google.protobuf.MapEntry.<init>(MapEntry.java:50)
    at com.google.protobuf.MapEntry$Metadata$1.parsePartialFrom(MapEntry.java:70)
    at com.google.protobuf.MapEntry$Metadata$1.parsePartialFrom(MapEntry.java:64)
    at com.google.protobuf.CodedInputStream$ArrayDecoder.readMessage(CodedInputStream.java:911)
    at org.tensorflow.framework.FunctionDef.<init>(FunctionDef.java:90)
    at org.tensorflow.framework.FunctionDef.<init>(FunctionDef.java:17)
    at org.tensorflow.framework.FunctionDef$1.parsePartialFrom(FunctionDef.java:1730)
    at org.tensorflow.framework.FunctionDef$1.parsePartialFrom(FunctionDef.java:1725)
    at com.google.protobuf.CodedInputStream$ArrayDecoder.readMessage(CodedInputStream.java:911)
    at org.tensorflow.framework.FunctionDefLibrary.<init>(FunctionDefLibrary.java:64)
    at org.tensorflow.framework.FunctionDefLibrary.<init>(FunctionDefLibrary.java:13)
    at org.tensorflow.framework.FunctionDefLibrary$1.parsePartialFrom(FunctionDefLibrary.java:1067)
    at org.tensorflow.framework.FunctionDefLibrary$1.parsePartialFrom(FunctionDefLibrary.java:1062)
    at com.google.protobuf.CodedInputStream$ArrayDecoder.readMessage(CodedInputStream.java:911)
    at org.tensorflow.framework.GraphDef.<init>(GraphDef.java:72)
    at org.tensorflow.framework.GraphDef.<init>(GraphDef.java:13)
    at org.tensorflow.framework.GraphDef$1.parsePartialFrom(GraphDef.java:1543)
    at org.tensorflow.framework.GraphDef$1.parsePartialFrom(GraphDef.java:1538)
    at com.google.protobuf.AbstractParser.parsePartialFrom(AbstractParser.java:163)
    at com.google.protobuf.AbstractParser.parseFrom(AbstractParser.java:197)
    at com.google.protobuf.AbstractParser.parseFrom(AbstractParser.java:209)
    at com.google.protobuf.AbstractParser.parseFrom(AbstractParser.java:214)
    at com.google.protobuf.AbstractParser.parseFrom(AbstractParser.java:49)
    at org.tensorflow.framework.GraphDef.parseFrom(GraphDef.java:443)
    at org.tensorframes.impl.TensorFlowOps$.readGraphSerial(TensorFlowOps.scala:73)
    at com.databricks.sparkdl.python.GraphModelFactory.buildGraphDef(ModelFactory.scala:174)
    at com.databricks.sparkdl.python.GraphModelFactory.makeUDF(ModelFactory.scala:136)
    at com.databricks.sparkdl.python.GraphModelFactory.registerUDF(ModelFactory.scala:160)

I'm stumped on how to debug this. Do I need to take extra steps in exporting the model?

jpurnell01 commented 6 years ago

I'm using Spark 2.3.0, spark-deep-learning 1.1.0, and scala 2.11.8. I'm currently working off a local cluster with the following setup:

val spark: SparkSession = SparkSession
    .builder()
    .appName("test")
    .master("local[*]")
    .config("spark.driver.memory", "4G")
    .config("spark.kryoserializer.buffer.max", "200M")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .config("spark.network.timeout", "10001s")
    .config("spark.executor.heartbeatInterval", "10000s")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .config("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator")
    .getOrCreate()