TimelyDataflow / differential-dataflow

An implementation of differential dataflow using timely dataflow on Rust.
MIT License
2.54k stars 183 forks source link

Difficulty understanding how to use prefix_sum / how to implement topK #381

Closed oli-w closed 1 year ago

oli-w commented 1 year ago

Hi 👋, I'm still quite new to Rust, have been learning it to be able to try out differential dataflow. I've had a lot of success using map/filter/join operators, but have struggled for a couple of days now to try and understand how to use prefix_sum for my use case.

What I'm trying to achieve: I have a collection of entities which have a unique key (String) and a numeric value. I want to calculate the top K entities ordered by value, where K is a reasonable number around about 100.

I saw @frankmcsherry (really appreciate your work!) mentioned somewhere (can't find it now) that you can use prefix_sum to compute the position of each entity. This makes sense to me - for each key, count the number of times that other keys appear before it (count rather than summing the previous values).

So for example, if I have inputs as (key, value) pairs: [(a, 7), (b, 11), (c, 3), (d, 2)] and I want to compute the top K=3, I would like to:

  1. Compute the position/index of each key, which would be: [(d, 0), (c, 1), (a, 2), (b, 4)]
  2. Filter to find index < 3: [(d, 0), (c, 1), (a, 2)]

I see that prefix_sum requires a Collection where each Data value has shape ((usize, K), D). My initial thoughts were that these types mean:

Here is my initial code (timestamps excluded because not relevant):

fn main() {
    timely::execute(timely::Config::thread(), |worker| {
        worker.dataflow(|scope| {
            let (mut input, collection) = scope.new_collection::<(String, i32), isize>();
            collection
                .map(|(key, value)| {
                    return ((value as usize, key), 1);
                })
                .prefix_sum(0, |k, first_value, second_value| {
                    return first_value + second_value;
                })
                .inspect(|(data, _timestamp, diff)| println!("{:?}, {:?}", data, diff));

            input.advance_to(0);
            input.insert((String::from("a"), 7));
            input.insert((String::from("b"), 11));
            input.insert((String::from("c"), 3));
            input.insert((String::from("d"), 2));
        });
    })
    .expect("Computation terminated abnormally");
}

It prints out the following, but the position/index is 0 for all:

((2, "d"), 0), 1
((3, "c"), 0), 1
((7, "a"), 0), 1
((11, "b"), 0), 1

I thought maybe I got the usize and D mixed up, so tried changing the map to return ((1 as usize, key), value); and I get:

((1, "a"), 0), 1
((1, "b"), 0), 1
((1, "c"), 0), 1
((1, "d"), 0), 1

I've tried all sorts of combinations to achieve step 1 output of [(d, 0), (c, 1), (a, 2), (b, 4)] but no luck. Any help would be much appreciated!

frankmcsherry commented 1 year ago

I apologize that it isn't at all clear. I agree about the state of the code.

The secret meanings are

  1. usize are indeed the ordered values at which something happens.
  2. K is a key, but probably not in the way that you hope. Each distinct key: K results in an independent subproblem, so you'll end up with counts for each letter you have.
  3. D is the accumulator. Counts if you want counts, but .. you might not want counts!

I think what you may want, if you want to use prefix_sum here, is to put the letters in the D. You'll probably need to roll your own D type, but you can imagine a D that implements a length K (of your "top K") ordered list, and when you "add" two together it keeps only the top K.

Now, that being said I think you might find that this all simplifies, to something that looks more like how Materialize does its TopK rendering: https://github.com/MaterializeInc/materialize/blob/main/src/compute/src/render/top_k.rs There, this is also a hierarchical pattern, but what it does is repeatedly group by (key, hash(val) >> round) and retain the Top K for each group. As round increase, eventually the hash(val) vanishes, and you have for each key the Top K values.

If the above sounds confusing, it may be because I think you are using "key" in a different sense that we would. Your characters are "keys" to you because they map to numbers. But for us they would be "values", because they are just the payload you want to recover. I think for example, that you are doing TopK with a key of () and an ordering by the integer score, with an associated payload of the character. In the paragraph just above, where I said key you should think (), and where I said val you should think "character".

frankmcsherry commented 1 year ago

Depending on how much sense the above ends up making (it's hard to be clear about such weird stuff without some careful framing), and depending on whether you want to understand prefix_sum or get to the best TopK implementation, feel free to follow up with more questions!

oli-w commented 1 year ago

Thank you for such a detailed answer! The K you describe makes sense to me - I think of it as a "group key". I was a bit mixed up having used the join methods where I would use the letters as K. I managed to get it working with prefix_sum, renaming key/value to letter/score for clarity:

let limit = 2;
collection
  .map(|(letter, score)| {
      return ((score as usize, ()), vec![(letter, score)]);
  })
  .prefix_sum(vec![], move |_k, first, second| {
      let mut result = vec![];
      result.extend(first.into_iter().map(|v| v.to_owned()));
      result.extend(second.into_iter().map(|v| v.to_owned()));
      result.sort_by(|(_, first_score), (_, second_score)| first_score.cmp(&second_score));
      return if result.len() > limit {
          result.split_at(limit).0.to_vec()
      } else {
          result
      };
  })
  .inspect(|(data, timestamp, diff)| println!("{:?}, {}", data, diff));

(Please ignore my poor use of collections / inefficiencies 😅, my focus is more on how to use differential dataflow well)

Using limit=2, this outputs:

((2, ()), []), 1
((3, ()), [("d", 2)]), 1
((7, ()), [("d", 2), ("c", 3)]), 1
((11, ()), [("d", 2), ("c", 3)]), 1

So for each score, find the sorted top (at most 2) entries that have a score less than it. Then to get (letter, index) outputs, I add after prefix_sum and before the inspect:

  .flat_map(|(score_and_key, letters_with_scores)| {
      return if letters_with_scores.len() > 0 {
          let last_index = letters_with_scores.len() - 1;
          let (letter, _) = letters_with_scores[last_index].to_owned();
          vec![(letter, last_index)].into_iter()
      } else {
          vec![].into_iter()
      };
  })
  .distinct()

Having to use the distinct feels a bit odd, since the list of topK will be produced for all larger scores, so the last entry needs to be de-duplicated. Is this what you had in mind or am I missing something?

While this does work, I'm not sure if it's much more efficient than a naive implementation using reduce (which if I understand correctly will end up being called on every input change). The only thing I can think of is that the prefix_sum version might save some computation for changes with scores larger than the first K entries.

The "naive" reduce implementation I have in mind:

collection
  .map(|entry| {
      return ((), entry);
  })
  .reduce(move |_key, mut input, mut output| {
      let mut sorted: Vec<&(String, i32)> =
          input.to_vec().into_iter().map(|(entry, _diff)| entry).collect();
      sorted.sort_by(|(_, first_score), (_, second_score)| first_score.cmp(&second_score));
      for index in 0..min(sorted.len(), limit) {
          output.push(((sorted[index].0.clone(), index), 1));
      }
  })
  .map(|(_key, letter_with_score)| letter_with_score)
  .inspect(|(data, timestamp, diff)| println!("{:?}, {}", data, diff));

Ultimately, I'm after an implementation that's better than the naive one (sort everything on every change), but doesn't need to be ultra-performant (at least not yet - I'm working on connecting everything end-to-end before optimising too much).

Thanks for the pointer to the Materialize code. I will need to study it a bit more to fully understand it, my understanding so far is that build_topk takes a collection of rows (representing database rows), specifies group_key to allow GROUP BY SQL syntax (i.e. which columns to group by) and order_key for the ORDER BY part. After grouping, it repeatedly calls build_topk_stage on the collection, trimming each group down to only take the first offset + limit rows, and then on the final iteration applying the actual offset and limit. My brain is still slowly adjusting to the new way of thinking, e.g. I read through your blog on Sudoku solving where we start with all possible solutions and then repeatedly trim away invalid solutions, which is opposite to how you would typically write code to solve Sudoku.

frankmcsherry commented 1 year ago

Ultimately, I'm after an implementation that's better than the naive one (sort everything on every change), but doesn't need to be ultra-performant (at least not yet - I'm working on connecting everything end-to-end before optimising too much).

I recommend the Materialize approach, then; it will be much easier (and better, I think?) than prefix_sum. More on that in a moment!

The "naive" reduce implementation I have in mind:

A few thoughts on this:

  1. Differential dataflow will sort input for you, so if you present (i32, String) instead, the data will be in sorted order and you can just leap to the first / last elements as you like. Spares you creating a vector and sorting it and all that.
  2. I think you have a possible bug in the code, in that DD will present at you (data, count) pairs, and if you see three copies of a record that counts for three of the outputs. Right now you are just producing (data, 1) which will be one record in the output. I'm not sure what you intend, or if you know the input are distinct, but wanted to call that out!

I will need to study it a bit more to fully understand it, my understanding so far is that ...

There's a pretty easy way to understand it, I think. If you generalize your naive approach, however it works out, to start instead from

collection
  .map(|entry| {
      return (entry.hashed() >> round, entry);
  })
  ...

where round is an argument to the method, then .. your () key is this where round is 64, but you can use smaller version of round also. If you do, you'll form multiple smaller groups, each of which are easier to update (because they are smaller). If you do what Materialize does, you can take round from 0 to 64 in steps of idk 4, you'll have 16 stages in which each group always has at most 16 elements. When a change occurs, it has a bounded amount of work it can cause to happen, even with unboundedly large data (well, up to 2^64 elements).

