Open hsmajlovic opened 10 months ago
It would be nice if the compiler could figure out that only one truncation is necessary for F instead of for each intermediate multiplication separately: F = A * B + C * D
F = A * B + C * D
Example from DTI code:
vW_prev = vW[l].copy() >>> # vW[l] = vW[l] * MOMENTUM - dW[l] * LEARN_RATE # vW[l] = vW[l].trunc(mpc.fp) # temp = vW[l] * (MOMENTUM + 1) - vW_prev * MOMENTUM # temp = temp.trunc(mpc.fp) # W[l] = W[l] + temp >>> should be vW[l] = vW[l] * MOMENTUM - dW[l] * LEARN_RATE W[l] = W[l] + vW[l] * (MOMENTUM + 1) - vW_prev * MOMENTUM vb_prev = vb[l].copy() >>> # vb[l] = vb[l] * MOMENTUM - db[l] * LEARN_RATE # vb[l] = vb[l].trunc(mpc.fp) # temp_v = vb[l] * (MOMENTUM + 1) - vb_prev * MOMENTUM # temp_v = temp_v.trunc(mpc.fp) # b[l] = b[l] + temp_v >>> should be vb[l] = vb[l] * MOMENTUM - db[l] * LEARN_RATE b[l] = b[l] + vb[l] * (MOMENTUM + 1) - vb_prev * MOMENTUM
It would be nice if the compiler could figure out that only one truncation is necessary for F instead of for each intermediate multiplication separately:
F = A * B + C * D
Example from DTI code: