NVIDIA / spark-rapids

Spark RAPIDS plugin - accelerate Apache Spark with GPUs
https://nvidia.github.io/spark-rapids
Apache License 2.0
765 stars 225 forks source link

[FEA] Support short mode of DayTimeIntervalType when cast string to daytime #10980

Open thirtiseven opened 1 month ago

thirtiseven commented 1 month ago

Is your feature request related to a problem? Please describe. Spark UT case Cast string to day-time interval and Take into account day-time interval fields in cast failed because we are not support short mode of DayTimeIntervalType when cast string to daytime.

Uncomment the Cast string to day-time interval or Take into account day-time interval fields in cast in RapidsTestSettings.scala, then run RapidsCastSuite:

mvn test -Dbuildver=330 -DwildcardSuites=org.apache.spark.sql.rapids.suites.RapidsCastSuite

got:

org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 13762.0 failed 1 times, most recent failure: Lost task 1.0 in stage 13762.0 (TID 27525) (spark-haoyang executor driver): java.lang.IllegalArgumentException: Cast string to day time interval failed, may be the format is invalid, range check failed or overflow
    at com.nvidia.spark.rapids.shims.GpuIntervalUtilsBase.$anonfun$castStringToDayTimeIntervalWithThrow$1(GpuIntervalUtilsBase.scala:174)
    at com.nvidia.spark.rapids.Arm$.withResource(Arm.scala:30)
    at com.nvidia.spark.rapids.shims.GpuIntervalUtilsBase.castStringToDayTimeIntervalWithThrow(GpuIntervalUtilsBase.scala:172)
    at com.nvidia.spark.rapids.shims.GpuIntervalUtilsBase.castStringToDayTimeIntervalWithThrow$(GpuIntervalUtilsBase.scala:171)
    at com.nvidia.spark.rapids.shims.GpuIntervalUtils$.castStringToDayTimeIntervalWithThrow(GpuIntervalUtils.scala:34)
    at com.nvidia.spark.rapids.shims.GpuIntervalUtilsBase.castStringToDayTimeIntervalWithThrow(GpuIntervalUtilsBase.scala:160)
    at com.nvidia.spark.rapids.shims.GpuIntervalUtilsBase.castStringToDayTimeIntervalWithThrow$(GpuIntervalUtilsBase.scala:159)
    at com.nvidia.spark.rapids.shims.GpuIntervalUtils$.castStringToDayTimeIntervalWithThrow(GpuIntervalUtils.scala:34)
    at com.nvidia.spark.rapids.GpuCast$.doCast(GpuCast.scala:593)
    at com.nvidia.spark.rapids.GpuCast.doColumnar(GpuCast.scala:1903)
    at com.nvidia.spark.rapids.GpuUnaryExpression.doItColumnar(GpuExpressions.scala:250)
    at com.nvidia.spark.rapids.GpuUnaryExpression.$anonfun$columnarEval$1(GpuExpressions.scala:261)
    at com.nvidia.spark.rapids.Arm$.withResource(Arm.scala:30)
    at com.nvidia.spark.rapids.GpuUnaryExpression.columnarEval(GpuExpressions.scala:260)
    at com.nvidia.spark.rapids.RapidsPluginImplicits$ReallyAGpuExpression.columnarEval(implicits.scala:35)
    at com.nvidia.spark.rapids.GpuAlias.columnarEval(namedExpressions.scala:110)
    at com.nvidia.spark.rapids.RapidsPluginImplicits$ReallyAGpuExpression.columnarEval(implicits.scala:35)
    at com.nvidia.spark.rapids.GpuProjectExec$.$anonfun$project$1(basicPhysicalOperators.scala:110)
    at com.nvidia.spark.rapids.RapidsPluginImplicits$MapsSafely.$anonfun$safeMap$1(implicits.scala:221)
    at com.nvidia.spark.rapids.RapidsPluginImplicits$MapsSafely.$anonfun$safeMap$1$adapted(implicits.scala:218)
    at scala.collection.immutable.List.foreach(List.scala:431)
    at com.nvidia.spark.rapids.RapidsPluginImplicits$MapsSafely.safeMap(implicits.scala:218)
    at com.nvidia.spark.rapids.RapidsPluginImplicits$AutoCloseableProducingSeq.safeMap(implicits.scala:253)
    at com.nvidia.spark.rapids.GpuProjectExec$.project(basicPhysicalOperators.scala:110)
    at com.nvidia.spark.rapids.GpuTieredProject.$anonfun$project$2(basicPhysicalOperators.scala:619)
    at com.nvidia.spark.rapids.Arm$.withResource(Arm.scala:30)
    at com.nvidia.spark.rapids.GpuTieredProject.recurse$2(basicPhysicalOperators.scala:618)
    at com.nvidia.spark.rapids.GpuTieredProject.project(basicPhysicalOperators.scala:631)
    at com.nvidia.spark.rapids.GpuTieredProject.$anonfun$projectWithRetrySingleBatchInternal$5(basicPhysicalOperators.scala:567)
    at com.nvidia.spark.rapids.RmmRapidsRetryIterator$.withRestoreOnRetry(RmmRapidsRetryIterator.scala:272)
    at com.nvidia.spark.rapids.GpuTieredProject.$anonfun$projectWithRetrySingleBatchInternal$4(basicPhysicalOperators.scala:567)
    at com.nvidia.spark.rapids.Arm$.withResource(Arm.scala:30)
    at com.nvidia.spark.rapids.GpuTieredProject.$anonfun$projectWithRetrySingleBatchInternal$3(basicPhysicalOperators.scala:565)
    at com.nvidia.spark.rapids.RmmRapidsRetryIterator$NoInputSpliterator.next(RmmRapidsRetryIterator.scala:395)
    at com.nvidia.spark.rapids.RmmRapidsRetryIterator$RmmRapidsRetryIterator.next(RmmRapidsRetryIterator.scala:613)
    at com.nvidia.spark.rapids.RmmRapidsRetryIterator$RmmRapidsRetryAutoCloseableIterator.next(RmmRapidsRetryIterator.scala:517)
    at com.nvidia.spark.rapids.RmmRapidsRetryIterator$.drainSingleWithVerification(RmmRapidsRetryIterator.scala:291)
    at com.nvidia.spark.rapids.RmmRapidsRetryIterator$.withRetryNoSplit(RmmRapidsRetryIterator.scala:185)
    at com.nvidia.spark.rapids.GpuTieredProject.$anonfun$projectWithRetrySingleBatchInternal$1(basicPhysicalOperators.scala:565)
    at com.nvidia.spark.rapids.Arm$.withResource(Arm.scala:39)
    at com.nvidia.spark.rapids.GpuTieredProject.projectWithRetrySingleBatchInternal(basicPhysicalOperators.scala:562)
    at com.nvidia.spark.rapids.GpuTieredProject.projectAndCloseWithRetrySingleBatch(basicPhysicalOperators.scala:601)
    at com.nvidia.spark.rapids.GpuProjectExec.$anonfun$internalDoExecuteColumnar$2(basicPhysicalOperators.scala:384)
    at com.nvidia.spark.rapids.Arm$.withResource(Arm.scala:30)
    at com.nvidia.spark.rapids.GpuProjectExec.$anonfun$internalDoExecuteColumnar$1(basicPhysicalOperators.scala:380)
    at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
    at com.nvidia.spark.rapids.ColumnarToRowIterator.$anonfun$fetchNextBatch$3(GpuColumnarToRowExec.scala:290)
    at com.nvidia.spark.rapids.Arm$.withResource(Arm.scala:30)
    at com.nvidia.spark.rapids.ColumnarToRowIterator.fetchNextBatch(GpuColumnarToRowExec.scala:287)
    at com.nvidia.spark.rapids.ColumnarToRowIterator.loadNextBatch(GpuColumnarToRowExec.scala:257)
    at com.nvidia.spark.rapids.ColumnarToRowIterator.hasNext(GpuColumnarToRowExec.scala:304)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
    at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:364)
    at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:890)
    at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:890)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:136)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:750)

