linkedin / spark-tfrecord

Read and write Tensorflow TFRecord data from Apache Spark.
BSD 2-Clause "Simplified" License
290 stars 57 forks source link

Fixed deserializer for case when sequential rows have different features #26

Closed schCRABicus closed 3 years ago

schCRABicus commented 3 years ago

Problem

We have an issue when deserializing tfrecords. The problem seems to exist in case when sequential records have different features inside.

The issue occurs when sequential records with different features inside are deserialised. In this case the subsequent row inherits the missing values from the preceding row which leads to incorrect deserialisation.

Here is the example which highlights the issue:

import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import org.apache.spark.sql._

// Declare DataFrame data
val testRows: Array[Row] = Array(
  new GenericRow(Array[Any](11, Map("foo" -> "1", "bar" -> "2"))),
  new GenericRow(Array[Any](21, Map("foo" -> "4", "baz" -> "3"))),
  new GenericRow(Array[Any](31, Map("abc" -> "7", "bcs" -> "8"))),
  new GenericRow(Array[Any](41, Map("foo" -> "6", "baz" -> "5")))
)

// DataFrame schema
val schema = StructType(List(StructField("id", IntegerType), 
                             StructField("features", MapType(StringType, StringType, false))))
// Create DataFrame
val rdd = spark.sparkContext.parallelize(testRows)
val df: DataFrame = spark.createDataFrame(rdd, schema).repartition(1)

// Make rows with different features
def explodeMap(
      df: DataFrame,
      mapColumn: String
  ): DataFrame = {
    val keysDF = df
      .select(explode(map_keys(col(mapColumn))))
      .distinct()

    val keys = keysDF
      .collect()
      .map(f => f.getAs[String](0))

    val explodedDF = keys
      .foldLeft(df)((d, key) => d.withColumn(key, col(s"$mapColumn.$key")))

    explodedDF.drop(col(mapColumn))
  }

val explodedDf = explodeMap(df, "features")
explodedDf.show()

// Now, write and read in tfrecord format
val path = "/tmp/dl/spark-tfrecord/test-output.tfrecord"
explodedDf.write.format("tfrecord").option("recordType", "Example").mode("overwrite").save(path)

// DataFrame schema
val explodedSchema = StructType(List(StructField("id", IntegerType), 
                                     StructField("foo", StringType),
                                    StructField("bar", StringType),
                                    StructField("baz", StringType),
                                    StructField("bcs", StringType),
                                    StructField("abc", StringType)))

//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
val loadedWithSchemaDf: DataFrame = spark.read.format("tfrecord").schema(explodedSchema).load(path)
loadedWithSchemaDf.show()

The outcome is that the initial explodedDf.show() prints the correct dataset -

+---+----+----+----+----+----+
| id| bcs| bar| foo| abc| baz|
+---+----+----+----+----+----+
| 11|null|   2|   1|null|null|
| 21|null|null|   4|null|   3|
| 31|   8|null|null|   7|null|
| 41|null|null|   6|null|   5|
+---+----+----+----+----+----+

meanwhile the dataset, read from tfrecord file and printed by loadedWithSchemaDf.show() looks as follows :

+---+---+---+----+----+----+
| id|foo|bar| baz| bcs| abc|
+---+---+---+----+----+----+
| 11|  1|  2|null|null|null|
| 21|  4|  2|   3|null|null|
| 31|  4|  2|   3|   8|   7|
| 41|  6|  2|   5|   8|   7|
+---+---+---+----+----+----+

Note that rows starting from second and till the end inherited the missing column data from the preceding rows thus resulting in incorrect dataset.

Root Cause

The root cause of the issue is the usage of private variable (i.e. state shared across multiple deserialisations) for result row in TFRecordDeserializer - private val resultRow = new SpecificInternalRow(dataSchema.map(_.dataType)). With this, each subsequent record has pre-filled all columns and therefore if any is missing in this specific record, it's inherited from previous record deserialisation.

Solution

I suggest we initialise the result row for each record being deserialised. It solves the issue for us.

junshi15 commented 3 years ago

Thanks for your fix.

For the example above, can you show the result after your fix, what would loadedWithSchemaDf.show look like?

schCRABicus commented 3 years ago

@junshi15 , yes, with the fix applied, the output looks as follows (corresponds to initial dataset written to file):

+---+----+----+----+----+----+
| id| foo| bar| baz| bcs| abc|
+---+----+----+----+----+----+
| 11|   1|   2|null|null|null|
| 21|   4|null|   3|null|null|
| 31|null|null|null|   8|   7|
| 41|   6|null|   5|null|null|
+---+----+----+----+----+----+

Also, I've provided a dedicated test case as part of the PR and it shows the expected behaviour. Without the changes in TFRecordDeserializer, test case fails because the val expectedInternalRow2 is

 InternalRow.fromSeq(
         Array[Any](10.0F, 1, null)
       )

not

val expectedInternalRow2 = InternalRow.fromSeq(
         Array[Any](null, 1, null)
       )

I.e. inherits FloatLabel from first deserialised record.

junshi15 commented 3 years ago

Thanks for your contribution!

schCRABicus commented 3 years ago

@junshi15 , may I ask you please to publish the new artifact so that we could start using it in production code? Thank you!

junshi15 commented 3 years ago

It's already published here: https://search.maven.org/search?q=a:spark-tfrecord_2.12

Or are you using Scala 2.11? Spark 3.x uses Scala 2.12. For Scala 2.11, I need to build with Spark 2.4 or Spark 2.3.

I was not able to publish to Bintray since it is deprecated https://jfrog.com/center-sunset/.

schCRABicus commented 3 years ago

Got it, thank you! I'm using 2.12, so everything is fine, just didn't saw it before, sorry. Thank you!