data61 / MP-SPDZ

Versatile framework for multi-party computation
Other
944 stars 280 forks source link

How best to add multi-threading in MP-SPDZ #1436

Closed Ascurius closed 4 months ago

Ascurius commented 4 months ago

Hello,

with regard to my question in issue #1429 I wanted to ask how to best implement multi-threading in MP-SPDZ.

Specifically, I have a large program as shown below. Currently, I run into the error RegisterOverflowError as described in #1429, for max_rows >= 4000, when compiling with compile.py -R 64. I wanted to kindly ask, if you could give me a hint on how to add multi-threading to my program. My goal is to extend the program to that I can process up to 100k rows of data, which is a significant increase.


def select_distinct(
        matrix: sint.Matrix, 
        key: int, 
        condition: Callable[[sint.Array], bool] = lambda row: True
    ) -> sint.Matrix:
    start_timer(600)
    matrix.sort((key,))
    stop_timer(600)
    result = sint.Matrix(
        rows=matrix.shape[0],
        columns=matrix.shape[1] + 1
    )

    prev_value = sint(-1)
    @for_range_opt(matrix.shape[0])
    def _(i):
        result[i].assign_vector(matrix[i])
        dbit = (
            (matrix[i][key] != prev_value) &
            condition(matrix[i])
        ).if_else(sint(1),sint(0))
        result[i][-1] = dbit
        new_value = (dbit == 1).if_else(matrix[i][key], prev_value)
        prev_value.update(new_value)
    return result

def join_nested_loop(
        left: sint.Matrix, 
        right: sint.Matrix, 
        left_key: int, 
        right_key: int,
        condition: Callable[[sint.Array, sint.Array], bool] = lambda left, right: True
    ) -> sint.Matrix:
    result = sint.Matrix(
        rows=left.shape[0] * right.shape[0],
        columns=left.shape[1] + right.shape[1] + 1
    )
    current_idx = regint(0)
    @for_range_opt(left.shape[0])
    def _(left_row):
        @for_range_opt(right.shape[0])
        def _(right_row):
            new_row = sint.Array(result.shape[1])
            @for_range_opt(left.shape[1])
            def _(left_col):
                new_row[left_col] = left[left_row][left_col]
            @for_range_opt(right.shape[1])
            def _(right_col):
                new_row[left.shape[1]+right_col] = right[right_row][right_col]
            result[current_idx].assign(new_row)
            result[current_idx][-1] = (
                (left[left_row][left_key] == right[right_row][right_key]) &
                condition(left[left_row], right[right_row])
            ).if_else(sint(1),sint(0))
            current_idx.update(current_idx + regint(1))
    return result

def where(matrix: sint.Matrix, key: int, value: int) -> sint.Matrix:
    result = sint.Matrix(
        rows=matrix.shape[0],
        columns=matrix.shape[1] + 1
    )
    @for_range_opt(matrix.shape[0])
    def _(i):
        result[i].assign_vector(matrix[i])
        result[i][-1] = (matrix[i][key] == value).if_else(1,0)
    return result

def where_less_then(matrix: sint.Matrix, col_1: int, col_2: int) -> sint.Matrix:
    result = sint.Matrix(
        rows=matrix.shape[0],
        columns=matrix.shape[1] + 1
    )
    @for_range_opt(matrix.shape[0])
    def _(i):
        result[i].assign_vector(matrix[i])
        result[i][-1] = (
            matrix[i][col_1] <= matrix[i][col_2]
        ).if_else(1,0)
    return result

max_rows = 4000
a = sint.Matrix(max_rows, 13)
a.input_from(0)
b = sint.Matrix(max_rows, 13)
b.input_from(1)

aw = where(a, 8, 414)
bw = where(b, 4, 0)

join = join_nested_loop(aw, bw, 1, 1)

wlt = where_less_then(join, 2, aw.shape[1]+2)

def distinct_condition(row):
    return (
        # row[-1] == 1
        (row[-1] == 1) &
        (row[-2] == 1) &
        (row[-3] == 1) &
        (row[13] == 1)
    ).if_else(1,0)

select = select_distinct(wlt, 0, condition=distinct_condition)

count = regint(0)
@for_range_opt(select.shape[0])
def _(i):
    dbit_5 = (select[i][-1] == 1).if_else(1,0)
    @if_(dbit_5.reveal())
    def _():
        count.update(count + 1)
print_ln("%s", count)
mkskeller commented 4 months ago

This seems to be related to the sorting, which isn't capable of multithreading currently. However, I would also note that your algorithm is quadratic in the size that you're trying to increase. What running time are you seeing for 1000 rows? Unless it's just a few seconds for 1000 rows, the running time for 100000 will be more than a week, so I'm not sure if it's worth to put effort into this.

Ascurius commented 4 months ago

Thank you for your answer. On average, I measured ~1100s in total, where ~240 seconds are spent for the nested-loop join and thus around 630s are spent for sorting the table. I assume that the sorting is bloated due to the nested-loop join as the sorting has to deal with a matrix of dimensions $m \times n$. Do you think that this problem can be mitigated by reducing the output size of the nested loop operator?

mkskeller commented 4 months ago

Yes, that's where the quadratic complexity comes from.

Ascurius commented 4 months ago

Besides changing the JOIN algorithm, I thought about another solution to this problem, because as far as I understood, the function Array.sort() supports multi-threading, whereas Matrix.sort() does not.

That is why, alternatively, we could take a column of a matrix and associate the original index with the value as a tuple. Then we could sort this array in multiple threads, thus improving the run time of the sorting. But from what I understood the type Array does not support storing tuples. Is that correct?

Additionally, if I may ask, do you have any plans on adding multi-threading support for Matrix.sort()?

mkskeller commented 4 months ago

Multithreading support for sorting is in the making. You could also think of using library.loopy_odd_even_merge_sort.

Ascurius commented 4 months ago

Thank you for your answer. I implemented a small code snippet as shown below but during testing, I noticed that there is no improvement in using library.loopy_odd_even_merge_sort (even in multi-threading) when compared to sorting with a.sort((1,)) (which in my understanding just refers to radix sort).

For ring domain, I compiled the program with compile.py -R 64 test.py and executed it with /Scripts/ring.sh. As for the binary domain, I compiled with compile.py -B 64 test.py and used /Scripts/replicated.sh.

from Compiler.library import loopy_odd_even_merge_sort

a = sint.Matrix(1000, 13)
a.input_from(0)

loopy_odd_even_merge_sort(a, key_indices=(1,), n_threads=5)

However, I noticed that the run time improves when instead of using modulo 2^(ring size), I switch to binary domain. Only in this case I noticed a considerable improvement. To name a few figures, in binary domain the sample code above took ~13s, but when sorting directly with a.sort((1,)), the run time worsens to ~ 31s.

In case of the domain modulo 2^(ring size), loopy_odd_even_merge_sort takes ~5s whereas a.sort((1,)) needs only ~0.4s.

If I have not made a mistake here, then multi-threading with loopy_odd_even_merge_sort does not improve the run time of the sorting, which is what I am trying to achieve here. Because from what I understood so far, also with regards to what you explained in #1429, the main challenge here is the secure sorting which is quite costly in terms of run time.

mkskeller commented 4 months ago

Sorry, I meant to say that loopy_odd_even_merge_sort might help with the register overflow, not with overall performance. Radix sort is chosen as default as it performs better overall.

Ascurius commented 4 months ago

Thank you for the clarification, I am sorry for the confusion.