tresata / spark-sorted

Secondary sort and streaming reduce for Apache Spark
Apache License 2.0
78 stars 17 forks source link

Question: Memory Performance #10

Open thobrien opened 3 years ago

thobrien commented 3 years ago

Hi, this looks like a great package - THANK YOU very much for your work here.

I had a question and wasn't sure where to inquire -- I've been trying to solve memory usage of a pattern like below.

Basically the problem is that while I'm actually intending to use groupBy+sort+map to do some work, since there isn't a fast sort-within-a-group available (fast enough == for really big data sets), I'm actually using repartition+sortWithinPartitions+mapPartitions to do the work. Which is pretty fast somehow.

The downside to my old approach is that while I need to compute something independently per-group/"id", iterating the whole partition with mapPartitions forces me to hold onto the result of that computation in-memory for EVERY group/"id" until I reach the end of the partition. We'll have a relatively large result per-group/"id". So this is my concern, and why I find this library promising.

My question is: Does spark-sorted, when using GroupSorted(rdd, comparator).mapStreamByKey actually flush the results of each group out of memory as it proceeds down the partition? I stared at the code for a couple of hours and couldn't see how mapStreamByKey is managing memory of each group. Ideally I'd like whatever the group returned from mapStreamByKey to be cleaned from memory as it's put onto the stream between each group.

This would be a huge win for us.

Pseudo-code for old way:

