bartosz25 / spark-scala-playground

Sample processing code using Spark 2.1+ and Scala
50 stars 26 forks source link

Use spark sql optimizer to optimize two ranges join #11

Closed bithw1 closed 5 years ago

bithw1 commented 5 years ago

Hi @bartosz25 ,

I got a question that I would like you to help take a look, thank you.

I want to use spark sql optimizer to optimize two ranges join , just to calculate two range intersection,so that it can avoid join

test("SparkTest") {
    object RangeIntersectRule extends Rule[LogicalPlan] {
      override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
        case Join(Project(_, Range(start1, end1, _, _, _, _)), Project(_, Range(start2, end2, _, _, _, _)), _, _) => {
          val start = start1 max start2
          val end = end1 min end2
          if (start1 > end2 || end1 < start2) Range(0, 0, 1, Some(1), false) else Range(start, end, 1, Some(1), false)

        }
      }
    }

    val spark = SparkSession.builder().master("local").appName("SparkTest").enableHiveSupport().getOrCreate()
    spark.experimental.extraOptimizations = Seq(RangeIntersectRule)
    spark.range(10, 40).toDF("x").createOrReplaceTempView("t1")
    spark.range(20, 50).toDF("y").createOrReplaceTempView("t2")
    val df = spark.sql("select t1.x from t1 join t2 on t1.x = t2.y")
    df.explain(true)
    df.show(truncate = false)
  }

The rule takes effect, but it throws exception as follows, looks I haven't implemented the apply method appropriately

The plan is

== Parsed Logical Plan ==
'Project [UnresolvedAttribute_'t1.x]
+- 'Join Inner, (UnresolvedAttribute_'t1.x = UnresolvedAttribute_'t2.y)
   :- 'UnresolvedRelation `t1`
   +- 'UnresolvedRelation `t2`

== Analyzed Logical Plan ==
x: bigint
Project [x#2L]
+- Join Inner, (x#2L = y#6L)
   :- SubqueryAlias `t1`
   :  +- Project [id#0L AS x#2L]
   :     +- Range (10, 40, step=1, splits=Some(1))
   +- SubqueryAlias `t2`
      +- Project [id#4L AS y#6L]
         +- Range (20, 50, step=1, splits=Some(1))

== Optimized Logical Plan ==
!Project [x#2L]
+- Range (20, 40, step=1, splits=Some(1))

== Physical Plan ==
!Project [x#2L]
+- Range (20, 40, step=1, splits=1)

The exception is:

Caused by: java.lang.RuntimeException: Couldn't find x#2L in [id#14L]
    at scala.sys.package$.error(package.scala:27)
    at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:107)
    at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:101)
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:59)
    ... 50 more
bartosz25 commented 5 years ago

Hi @bithw1

I'll try to take a look today or tomorrow.

Best regards, Bartosz.

bartosz25 commented 5 years ago

@bithw1

The range method returns a DataFrame with a single one column called id. It's why the engines looks for it in your query declaration:

   * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements
   * in a range from `start` to `end` (exclusive) with step value 1.

If you rewrite your query like this it should work:

     val spark = SparkSession.builder().master("local").appName("SparkTest").enableHiveSupport().getOrCreate()
    spark.experimental.extraOptimizations = Seq(RangeIntersectRule)
    spark.range(10, 40).createOrReplaceTempView("t1")
    spark.range(20, 50).createOrReplaceTempView("t2")
    val df = spark.sql("select t1.id from t1 join t2 on t1.id = t2.id")
    df.explain(true)
    df.show(truncate = false)

Best regards, Bartosz.

bithw1 commented 5 years ago

Thanks @bartosz25 , but when rewriting it using column id, the rule doesn't take effect, you could see that the physical plan is still using BroadcastHashJoin.

When using id, I try to modify the rule'apply method, I still can't be able to make the apply method work ,still throwing Caused by: java.lang.RuntimeException: Couldn't find id#0L in [id#2L]

    object RangeIntersectRule extends Rule[LogicalPlan] {
      override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
        case Join(Range(start1, end1, 1, Some(1), output1, false), Range(start2, end2, 1, Some(1), output2, false), Inner, _) => {
          val start = start1 max start2
          val end = end1 min end2
          if (start1 > end2 || end1 < start2) Range(0, 0, 1, Some(1), output1, false)
          else Range(start, end, 1, Some(1), output1, false)
        }
      }
    }
bithw1 commented 5 years ago

Hi @bartosz25

I wrap Rangewith Project as the following code does, it works but I have no idea why it should be wrapped with Project,could you please help take a look? thank you.

    object RangeIntersectRule extends Rule[LogicalPlan] {
      override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
        case Join(Range(start1, end1, 1, Some(1), output1, false), Range(start2, end2, 1, Some(1), output2, false), Inner, _) => {
          val start = start1 max start2
          val end = end1 min end2
          if (start1 > end2 || end1 < start2) Project(output1, Range(0, 0, 1, Some(1), output1, false))
          //wrap Range with Project
          else Project(output1, Range(start, end, 1, Some(1), output1, false))
        }
      }
    }
bartosz25 commented 5 years ago

Hi @bithw1

Sorry, I missed your message last week. I'll add the topic of extra optimizations to my backlog and try to answer your question here when I'll write about it.

Best regards, Bartosz.

bithw1 commented 5 years ago

Sure, thank you @bartosz25 !

bartosz25 commented 5 years ago

Hi @bithw1

Today I started the topic of custom optimizations. Since the topic is quite new for me, I will go slowly from the basics and try to cover more advanced concepts at the end. The first post is there : https://www.waitingforcode.com/apache-spark-sql/introduction-custom-optimization-apache-spark-sql/read

Best regards, Bartosz.

bithw1 commented 5 years ago

That's great, thanks @bartosz25. Looking forward to your great posts and learn, :-)