bartosz25 / spark-scala-playground

Sample processing code using Spark 2.1+ and Scala
50 stars 25 forks source link

Some(null) as the result of UserDefinedAggregationFunction#evaludate #14

Closed bithw1 closed 5 years ago

bithw1 commented 5 years ago

Hi @bartosz25 , I post the code here, it is to get the max digits given the english word:

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StringType, StructType}
import org.scalatest.{BeforeAndAfter, FunSuite}

object MaxNumber extends UserDefinedAggregateFunction {
  val map1 = Map[Int, String](
    -1 -> null,
    1 -> "One",
    2 -> "Two",
    3 -> "Three"

  )
  val map2 = map1.filter(_._1 >= 0).map { case (k, v) => (v, k) }

  override def inputSchema: StructType = new StructType().add("gap", "string")

  override def bufferSchema: StructType = new StructType().add("gap_int", "int")

  override def dataType: DataType = StringType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = -1
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val str = input.getString(0)
    if (str != null) {
      map2.get(str).foreach {
        v =>
          if (v > buffer.getInt(0)) {
            buffer.update(0, v)
          }
      }
    }
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val v1 = buffer1.getInt(0)
    val v2 = buffer2.getInt(0)
    if (v1 < v2) {
      buffer1.update(0, v2)
    }
  }

  override def evaluate(buffer: Row): Any = {
    val v = buffer.getInt(0)

    //NPE
    map1.get(v)

    //map1.get(v).get
  }

}

The test case code is as follows:

 test("SparkTest") {
    val spark = SparkSession.builder().master("local").appName("SparkTest").getOrCreate()
    import spark.implicits._
    spark.createDataset(
      Seq(
        (1, "One"),
        (1, "Three"),
        (1, null),
        (2, null),
        (3, "Two"),
        (3, "Three")
      )
    ).toDF("x", "y").createOrReplaceTempView("t")
    spark.udf.register("max_num", MaxNumber)

    spark.sql(
      """
         select x, max_num(y) as y from t group by x
      """.stripMargin(' ')).show(truncate = false)
  }

It will throw NPE when I run it.

bartosz25 commented 5 years ago

Hi @bithw1 ,

I analyzed the code and IMO the problem is a misuse of optionals rather than Spark code. If you want to mark a value as missing, you should use None instead of Some(null). Some type by itself tells that Some(thing) exists and you can freely call the functions on the underlying object. null is a little bit contradictory to Some and the engine tries to apply a computation on it which obviously leads to a NPE.

Instead you can:

Best regards, Bartosz.

bithw1 commented 5 years ago

Thanks @bartosz25 . I agree that I have misused Option as I didn't realize I was actually returning an Option instead of the real with Map.get in the evaluate method. But, I remembered that it works for you to return Some(null) in the evaluate method?

bartosz25 commented 5 years ago

In fact, the function from the post returns a TreeMap[Long, String]() that is held inside MutableAggregationBuffer at the index 0. And the map is defined in the initialize method and hence it's never an optional nor null.

Best regards, Bartosz.

bithw1 commented 5 years ago

thanks, @bartosz25