Open Feng-Jiang28 opened 1 month ago
So what is happening here is that the data is being partitioned by the decimal column on the write.
scala> val df = Seq((new java.math.BigDecimal("2.125"), 100L),(new java.math.BigDecimal("2"), 200L)).toDF("db", "lval")
df: org.apache.spark.sql.DataFrame = [db: decimal(38,18), lval: bigint]
scala> df.show()
+--------------------+----+
| db|lval|
+--------------------+----+
|2.125000000000000000| 100|
|2.000000000000000000| 200|
+--------------------+----+
scala> df.printSchema
root
|-- db: decimal(38,18) (nullable = true)
|-- lval: long (nullable = false)
scala> df.write.mode("overwrite").partitionBy("db").parquet("./target/TMP")
scala> val df_read = spark.read.parquet("./target/TMP")
df_read: org.apache.spark.sql.DataFrame = [lval: bigint, db: double]
The value of db
is instantly turned into a double. It is no longer a decimal when you have partitioned by that column unless you ask for it to be.
scala> val df_read = spark.read.schema("lval LONG, db DECIMAL(38,18)").parquet("./target/TMP")
df_read: org.apache.spark.sql.DataFrame = [lval: bigint, db: decimal(38,18)]
The problem is how we cast doubles to big decimal values.
On the GPU:
scala> val df_read = spark.read.parquet("./target/TMP")
df_read: org.apache.spark.sql.DataFrame = [lval: bigint, db: double]
scala> df_read.selectExpr("*", "CAST(db as DECIMAL(38,18))").show()
24/11/05 17:12:36 WARN GpuOverrides:
!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it
@Partitioning <SinglePartition$> could run on GPU
+----+-----+--------------------+
|lval| db| db|
+----+-----+--------------------+
| 200| 2.0|2.000000000000000000|
| 100|2.125|2.125000000000000200|
+----+-----+--------------------+
on the CPU:
scala> val df_read = spark.read.parquet("./target/TMP")
df_read: org.apache.spark.sql.DataFrame = [lval: bigint, db: double]
scala> df_read.selectExpr("*", "CAST(db as DECIMAL(38,18))").show()
+----+-----+--------------------+
|lval| db| db|
+----+-----+--------------------+
| 200| 2.0|2.000000000000000000|
| 100|2.125|2.125000000000000000|
+----+-----+--------------------+
We have tried to fix this in the past, but it is still an issue.
The issue: As you noted, the value of p_0 is not correct in the output. The original value was 2.125, but the output shows 2.125000000000000200. The discrepancy is due to precision issues when working with demical types in Spark Rapids and Parquet. When writing and then reading back the decimal value, additional digits are introduced.
Reproduce:
GPU:
CPU: