data61 / MP-SPDZ

Versatile framework for multi-party computation
Other
944 stars 280 forks source link

sort-merge join performs worse than nested-loop join in MP-SPDZ #1445

Closed Ascurius closed 4 months ago

Ascurius commented 4 months ago

Hello,

I recently implemented a variant of sort-merge join and nested-loop join in MP-SPDZ. However, I have noticed that the execution time of sort-merge-join is worse than that of nested-loop join, which should not be the case in theory. I compiled the code with compile.py -R 64 test.py and ran the code with ./Scripts/ring.sh test. I assume that this difference is caused by the secure sorting, happening at the start of the sort-merge join, or do you think that there is a different explanantion?

My sort-merge join looks like this:

def sort_merge_join(
        left: sint.Matrix, 
        right: sint.Matrix, 
        l_key: int, 
        r_key: int
    ) -> sint.Matrix:
    left.sort((l_key,))
    right.sort((r_key))

    result = sint.Matrix(
        rows=left.shape[0]*right.shape[0],
        columns=left.shape[1] + right.shape[1]
    )
    result.assign_all(0)

    i = regint(0)
    j = regint(0)
    mark = regint(-1)
    cnt = regint(0)

    @while_do(lambda: (i < left.shape[0]) & (j < right.shape[0]+1))
    def _():
        @if_(j >= right.shape[0])
        def _():
            j.update(mark)
            i.update(i+1)
            mark.update(-1)
            @if_(i >= len(left))
            def _():
                break_loop()
        @if_(mark == -1)
        def _():
            @while_do(lambda: (left[i][l_key] < right[j][l_key]).if_else(1,0).reveal())
            def _():
                i.update(i+1)
            @while_do(lambda: (left[i][l_key] > right[j][l_key]).if_else(1,0).reveal())
            def _():
                j.update(j+1)
            mark.update(j)
        @if_e(
            (left[i][l_key] == right[j][l_key]).if_else(1,0).reveal()
        )
        def _():
            result[cnt] = left[i].concat(right[j])
            j.update(j+1)
            cnt.update(cnt+1)
        @else_
        def _():
            j.update(mark)
            i.update(i+1)
            mark.update(-1)
    return result
mkskeller commented 4 months ago

You can benchmark parts using start_timer()/stop_timer(): https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.library.start_timer

What are you using for nested loop join? Also note that there are dedicated algorithms for database joining: https://eprint.iacr.org/2024/141

Ascurius commented 4 months ago

Thank you for the helpful reference to the paper.