data61 / MP-SPDZ

Versatile framework for multi-party computation
Other
944 stars 280 forks source link

Question about overflow error #1450

Closed Ascurius closed 4 months ago

Ascurius commented 4 months ago

Hello,

I recently encountered an overflow error when trying to run my program whose origin I cannot find.

def test(
        left: sint.Matrix, 
        right: sint.Matrix, 
        l_key: int, 
        r_key: int
    ) -> sint.Matrix:
    start_timer(1000)
    left.sort((l_key,))
    right.sort((r_key))
    stop_timer(1000)

    result = sint.Matrix(
        rows=right.shape[0],
        columns=left.shape[1] + right.shape[1] + 1
    )
    result.assign_all(0)

    i = regint(0)
    j = regint(0)
    mark = regint(-1)
    cnt = regint(0)

    @while_do(lambda: (i < left.shape[0]) & (j < right.shape[0]))
    def _():
        left_value = left[i][l_key]

        lt = (left[i][l_key] < right[j][r_key]).if_else(1,0).reveal()
        gt = (left[i][l_key] > right[j][r_key]).if_else(1,0).reveal()
        eq = (left[i][l_key] == right[j][r_key]).if_else(1,0).reveal()

        @if_(lt)
        def _():
            i.update(i+1)
        @if_(gt)
        def _():
            j.update(j+1)
        @if_(eq)
        def _():
            @while_do(lambda: (j < right.shape[0]) & (right[j][r_key] == left_value).if_else(1,0).reveal())
            def _():
                result[cnt] = left[i].concat(right[j])
                j.update(j+1)
                cnt.update(cnt+1)
            i.update(i+1)
    return result

a = sint.Matrix(10, 13)
a.input_from(0)
b = sint.Matrix(10, 13)
b.input_from(1)

join = test(a, b, 0, 1)

I compiled the program with compile.py -R 64 test.py and executed it with ./Scripts/ring.sh test but the execution aborted with the following error. It seems that one of the index variables exceeded the range of the matrices. However, I cannot see the problem here. I implemented a reference in plain python by defining an array of fixed size, and in python it worked flawlessly. Do you have an idea on what is broken here?

Running /MP-SPDZ/Scripts/../replicated-ring-party.x 0 main -pn 12258 -h localhost
Running /MP-SPDZ/Scripts/../replicated-ring-party.x 1 main -pn 12258 -h localhost
Running /MP-SPDZ/Scripts/../replicated-ring-party.x 2 main -pn 12258 -h localhost
Using statistical security parameter 40
Trying to run 64-bit computation
Starting timer 1000 at 0 (0 MB, 0 rounds) after 3.983e-06
Stopped timer 1000 at 0.0781423 (0.265008 MB, 3036 rounds)
overflow: 10/[10, 13]
terminate called after throwing an instance of 'crash_requested'
  what():  Crash requested by program
/MP-SPDZ/Scripts/run-common.sh: line 90: 3140642 Aborted                 (core dumped) $my_prefix $SPDZROOT/$bin $i $params 2>&1
     3140643 Done                    | { if test "$BENCH"; then
    if test $i = $front_player; then
        tee -a $log;
    else
        cat >> $log;
    fi;
else
    if test $i = $front_player; then
        tee $log;
    else
        cat > $log;
    fi;
fi; }
/MP-SPDZ/Scripts/run-common.sh: line 90: 3140647 Aborted                 (core dumped) $my_prefix $SPDZROOT/$bin $i $params 2>&1
     3140648 Done                    | { if test "$BENCH"; then
    if test $i = $front_player; then
        tee -a $log;
    else
        cat >> $log;
    fi;
else
    if test $i = $front_player; then
        tee $log;
    else
        cat > $log;
    fi;
fi; }
=== Party 1
Stopped timer 1000 at 0.0781439 (0.265008 MB, 3035 rounds)
terminate called after throwing an instance of 'crash_requested'
  what():  Crash requested by program
=== Party 2
Stopped timer 1000 at 0.0782289 (0.265008 MB, 3036 rounds)
terminate called after throwing an instance of 'crash_requested'
  what():  Crash requested by program
Ascurius commented 4 months ago

