google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
332 stars 48 forks source link

Fix upstream polynomial canonicalization rules #749

Open j2kun opened 4 months ago

j2kun commented 4 months ago

Only apply if there is no coefficient modulus, otherwise we need arith_ext

ZenithalHourlyRate commented 1 month ago

If I understand it correctly, is https://github.com/llvm/llvm-project/blob/30f5a3ca150e98d482abc6a4d0e3fe6c12f77695/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td#L29-L33 the only issue here?

asraa commented 1 month ago

I think the ones introducing arith ops (https://github.com/llvm/llvm-project/blob/30f5a3ca150e98d482abc6a4d0e3fe6c12f77695/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td#L59-L64, and the corresponding sub one) are also incorrect in the case where there's a modulus or the coefficients might wrap around the underlying data type. For example, converting poly.add(intt(p), intt(q)) to intt(arith.add(p, q)) would be incorrect if arith.add(p, q) will overflow mod q or mod the data type, since arith.add doesn't take the modulus into account.

@AlexanderViand has this nice doc! https://docs.google.com/document/d/1aS-rSZP3GJvzemoeAG8ingAyyLUNq2IB5Rgqksilbdk/edit

I think the upstream passes would need to fail to apply when there is a modulus or wrap around, and i suppose we can write our HEIR internal canonicalization with mod_arith in this repo

ZenithalHourlyRate commented 1 month ago

After careful inspection, I now think the canonicalization rule SubAsAdd is OK with modulus ; I was thinking that coefficients should in range [0, +mod) (which means -1 needs to be mod-1 instead of 2^32-1), but now I understand it is in (-mod, +mod), so directly polynomial.mul_scalar -1 is OK and won't overflow; given that the semantic of polynomial.mul_scalar is correct.

Checking the lowering of polynomial.mul_scalar confused me; It may result in wrong result; it emits arith.muli then arith.remsi; arith.muli may overflow if the scalar is large enough; remsi fortunately avoids the issue in #928. I again think these lowering should be handled in mod_arith.

Only apply if there is no coefficient modulus, otherwise we need arith_ext

However, for the case of NTT/iNTT, there must be coefficientModulus otherwise primitiveRoot wont exist, so the fix proposed in https://github.com/llvm/llvm-project/pull/110318 is removing them for now in the upstream.