0xTCG / sequre

A high-performance, Pythonic framework for secure computing in bioinformatics
Apache License 2.0
18 stars 1 forks source link

Implement clever joint truncations pattern matcher #13

Open hsmajlovic opened 10 months ago

hsmajlovic commented 10 months ago

It would be nice if the compiler could figure out that only one truncation is necessary for F instead of for each intermediate multiplication separately: F = A * B + C * D

Example from DTI code:

vW_prev = vW[l].copy()
>>>
# vW[l] = vW[l] * MOMENTUM - dW[l] * LEARN_RATE
# vW[l] = vW[l].trunc(mpc.fp)
# temp = vW[l] * (MOMENTUM + 1) - vW_prev * MOMENTUM
# temp = temp.trunc(mpc.fp)
# W[l] = W[l] + temp
>>> should be
vW[l] = vW[l] * MOMENTUM - dW[l] * LEARN_RATE
W[l] = W[l] + vW[l] * (MOMENTUM + 1) - vW_prev * MOMENTUM

vb_prev = vb[l].copy()
>>>
# vb[l] = vb[l] * MOMENTUM - db[l] * LEARN_RATE
# vb[l] = vb[l].trunc(mpc.fp)
# temp_v = vb[l] * (MOMENTUM + 1) - vb_prev * MOMENTUM
# temp_v = temp_v.trunc(mpc.fp)
# b[l] = b[l] + temp_v
>>> should be
vb[l] = vb[l] * MOMENTUM - db[l] * LEARN_RATE
b[l] = b[l] + vb[l] * (MOMENTUM + 1) - vb_prev * MOMENTUM