Edit: For comparison, my normal python code looks like this. From my understanding, both implementations should have the same semantics

def test(
        left: List[List[int]], 
        right: List[List[int]], 
        l_key: int, 
        r_key: int
    ) -> List[List[int]]:
    left.sort(key=lambda row: row[l_key])
    right.sort(key=lambda row: row[r_key])

    n_rows = len(right)
    n_cols = len(left[0]) + len(right[0])
    result = [[0 for _ in range(n_cols)] for _ in range(n_rows)]

    i, j, cnt = 0,0,0

    while i < len(left) and j < len(right):
        left_value = left[i][l_key]

        lt = 1 if left[i][l_key] < right[j][r_key] else 0
        gt = 1 if left[i][l_key] > right[j][r_key] else 0
        eq = 1 if left[i][l_key] == right[j][r_key] else 0

        if lt:
            i += 1
        if gt:
            j += 1
        if eq:
            while j < len(right) and right[j][r_key] == left_value:
                result[cnt] = left[i] + right[j]
                j += 1
                cnt += 1
            i += 1
    return result
Ascurius commented 4 months ago

Another edit: After a lot of back and forth, I found that the problem seems to be related to the value of the index variable j. I suspect that the inner while loop tries to evaluate the expression (right[j][r_key] == left_value).if_else(1,0).reveal() for j = 10 which exceeds the index of right and thus should already be handled by the inners while loops first condition j < right.shape[0]. In other words, it seems that the program tries to evaluate the second part of the condition (which raises an error for index values j>= right.shape[0]), even though the first part of the condition is already evaluated to be false. Hence the second part does not need to be evaluated as they are composed with a logical "and" meaning the evaluation of the second condition is irrelevant in case the first one already resolves to be false. At least that is my suspicion as to why the error occurs.

mkskeller commented 4 months ago

I cannot compile the code example. When I correct right.sort((r_key)) to right.sort((r_key,)), it does compile, and it executes without error. What options did you use with ./compile.py? And what input data did you use? This might be relevant because the control flow depends on it.

Ascurius commented 4 months ago

Thank you for your answer, I corrected the sorting, but I don't think that this is causing the issue here. My data is generated using the function generate_player_input() shown below. Note that the function random_date() just generates a random timestamp between two boundaries.

def generate_player_input(filename: str, rows: int) -> None:
    with open(filename, "w") as file:
        for i in range(1, rows + 1):
            year = random.randint(1900, 2100),
            timestamp = int(random_date(year[0], 2100))
            data = [
                i,
                random.randint(1, 10),
                year[0],
                random.randint(1, 10),
                random.randint(0, 1),
                random.randint(0, 1),
                random.randint(0, 1),
                random.randint(0, 1),
                random.choice([8, 414]),
                random.randint(0, 1),
                timestamp,
                int(random.uniform(0, 1000)),
                int(random.uniform(0, 1000))
            ]
            file.write(" ".join(map(str, data)) + "\n")
Ascurius commented 4 months ago

I adjusted my code to accommodate for my initial suspicion that the second part of the inner while loop is evaluated even though it shouldn't be. So I added two checks as shown below to mitigate this problem by first checking whether or not j < right.shape[0] and only if it is, I tried to access the j-th element in right so mitigate an index error. As it turns out, this approach in fact works and produces the desired result.

Then from my understanding it proves that the problem originated from the while loop. However, even divining into your code, I could not figure out the problem with it.