Describe the solution you'd like

We will check the subtype in DayTimeIntervalType and find a regex pattern in GpuIntervalUtilsBase , then extract the input data with the pattern. We can updated those pattern so it can also match data in short mode.

Describe alternatives you've considered If we are not plan to support it, at least we need to note we do not support short mode of DayTimeIntervalType in compatibility doc.

thirtiseven commented 1 month ago

Specific cases in Spark UT that fails (commented):

 test("SPARK-35112: Cast string to day-time interval2") {
//    checkEvaluation(cast(Literal.create("0"), DayTimeIntervalType()), 0L)
//    checkEvaluation(cast(Literal.create("0 0:0:0"), DayTimeIntervalType()), 0L)
//    checkEvaluation(cast(Literal.create(" interval '0 0:0:0' Day TO second   "),
//      DayTimeIntervalType()), 0L)
    checkEvaluation(cast(Literal.create("INTERVAL '1 2:03:04' DAY TO SECOND"),
      DayTimeIntervalType()), 93784000000L)
    checkEvaluation(cast(Literal.create("INTERVAL '1 03:04:00' DAY TO SECOND"),
      DayTimeIntervalType()), 97440000000L)
    checkEvaluation(cast(Literal.create("INTERVAL '1 03:04:00.0000' DAY TO SECOND"),
      DayTimeIntervalType()), 97440000000L)
//    checkEvaluation(cast(Literal.create("1 2:03:04"), DayTimeIntervalType()), 93784000000L)
    checkEvaluation(cast(Literal.create("INTERVAL '-10 2:03:04' DAY TO SECOND"),
      DayTimeIntervalType()), -871384000000L)
//    checkEvaluation(cast(Literal.create("-10 2:03:04"), DayTimeIntervalType()), -871384000000L)
//    checkEvaluation(cast(Literal.create("-106751991 04:00:54.775808"), DayTimeIntervalType()),
//      Long.MinValue)
//    checkEvaluation(cast(Literal.create("106751991 04:00:54.775807"), DayTimeIntervalType()),
//      Long.MaxValue)
//
//    Seq("-106751991 04:00:54.775808", "106751991 04:00:54.775807").foreach { interval =>
//      val ansiInterval = s"INTERVAL '$interval' DAY TO SECOND"
//      checkEvaluation(
//        cast(cast(Literal.create(interval), DayTimeIntervalType()), StringType), ansiInterval)
//      checkEvaluation(cast(cast(Literal.create(ansiInterval),
//        DayTimeIntervalType()), StringType), ansiInterval)
//    }
//
    if (!isTryCast) {
      Seq("INTERVAL '-106751991 04:00:54.775809' DAY TO SECOND",
        "INTERVAL '106751991 04:00:54.775808' DAY TO SECOND").foreach { interval =>
        val e = intercept[ArithmeticException] {
          cast(Literal.create(interval), DayTimeIntervalType()).eval()
        }.getMessage
        assert(e.contains("long overflow"))
      }
    }

    Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Long.MaxValue, Long.MinValue + 1,
      Long.MinValue).foreach { duration =>
      val interval = Literal.create(
        Duration.of(duration, ChronoUnit.MICROS),
        DayTimeIntervalType())
      checkEvaluation(cast(cast(interval, StringType), DayTimeIntervalType()), duration)
    }
  }

  test("SPARK-35735: Take into account day-time interval fields in cast2") {
    def typeName(dataType: DayTimeIntervalType): String = {
      if (dataType.startField == dataType.endField) {
        DayTimeIntervalType.fieldToString(dataType.startField).toUpperCase(Locale.ROOT)
      } else {
        s"${DayTimeIntervalType.fieldToString(dataType.startField)} TO " +
          s"${DayTimeIntervalType.fieldToString(dataType.endField)}".toUpperCase(Locale.ROOT)
      }
    }

    Seq(("1", DayTimeIntervalType(DAY, DAY), (86400) * MICROS_PER_SECOND),
      ("-1", DayTimeIntervalType(DAY, DAY), -(86400) * MICROS_PER_SECOND),
      ("1 01", DayTimeIntervalType(DAY, HOUR), (86400 + 3600) * MICROS_PER_SECOND),
      ("-1 01", DayTimeIntervalType(DAY, HOUR), -(86400 + 3600) * MICROS_PER_SECOND),
      ("1 01:01", DayTimeIntervalType(DAY, MINUTE), (86400 + 3600 + 60) * MICROS_PER_SECOND),
      ("-1 01:01", DayTimeIntervalType(DAY, MINUTE), -(86400 + 3600 + 60) * MICROS_PER_SECOND),
      ("1 01:01:01.12345", DayTimeIntervalType(DAY, SECOND),
        ((86400 + 3600 + 60 + 1.12345) * MICROS_PER_SECOND).toLong),
      ("-1 01:01:01.12345", DayTimeIntervalType(DAY, SECOND),
        (-(86400 + 3600 + 60 + 1.12345) * MICROS_PER_SECOND).toLong),

      ("01", DayTimeIntervalType(HOUR, HOUR), (3600) * MICROS_PER_SECOND),
      ("-01", DayTimeIntervalType(HOUR, HOUR), -(3600) * MICROS_PER_SECOND),
      ("01:01", DayTimeIntervalType(HOUR, MINUTE), (3600 + 60) * MICROS_PER_SECOND),
      ("-01:01", DayTimeIntervalType(HOUR, MINUTE), -(3600 + 60) * MICROS_PER_SECOND),
      ("01:01:01.12345", DayTimeIntervalType(HOUR, SECOND),
        ((3600 + 60 + 1.12345) * MICROS_PER_SECOND).toLong),
      ("-01:01:01.12345", DayTimeIntervalType(HOUR, SECOND),
        (-(3600 + 60 + 1.12345) * MICROS_PER_SECOND).toLong),

      ("01", DayTimeIntervalType(MINUTE, MINUTE), (60) * MICROS_PER_SECOND),
      ("-01", DayTimeIntervalType(MINUTE, MINUTE), -(60) * MICROS_PER_SECOND),
      ("01:01", DayTimeIntervalType(MINUTE, SECOND), ((60 + 1) * MICROS_PER_SECOND)),
      ("01:01.12345", DayTimeIntervalType(MINUTE, SECOND),
        ((60 + 1.12345) * MICROS_PER_SECOND).toLong),
      ("-01:01.12345", DayTimeIntervalType(MINUTE, SECOND),
        (-(60 + 1.12345) * MICROS_PER_SECOND).toLong),

      ("01.12345", DayTimeIntervalType(SECOND, SECOND), ((1.12345) * MICROS_PER_SECOND).toLong),
      ("-01.12345", DayTimeIntervalType(SECOND, SECOND), (-(1.12345) * MICROS_PER_SECOND).toLong))
      .foreach { case (str, dataType, dt) =>
//        checkEvaluation(cast(Literal.create(str), dataType), dt)
        checkEvaluation(
          cast(Literal.create(s"INTERVAL '$str' ${typeName(dataType)}"), dataType), dt)
        checkEvaluation(
          cast(Literal.create(s"INTERVAL -'$str' ${typeName(dataType)}"), dataType), -dt)
      }

    // Check max value
    Seq(("INTERVAL '106751991' DAY", DayTimeIntervalType(DAY), 106751991L * MICROS_PER_DAY),
      ("INTERVAL '106751991 04' DAY TO HOUR", DayTimeIntervalType(DAY, HOUR), 9223372036800000000L),
      ("INTERVAL '106751991 04:00' DAY TO MINUTE",
        DayTimeIntervalType(DAY, MINUTE), 9223372036800000000L),
      ("INTERVAL '106751991 04:00:54.775807' DAY TO SECOND", DayTimeIntervalType(), Long.MaxValue),
      ("INTERVAL '2562047788' HOUR", DayTimeIntervalType(HOUR), 9223372036800000000L),
      ("INTERVAL '2562047788:00' HOUR TO MINUTE",
        DayTimeIntervalType(HOUR, MINUTE), 9223372036800000000L),
      ("INTERVAL '2562047788:00:54.775807' HOUR TO SECOND",
        DayTimeIntervalType(HOUR, SECOND), Long.MaxValue),
      ("INTERVAL '153722867280' MINUTE", DayTimeIntervalType(MINUTE), 9223372036800000000L),
      ("INTERVAL '153722867280:54.775807' MINUTE TO SECOND",
        DayTimeIntervalType(MINUTE, SECOND), Long.MaxValue),
      ("INTERVAL '9223372036854.775807' SECOND", DayTimeIntervalType(SECOND), Long.MaxValue))
      .foreach { case (interval, dataType, dt) =>
        checkEvaluation(cast(Literal.create(interval), dataType), dt)
        checkEvaluation(cast(Literal.create(interval.toLowerCase(Locale.ROOT)), dataType), dt)
      }

    Seq(("INTERVAL '-106751991' DAY", DayTimeIntervalType(DAY), -106751991L * MICROS_PER_DAY),
      ("INTERVAL '-106751991 04' DAY TO HOUR",
        DayTimeIntervalType(DAY, HOUR), -9223372036800000000L),
      ("INTERVAL '-106751991 04:00' DAY TO MINUTE",
        DayTimeIntervalType(DAY, MINUTE), -9223372036800000000L),
      ("INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND", DayTimeIntervalType(), Long.MinValue),
      ("INTERVAL '-2562047788' HOUR", DayTimeIntervalType(HOUR), -9223372036800000000L),
      ("INTERVAL '-2562047788:00' HOUR TO MINUTE",
        DayTimeIntervalType(HOUR, MINUTE), -9223372036800000000L),
      ("INTERVAL '-2562047788:00:54.775808' HOUR TO SECOND",
        DayTimeIntervalType(HOUR, SECOND), Long.MinValue),
      ("INTERVAL '-153722867280' MINUTE", DayTimeIntervalType(MINUTE), -9223372036800000000L),
      ("INTERVAL '-153722867280:54.775808' MINUTE TO SECOND",
        DayTimeIntervalType(MINUTE, SECOND), Long.MinValue),
      ("INTERVAL '-9223372036854.775808' SECOND", DayTimeIntervalType(SECOND), Long.MinValue))
      .foreach { case (interval, dataType, dt) =>
        checkEvaluation(cast(Literal.create(interval), dataType), dt)
      }

    if (!isTryCast) {
      Seq(
        ("INTERVAL '1 01:01:01.12345' DAY TO SECOND", DayTimeIntervalType(DAY, HOUR)),
        ("INTERVAL '1 01:01:01.12345' DAY TO HOUR", DayTimeIntervalType(DAY, SECOND)),
        ("INTERVAL '1 01:01:01.12345' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
        ("1 01:01:01.12345", DayTimeIntervalType(DAY, DAY)),
        ("1 01:01:01.12345", DayTimeIntervalType(DAY, HOUR)),
        ("1 01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)),

        ("INTERVAL '01:01:01.12345' HOUR TO SECOND", DayTimeIntervalType(DAY, HOUR)),
        ("INTERVAL '01:01:01.12345' HOUR TO HOUR", DayTimeIntervalType(DAY, SECOND)),
        ("INTERVAL '01:01:01.12345' HOUR TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
        ("01:01:01.12345", DayTimeIntervalType(DAY, DAY)),
        ("01:01:01.12345", DayTimeIntervalType(HOUR, HOUR)),
        ("01:01:01.12345", DayTimeIntervalType(DAY, MINUTE)),
        ("INTERVAL '1.23' DAY", DayTimeIntervalType(DAY)),
        ("INTERVAL '1.23' HOUR", DayTimeIntervalType(HOUR)),
        ("INTERVAL '1.23' MINUTE", DayTimeIntervalType(MINUTE)),
        ("INTERVAL '1.23' SECOND", DayTimeIntervalType(MINUTE)),
        ("1.23", DayTimeIntervalType(DAY)),
        ("1.23", DayTimeIntervalType(HOUR)),
        ("1.23", DayTimeIntervalType(MINUTE)),
        ("1.23", DayTimeIntervalType(MINUTE)))
        .foreach { case (interval, dataType) =>
          val e = intercept[IllegalArgumentException] {
            cast(Literal.create(interval), dataType).eval()
          }.getMessage
          assert(e.contains(s"Interval string does not match day-time format of " +
            s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
              .map(format => s"`$format`").mkString(", ")} " +
            s"when cast to ${dataType.typeName}: $interval, " +
            s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
            "to restore the behavior before Spark 3.0."))
        }

      // Check first field outof bound
      Seq(("INTERVAL '1067519911' DAY", DayTimeIntervalType(DAY)),
        ("INTERVAL '10675199111 04' DAY TO HOUR", DayTimeIntervalType(DAY, HOUR)),
        ("INTERVAL '1067519911 04:00' DAY TO MINUTE", DayTimeIntervalType(DAY, MINUTE)),
        ("INTERVAL '1067519911 04:00:54.775807' DAY TO SECOND", DayTimeIntervalType()),
        ("INTERVAL '25620477881' HOUR", DayTimeIntervalType(HOUR)),
        ("INTERVAL '25620477881:00' HOUR TO MINUTE", DayTimeIntervalType(HOUR, MINUTE)),
        ("INTERVAL '25620477881:00:54.775807' HOUR TO SECOND", DayTimeIntervalType(HOUR, SECOND)),
        ("INTERVAL '1537228672801' MINUTE", DayTimeIntervalType(MINUTE)),
        ("INTERVAL '1537228672801:54.7757' MINUTE TO SECOND", DayTimeIntervalType(MINUTE, SECOND)),
        ("INTERVAL '92233720368541.775807' SECOND", DayTimeIntervalType(SECOND)))
        .foreach { case (interval, dataType) =>
          val e = intercept[IllegalArgumentException] {
            cast(Literal.create(interval), dataType).eval()
          }.getMessage
          assert(e.contains(s"Interval string does not match day-time format of " +
            s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
              .map(format => s"`$format`").mkString(", ")} " +
            s"when cast to ${dataType.typeName}: $interval, " +
            s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
            "to restore the behavior before Spark 3.0."))
        }
    }
  }
mattahrens commented 4 weeks ago

Initial scope: disable support for casting string to daytime intervals with fallback to CPU and add config with default set to off.