hail-is / hail

Cloud-native genomic dataframes and batch computing
https://hail.is
MIT License
982 stars 246 forks source link

[query] reference genome operations are too slow #13862

Open danking opened 1 year ago

danking commented 1 year ago

What happened?

Consider this code:

    male = karyotype_expr == xy_karyotype_str
    female = karyotype_expr == xx_karyotype_str
    x_nonpar = locus_expr.in_x_nonpar()
    y_par = locus_expr.in_y_par()
    y_nonpar = locus_expr.in_y_nonpar()
    return (
        hl.case(missing_false=True)
        .when(female & (y_par | y_nonpar), hl.null(hl.tcall))
        .when(male & (x_nonpar | y_nonpar) & gt_expr.is_het(), hl.null(hl.tcall))
        .when(male & (x_nonpar | y_nonpar), hl.call(gt_expr[0], phased=False))
        .default(gt_expr)
    )

A single partition is taking a very long time to compute. Manual sampling of stack traces via jstack or the Spark UI reveals we spend a lot of time in computing the inPar predicates:

app//is.hail.utils.Interval.contains(Interval.scala:67)
app//is.hail.variant.ReferenceGenome.$anonfun$inPar$1(ReferenceGenome.scala:298)
app//is.hail.variant.ReferenceGenome.$anonfun$inPar$1$adapted(ReferenceGenome.scala:298)
app//is.hail.variant.ReferenceGenome$Lambda$924/0x00000008009b2840.apply(Unknown Source)
app//scala.collection.IndexedSeqOptimized.prefixLengthImpl(IndexedSeqOptimized.scala:41)
app//scala.collection.IndexedSeqOptimized.exists(IndexedSeqOptimized.scala:49)
app//scala.collection.IndexedSeqOptimized.exists$(IndexedSeqOptimized.scala:49)
app//scala.collection.mutable.ArrayOps$ofRef.exists(ArrayOps.scala:198)
app//is.hail.variant.ReferenceGenome.inPar(ReferenceGenome.scala:298)
app//is.hail.variant.ReferenceGenome.inYPar(ReferenceGenome.scala:302)__C9622collect_distributed_array_table_native_writer.__m10668inYPar(Unknown Source)__C9622collect_distributed_array_table_native_writer.__m10656split_Let(Unknown Source)__C9622collect_distributed_array_table_native_writer.__m10638split_ToArray_region3_65(Unknown Source)__C9622collect_distributed_array_table_native_writer.__m10638split_ToArray(Unknown Source)__C9622collect_distributed_array_table_native_writer.__m9658split_Let_region608_615(Unknown Source)__C9622collect_distributed_array_table_native_writer.__m9658split_Let_region21_922(Unknown Source)__C9622collect_distributed_array_table_native_writer.__m9658split_Let(Unknown Source)__C9622collect_distributed_array_table_native_writer.apply(Unknown Source)__C9622collect_distributed_array_table_native_writer.apply(Unknown Source)
app//is.hail.backend.BackendUtils.$anonfun$collectDArray$6(BackendUtils.scala:52)
app//is.hail.backend.BackendUtils$Lambda$783/0x000000080080c040.apply(Unknown Source)
app//is.hail.utils.package$.using(package.scala:635)
app//is.hail.annotations.RegionPool.scopedRegion(RegionPool.scala:162)
app//is.hail.backend.BackendUtils.$anonfun$collectDArray$5(BackendUtils.scala:51)
app//is.hail.backend.BackendUtils$Lambda$757/0x00000008007bcc40.apply(Unknown Source)
app//is.hail.backend.spark.SparkBackendComputeRDD.compute(SparkBackend.scala:751)
app//org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
app//org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
app//org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
app//org.apache.spark.scheduler.Task.run(Task.scala:136)
app//org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
app//org.apache.spark.executor.Executor$TaskRunner$Lambda$608/0x0000000800652c40.apply(Unknown Source)
app//org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
app//org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
java.base@11.0.17/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
java.base@11.0.17/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
java.base@11.0.17/java.lang.Thread.run(Thread.java:829)