def test(
        left: sint.Matrix, 
        right: sint.Matrix, 
        l_key: int, 
        r_key: int
    ) -> sint.Matrix:
    start_timer(1000)
    left.sort((l_key,))
    right.sort((r_key,))
    stop_timer(1000)

    result = sint.Matrix(
        rows=right.shape[0],
        columns=left.shape[1] + right.shape[1]
    )
    result.assign_all(0)

    i = regint(0)
    j = regint(0)
    cnt = regint(0)

    @while_do(lambda: (i < left.shape[0]) & (j < right.shape[0]))
    def _():
        left_value = left[i][l_key]

        lt = (left[i][l_key] < right[j][r_key]).if_else(1,0).reveal()
        gt = (left[i][l_key] > right[j][r_key]).if_else(1,0).reveal()
        eq = (left[i][l_key] == right[j][r_key]).if_else(1,0).reveal()

        @if_(lt)
        def _():
            i.update(i+1)
        @if_(gt)
        def _():
            j.update(j+1)
        @if_(eq)
        def _():
            @while_do(lambda: True)
            def _():
                @if_e(j < right.shape[0])
                def _():
                    @if_e((right[j][r_key] == left_value).if_else(1,0).reveal())
                    def _():
                        result[cnt] = left[i].concat(right[j])
                        j.update(j+1)
                        cnt.update(cnt+1)
                    @else_
                    def _():
                        break_loop()
                @else_
                def _():
                    break_loop()
            i.update(i+1)
    return result
Ascurius commented 4 months ago

I cannot compile the code example. When I correct right.sort((r_key)) to right.sort((r_key,)), it does compile, and it executes without error. What options did you use with ./compile.py? And what input data did you use? This might be relevant because the control flow depends on it.

Regarding the compilation, I used no extra options or arguments. The command I used was compile.py -R 64 test.py.

mkskeller commented 4 months ago

I think the issue in the original program is that you replaced and with &. The former prevents execution of the second part but the latter doesn't. You would need to use @while_do(and_(lambda: <condition1>, lambda: <condition2>)).

Ascurius commented 4 months ago

I adjusted my original function test() like you suggested, but that still does not work. I used MP-SPDZ version 0.3.9 to make sure that I am on the latest state of the framework. Am I missing something here?

def test(
        left: sint.Matrix, 
        right: sint.Matrix, 
        l_key: int, 
        r_key: int
    ) -> sint.Matrix:
    start_timer(1000)
    left.sort((l_key,))
    right.sort((r_key,))
    stop_timer(1000)

    result = sint.Matrix(
        rows=right.shape[0],
        columns=left.shape[1] + right.shape[1] + 1
    )
    result.assign_all(0)

    i = regint(0)
    j = regint(0)
    mark = regint(-1)
    cnt = regint(0)

    @while_do(and_(lambda: i < left.shape[0], lambda: j < right.shape[0]))
    def _():
        left_value = left[i][l_key]

        lt = (left[i][l_key] < right[j][r_key]).if_else(1,0).reveal()
        gt = (left[i][l_key] > right[j][r_key]).if_else(1,0).reveal()
        eq = (left[i][l_key] == right[j][r_key]).if_else(1,0).reveal()

        @if_(lt)
        def _():
            i.update(i+1)
        @if_(gt)
        def _():
            j.update(j+1)
        @if_(eq)
        def _():
            @while_do(and_(lambda: j < right.shape[0], lambda: (right[j][r_key] == left_value).if_else(1,0).reveal()))
            def _():
                result[cnt] = left[i].concat(right[j])
                j.update(j+1)
                cnt.update(cnt+1)
            i.update(i+1)
    return result
mkskeller commented 4 months ago

Thank you for raising this. You should find that 77aaaab6a0fe84e474f99327596a346faabf36c4 fixes it.

Ascurius commented 4 months ago

Thank you very much for your help. I tested the original program with you latest changes, but I noticed that for higher amounts of input data, specifically for 10k upwards the compilation is somehow stuck. Hence I terminated the compilation after 2 hours. Is there are reason why and_ increases the compile time so drastically?

mkskeller commented 4 months ago

What do you mean by the original program? The one at the top of this thread works fine for me.

Ascurius commented 4 months ago

So what I was referring to, is this program as mentioned in this comment, where I already implemented the function using and_. Is there a way to inspect the compilation or view into a log file to see why the compilation takes so much time on my system?

mkskeller commented 4 months ago

What command are you using to compile it?

Ascurius commented 4 months ago

Still the same as mentioned before, that is: "compile.py -R 64 program.py"

mkskeller commented 4 months ago

I'm afraid it works for me this way. What you can do is aborting the compilation after a few minutes with Ctrl-C. The resulting backtrace should give some insight as to where the the compilation is stuck.

Ascurius commented 4 months ago

I will close this issue, as this problem no longer occurred, thus the compilation runs fine on my machine.