NVIDIA / spark-rapids

Spark RAPIDS plugin - accelerate Apache Spark with GPUs
https://nvidia.github.io/spark-rapids
Apache License 2.0
812 stars 234 forks source link

[BUG] Spark UT framework: Various partition value types. Decimal precision issue. #11583

Open Feng-Jiang28 opened 1 month ago

Feng-Jiang28 commented 1 month ago

The issue: As you noted, the value of p_0 is not correct in the output. The original value was 2.125, but the output shows 2.125000000000000200. The discrepancy is due to precision issues when working with demical types in Spark Rapids and Parquet. When writing and then reading back the decimal value, additional digits are introduced.

Reproduce:

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.internal.SQLConf

val row = Row(
  new java.math.BigDecimal("2.125"),
  "This is not a partition column"
)

// BooleanType is not supported yet
val partitionColumnTypes = Seq(
  DecimalType.SYSTEM_DEFAULT
 )

val partitionColumns = partitionColumnTypes.zipWithIndex.map {
  case (t, index) => StructField(s"p_$index", t)
}

val schema = StructType(partitionColumns :+ StructField(s"i", StringType))

val df = spark.createDataFrame(
  spark.sparkContext.parallelize(Seq(row)), schema)

val outputDir = "/home/fejiang/Desktop/temp2" 
df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(outputDir)
val fields = Seq(
  col("p_0").cast(DecimalType.SYSTEM_DEFAULT),
  col("i").cast(StringType))  
val readDf = spark.read.parquet(outputDir)
readDf.select(fields: _*).show(truncate = false)

GPU:

scala> readDf.select(fields: _*).show(truncate = false)
24/10/10 16:05:31 WARN GpuOverrides: 
!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it
  @Partitioning <SinglePartition$> could run on GPU
  *Exec <ProjectExec> will run on GPU
    *Expression <Alias> cast(cast(p_0#9 as decimal(38,18)) as string) AS p_0#19 will run on GPU
      *Expression <Cast> cast(cast(p_0#9 as decimal(38,18)) as string) will run on GPU
        *Expression <Cast> cast(p_0#9 as decimal(38,18)) will run on GPU
    *Exec <FileSourceScanExec> will run on GPU

+--------------------+------------------------------+
|p_0                 |i                             |
+--------------------+------------------------------+
|2.125000000000000200|This is not a partition column|
+--------------------+------------------------------+

CPU:

scala> readDf.select(fields: _*).show(truncate = false)
+--------------------+------------------------------+
|p_0                 |i                             |
+--------------------+------------------------------+
|2.125000000000000000|This is not a partition column|
+--------------------+------------------------------+
revans2 commented 1 week ago

So what is happening here is that the data is being partitioned by the decimal column on the write.

scala> val df = Seq((new java.math.BigDecimal("2.125"), 100L),(new java.math.BigDecimal("2"), 200L)).toDF("db", "lval")
df: org.apache.spark.sql.DataFrame = [db: decimal(38,18), lval: bigint]

scala> df.show()
+--------------------+----+
|                  db|lval|
+--------------------+----+
|2.125000000000000000| 100|
|2.000000000000000000| 200|
+--------------------+----+

scala> df.printSchema
root
 |-- db: decimal(38,18) (nullable = true)
 |-- lval: long (nullable = false)

scala> df.write.mode("overwrite").partitionBy("db").parquet("./target/TMP")

scala> val df_read = spark.read.parquet("./target/TMP")
df_read: org.apache.spark.sql.DataFrame = [lval: bigint, db: double]

The value of db is instantly turned into a double. It is no longer a decimal when you have partitioned by that column unless you ask for it to be.

scala> val df_read = spark.read.schema("lval LONG, db DECIMAL(38,18)").parquet("./target/TMP")
df_read: org.apache.spark.sql.DataFrame = [lval: bigint, db: decimal(38,18)]

The problem is how we cast doubles to big decimal values.

On the GPU:

scala> val df_read = spark.read.parquet("./target/TMP")
df_read: org.apache.spark.sql.DataFrame = [lval: bigint, db: double]

scala> df_read.selectExpr("*", "CAST(db as DECIMAL(38,18))").show()
24/11/05 17:12:36 WARN GpuOverrides: 
!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it
  @Partitioning <SinglePartition$> could run on GPU

+----+-----+--------------------+
|lval|   db|                  db|
+----+-----+--------------------+
| 200|  2.0|2.000000000000000000|
| 100|2.125|2.125000000000000200|
+----+-----+--------------------+

on the CPU:

scala> val df_read = spark.read.parquet("./target/TMP")
df_read: org.apache.spark.sql.DataFrame = [lval: bigint, db: double]

scala> df_read.selectExpr("*", "CAST(db as DECIMAL(38,18))").show()
+----+-----+--------------------+
|lval|   db|                  db|
+----+-----+--------------------+
| 200|  2.0|2.000000000000000000|
| 100|2.125|2.125000000000000000|
+----+-----+--------------------+

We have tried to fix this in the past, but it is still an issue.