data.repartition("id").sortWithinPartitions("id","time").mapPartitions(partitionIterator -> 
  val previousId:String = null;
  val output:List[Row] = new List();
  partitionIterator.forEach(event -> {
     val nextId = event.getAs("id");
     if (nextId!=previousID) {
       // put some stuff we learned while processing that ID stream on the partition into the output - ugly!
      previousId=nextId
    }
    // remember/buffer other state information while processing the ID so we have something to output when the ID changes
  }
  return output
) (myRowEncoder) 

Pseudo code for new way:

        GroupSorted<String,Row> groupSortedRDD = new GroupSorted<>(pairRDD, new RowTimeComparator());
        JavaRDD<Row> output = groupSortedRDD.mapStreamByKey(rowInGroupIterator -> {
           MyStateMachine myStateMachine = new MyStateMachine();
            rowInGroupIterator.forEachRemaining(row -> {
               myStateMachine.receiveEvent(row);
            } );
            return myStateMachine.getOutput().iterator();
        }).map(kvp -> kvp._2);

So ideally everything I return from myStateMachine.getOutput().iterator() would not stay in memory until the full partition walk is completed.

koertkuipers commented 3 years ago

the premise is that by using a secondary sort within a partition and processing a partition as an iterator and writing out an iterator we never materialize the entire result for a partition in memory (but note that spark might still save the entire result for a partition together serialized).

the plumbing of mapStreamByKey basically reads from an iterator, calls the provided function f (per key) repeatedly and concatenates the results lazily into a new iterator. this is the tricky bit because we are chaining iterators to avoid creating a result for the partition in memory. see com.tresata.spark.sorted.mapStreamIteratorWithContext. the goal is that spark as it writes out the result of the partition ends up calling the function f repeatedly to generate the result on the fly (as opposed to the function f being called repeatedly to build a strict collection in memory with the result for the entire partition, and then handing that in-memory result to spark to write out).

with dataframes we have less control and your approach of using repartition + sortWithinPartitions + mapPartitions is just as good. the same approach is sed in com.tresata.spark.sorted.sql.GroupSortedDataset. the crux is that per partition you do not collect your result in an in memory data structure (like your val output:List[Row]) but instead build your output lazily using an iterator (or TraversableOnce).

thobrien commented 3 years ago

Thanks @koertkuipers - the code pointers you sent are definitely insightful - this is quite a nice piece of work. I hope I can use it.

So I've spent the afternoon looking at this code and I think I see how it is doing the magic and chaining the iterators so that the function f() is called once per group. This should indeed improve the memory utilization by allowing memory to be flushed per-group rather than per-partition.

Oddly, though, in testing, when I transform our old approach to the GroupSorted + mapStreamByKey approach - using the code above, the execution on the same input data now ends up with 50+TB of shuffle spill, whereas it had 0 before. Lots of digging later, if I increase my partition count (from 4200 to 42000 - 10x), then I can run with minimal but never zero shuffle (so down to 500 GB or so from 50+TB) spill. BUT even in that case, in essence, the code block surrounding my f() is now taking 10x longer to execute.

In that latter case, the executor stack traces show the code is spending a lot of time deep within code under GroupSorted and always under ExternalSorter, which I imagine is part of the lazy sort coming from the mapStreamByKey processing of my function (which later I write to Hive)... In sub-stacks like:

For more stacks, see here

Have you seen behaviors like this?

While the shuffle spill is unexpected since we both believe the new code flushes data to the iterator more rapidly, it's manageable by increasing partitions. But taking 10x longer and seemingly always in sort-related functions like ExternalSorter.insertAll() has me wondering if there is some inefficiency in ExternalSorter that somehow is triggering in this case but not my previous .repartition().sortWithinPartitions case.

You did mention that with dataframes we have less control and your approach of using repartition + sortWithinPartitions + mapPartitions is just as good -- so my processing pipeline begins with the data as a dataframe, converts it to RDD for both the previous approach and this new approach, and converts back to dataframe afterwards. Could this be a reason why I'm observing sketchy shuffle spill and execution times? It used to be helpful to work in dataframe before/afterwards but perhaps that's not needed any longer if this causes concerns.

Thank you!

BTW I'm open to setting up a quick conference call if you are interested, that way I don't burn your time with back and forth.

koertkuipers commented 3 years ago

i wonder if the DataFrames are simply a lot more efficient than RDDs in your case?

did you try staying in dataframe land but not materializing the results in a List[Row]? we have an api for this (groupSort on Dataset instead of RDD) but it is only available in scala :(

its basically a matter of doing repartition + sortWithinPartitions + mapPartitions, and then using mapStreamIterator yourself. that is exactly what we do here: https://github.com/tresata/spark-sorted/blob/master/src/main/scala/com/tresata/spark/sorted/sql/GroupSortedDataset.scala#L32

thobrien commented 3 years ago

That's an interesting option - I spent last evening trying to port that Iterator to Java since it was hard to access otherwise in the Dataframe use case. I'll let you know how it goes. Thanks!

koertkuipers commented 3 years ago

in scala we can wrap it in a version that's java friendly. it goes something like this:

$ git diff
diff --git a/src/main/scala/com/tresata/spark/sorted/package.scala b/src/main/scala/com/tresata/spark/sorted/package.scala
index 28d0cfa..ba12bd3 100644
--- a/src/main/scala/com/tresata/spark/sorted/package.scala
+++ b/src/main/scala/com/tresata/spark/sorted/package.scala
@@ -1,8 +1,10 @@
 package com.tresata.spark.sorted

+import java.util.{Iterator => JIterator}
 import java.nio.ByteBuffer
 import scala.annotation.tailrec
 import scala.reflect.ClassTag
+import scala.collection.JavaConverters._

 import org.apache.spark.SparkEnv

@@ -56,6 +58,8 @@ object `package` {
   private[sorted] def mapStreamIterator[K, V, W](iter: Iterator[(K, V)])(f: Iterator[V] => TraversableOnce[W]): Iterator[(K, W)] =
     mapStreamIteratorWithContext[K, V, W, Unit](iter)(() => (), (_: Unit, it: Iterator[V]) => f(it))

+  def mapStreamIterator[K, V, W](iter: JIterator[(K, V)])(f: JIterator[V] => JIterator[W]): JIterator[(K, W)] =
+    mapStreamIterator(iter.asScala){ iter => f(iter.asJava).asScala }.asJava

   private[sorted] def fMergeJoinOuter[V1, V2: ClassTag]: (Iterator[V1], Iterator[V2]) => TraversableOnce[(Option[V1], Option[V2])] = { (it1, it2) =>
     if (it1.hasNext) {

might have to put this somewhere else since i am not sure if java can access methods in package object. i dont know how to do this in java easily.

thobrien commented 3 years ago

No problem - let me try the one I manually ported and if it performs well then it might be worth doing this. If not, then we might learn more about the performance issue. If it's still 10x slower and/or spilling, then the performance issue must be coming from the iterator strategy itself.

thobrien commented 3 years ago

So - recoding the iterator in Java was: a) WAY faster when using DataSet.repartition(group_by_id).sortWithinPartitions(group_by_id,time_id) rather than the spark-sorted RDD sort, and b) also all shuffle spill problems went away

Overall the memory performed about the same as the full partition walk -- which was unexpected since now the custom iterators incrementally dump the state, in my case once as we reach the end of a group, rather than once as we reach the end of the partition. (and it's not a degenerate case - we on average have 625 groups per partition based on UUID hash partitioning)

spark_stage_memory

I'll talk it over with my team tomorrow, but I bet we would proceed and use the custom iterators, but I wish we could figure out why we're not using less memory than the full-partition walks.

Here's the Java iterator code - I did take the liberty of making it not (tail) recursive so that I could understand the code more easily and not worry about crashing the stack since Java doesn't have tail recursive optimizations with the stack:

public class SparkPartitionToGroupIterator<InputIteratorType, KeyFieldType, OutputIteratorType> implements Iterator<OutputIteratorType> {
    private PeekingIterator<InputIteratorType> biter;

    private Function<InputIteratorType, KeyFieldType> getKeyFunction;

    private Function<Iterator<InputIteratorType>,Iterator<OutputIteratorType>> f;

    private PeekingIterator<OutputIteratorType> kwiter;

    public SparkPartitionToGroupIterator(Iterator<InputIteratorType> partitionIterator,
                                         Function<InputIteratorType, KeyFieldType> getKeyFunction,
                                         Function<Iterator<InputIteratorType>,Iterator<OutputIteratorType>> myFunction) {
        this.biter = (PeekingIterator)IteratorUtils.peekingIterator(partitionIterator); // This is the original iterator Spark knows about
        this.getKeyFunction = getKeyFunction;
        this.f = myFunction; // This is the function we're asked to call to transform a partition<Row> into an iterator<some new Rows>
        this.kwiter = perKeyIterator(biter); // Now we setup a per-entity iterator and execute myFunction, which return an iterator of output for the records for that entity. We pass through the Spark iterator, which the per-entity iterator iterates until we observe a new entityId while iterating, at which point kwiter (the perKeyIterator) says hasNext is false so it will stop iterating calls to myFunction.
    }

    @Override
    public boolean hasNext() {
        while (kwiter != null && !kwiter.hasNext()) { // kwiter becomes null when we hit the end of the partition, but until then, keep filling kwiter with results from the next entity iteration, even if that entity returns an empty (but non-null) kwiter iterator
            this.kwiter = perKeyIterator(biter); // Using our partition iterator, build a per-key iterator and use it to iterate the entity timeline, and put all results (zero or more) from that timeline into kwiter
        }

        return kwiter==null ? false : kwiter.hasNext(); // Actually I think kwitter will always either be null or hasNext==true from the previous loop, so we probably don't need to check kwiter.hasNext here and could return true instead, but since it doesn't matter, I'll leave it like this for now.
    }

    @Override
    public OutputIteratorType next() {
        if (hasNext()) // Fill the output iterator if it needs filling, and if we have something in it, then return that
            return kwiter.next();
        else
            throw new NoSuchElementException("next on empty iterator");
    }

    protected PeekingIterator<OutputIteratorType> perKeyIterator(PeekingIterator<InputIteratorType> biter) {
        if (biter.hasNext()) {  // If we haven't reached the end of the partition
            SparkValueInGroupIterator<InputIteratorType,KeyFieldType> viter = new SparkValueInGroupIterator(biter, getKeyFunction); // Build the input iterator for our function, which just passes through the Spark iterator, but also maintains the key to look for changes of the visitor_id while processing, to split up the stream data.
            PeekingIterator<OutputIteratorType> kwiterTemp = (PeekingIterator) IteratorUtils.peekingIterator(f.apply(viter)); // Call the function which should iterate all events for that visitor_id and return back an iterator with 0 or more rows output from that visitor timeline processing.
            return kwiterTemp;
        } else { // If we reached the end of the partition, we can return an empty iterator so that hasNext will start returning false
            return null;
        }
    }
}

public class SparkValueInGroupIterator<InputIteratorType,KeyFieldType> implements Iterator<InputIteratorType> {
    private PeekingIterator<InputIteratorType> biter;
    private Function<InputIteratorType, KeyFieldType> getKeyFunction;
    private KeyFieldType keyFieldValue;

    public SparkValueInGroupIterator(PeekingIterator<InputIteratorType> partitionIterator, Function<InputIteratorType, KeyFieldType> getKeyFunction) {
        this.biter = partitionIterator;
        this.getKeyFunction=getKeyFunction;

        // Remember the value of the key we are iterating so we know when to stop iterating
        InputIteratorType peekRow = biter.peek();
        keyFieldValue = getKeyFunction.apply(peekRow); // Peek at the visitor_id so we can know when the next event has a different visitor_id, and the ValueIterator can then return hasNext=false
    }

    @Override
    public boolean hasNext() {
        InputIteratorType peekRow = biter.peek();
        return biter.hasNext() && keyFieldValue.equals(getKeyFunction.apply(peekRow));
    }

    @Override
    public InputIteratorType next() {
        InputIteratorType peekRow = biter.peek();
        if (hasNext()) {
            return biter.next();
        } else {
            throw new NoSuchElementException("next on empty iterator");
        }
    }
}
koertkuipers commented 3 years ago

i am not surprised by DataFrame performing way better than RDD. it has more efficient serialization plus the ability to sort on serialized formats, which this giant exercise in sorting greatly benefits from.

it is surprising that you don't use less memory. how big are the output objects per group?

thobrien commented 3 years ago

It's variable - we have 1.5kb every time we detect the first of a certain event - so maybe 30-50 of those at the worst case?

Per group so maybe we end up with 50-100 KB, and per-partition maybe multiply by 600 groups per partition. So that's 60 MB of state per partition if I can count object sizes properly, but maybe 50-100KB if done by group.

Did you see anything that looks like a regression in my translation of your code? I think/hope I did a decent job.

Anyways it's super fast using this approach - so I think my team is definitely interested in going this direction since at least it's not worse on memory and may actually be better.