oli-w commented 1 year ago

I've got it working. Thank you!!

Differential dataflow will sort input for you

Ahah I should have seen that in the docs for reduce 👍, this cleans things up nicely.

Right now you are just producing (data, 1) which will be one record in the output.

I should have mentioned initially, the "letters" in my example will in practice be globally unique ID's. Understood this will need more work to use the diffs if not dealing with unique entries.

When a change occurs, it has a bounded amount of work it can cause to happen, even with unboundedly large data (well, up to 2^64 elements).

Awesome. I think I understand this now, because the data is split into multiple different groups, which get combined over and over, and most of the groups won't change for a single input change (right?).

Here is the code I ended up with:

type ScopeType<'a> = Child<'a, Worker<Generic>, i32>;
fn top_k_stage(
    collection: Collection<ScopeType, (u64, (i32, String)), isize>,
    limit: usize,
    round: u32,
) -> Collection<ScopeType, (u64, (i32, String)), isize> {
    return collection
        .map(move |(hash, entry)| (hash >> round, entry))
        .reduce(move |_key, input, mut output| {
            for index in 0..min(input.len(), limit) {
                let (entry, _diff) = input[index];
                output.push((entry.clone(), 1));
            }
        });
}

fn main() {
    timely::execute(timely::Config::thread(), |worker| {
        let mut input = worker.dataflow(|scope| {
            let (mut input, collection) = scope.new_collection::<(String, i32), isize>();
            let limit = 4;
            let mut collection = collection
                .map(|(letter, score)| (score, letter))
                .map(|entry| (entry.hashed(), entry));
            for round in (0..64).step_by(4) {
                collection = top_k_stage(collection, limit, round);
            }
            // Aggregate final result into a single Vec<String>
            collection
                .map(|(_key, entry)| ((), entry))
                .reduce(move |_key, input, mut output| {
                    let result: Vec<String> = input.to_vec().iter().map(|&(entry, _diff)| entry.1.clone()).collect();
                    output.push((result.clone(), 1));
                })
                .map(|(_key, ordered_letters)| ordered_letters)
                .inspect(|(data, timestamp, diff)| println!("{:?}, {}, {}", data, timestamp, diff));

            return input;
        });

        input.advance_to(0);
        // Expected: [d, c, a, b]
        input.insert((String::from("a"), 7));
        input.insert((String::from("b"), 11));
        input.insert((String::from("c"), 3));
        input.insert((String::from("d"), 2));
        input.insert((String::from("e"), 13));
        input.insert((String::from("f"), 17));
    })
    .expect("Computation terminated abnormally");
}

Thanks again for all your help!

frankmcsherry commented 1 year ago

Looks, good with one nit: your top_k_stage will shift the hash by round and emits the shifted hash as the result. This gets fed in to the next stage, rather than the orginal hash. This means that you probably shouldn't step up each time around, and instead just use a constant round some number of times. E.g. instead do

            for round in (0..64).step_by(4) {
                collection = top_k_stage(collection, limit, 4); // `4` not `round`
            }