A few things:

  1. Verify that this case statement is evaluated intelligently. In particular, we really want to evaluate each predicate once, and only if necessary.
  2. We should not allocate just to evaluate these reference genome predicates, but that is exactly what we do.

It seems like the right fix is for the ReferenceGenome's intervals to be shipped as literals so that we can perform inXPar or isAutosomal checks without allocating contig strings or locus objects.

Version

0.2.124

Relevant log output

No response

danking commented 1 year ago

Something along these lines might work. The trouble is when you get to replacing the functions in LocusFunctions.scala. In there, you need set containment and interval containment. Set containment is currently implemented in terms of IR, see SetFunctions.contains. I'm not exactly sure the easiest way to fix that. We can't reference Code things in IR, but I don't know how to compile the IR in-line like that.

(base) dking@wm28c-761 hail % g diff
diff --git a/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala b/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala
index 115df824b3..6e5ee81e6a 100644
--- a/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala
+++ b/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala
@@ -59,6 +59,35 @@ class EmitModuleBuilder(val ctx: ExecuteContext, val modb: ModuleBuilder) {
     new StaticFieldRef(rgField)
   }

+  class LoweredReferenceGenome(
+    name: SStringPointerValue,
+    contigs: SIndexablePointerValue,
+    lengths: SIndexablePointerValue,
+    xContigs: SIndexablePointerValue,
+    yContigs: SIndexablePointerValue,
+    mtContigs: SIndexablePointerValue,
+    parInterval: SIntervalPointerValue
+  )
+
+  private val loweredReferences: mutable.Map[String, StaticField[Long]] = mutable.Map.empty
+
+  def getLoweredReferenceGenome(cb: EmitCodeBuilder, name: String): LoweredReferenceGenome = {
+    loweredReferences.getOrElseUpdate(name, {
+      val ecb = genEmitClass[Unit](s"lowered_reference_${name}")
+      val rg = ctx.getReference(name)
+      assert(rg.name == name)
+      new LoweredReferenceGenome(
+        ecb.addLiteral(cb, rg.name, VirtualTypeWithReq.fullyRequired(TString)).asInstanceOf[SStringPointerValue],
+        ecb.addLiteral(cb, rg.contigs, VirtualTypeWithReq.fullyRequired(TArray(TString))).asInstanceOf[SIndexablePointerValue],
+        ecb.addLiteral(cb, rg.lengths, VirtualTypeWithReq.fullyRequired(TArray(TInt32))).asInstanceOf[SIndexablePointerValue],
+        ecb.addLiteral(cb, rg.xContigs, VirtualTypeWithReq.fullyRequired(TSet(TString))).asInstanceOf[SIndexablePointerValue],
+        ecb.addLiteral(cb, rg.yContigs, VirtualTypeWithReq.fullyRequired(TSet(TString))).asInstanceOf[SIndexablePointerValue],
+        ecb.addLiteral(cb, rg.mtContigs, VirtualTypeWithReq.fullyRequired(TSet(TString))).asInstanceOf[SIndexablePointerValue],
+        ecb.addLiteral(cb, Interval(rg.parInput._1, rg.parInput._2), VirtualTypeWithReq.fullyRequired(TInterval(TLocus(rg.name)))).asInstanceOf[SIntervalPointerValue]
+      )
+    }
+  }
+
   def referenceGenomes(): IndexedSeq[ReferenceGenome] = rgContainers.keys.map(ctx.getReference(_)).toIndexedSeq.sortBy(_.name)
   def referenceGenomeFields(): IndexedSeq[StaticField[ReferenceGenome]] = rgContainers.toFastSeq.sortBy(_._1).map(_._2)
danking commented 1 year ago

FWIW, this pipeline was performing these checks perhaps as many as 10 times per genotype, which is obviously unreasonable. Nonetheless, sending the RG along as a literal should improve the speed of these operations.