Open j2kun opened 4 months ago
If I understand it correctly, is https://github.com/llvm/llvm-project/blob/30f5a3ca150e98d482abc6a4d0e3fe6c12f77695/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td#L29-L33 the only issue here?
I think the ones introducing arith ops (https://github.com/llvm/llvm-project/blob/30f5a3ca150e98d482abc6a4d0e3fe6c12f77695/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td#L59-L64, and the corresponding sub one) are also incorrect in the case where there's a modulus or the coefficients might wrap around the underlying data type. For example, converting poly.add(intt(p), intt(q))
to intt(arith.add(p, q))
would be incorrect if arith.add(p, q)
will overflow mod q or mod the data type, since arith.add doesn't take the modulus into account.
@AlexanderViand has this nice doc! https://docs.google.com/document/d/1aS-rSZP3GJvzemoeAG8ingAyyLUNq2IB5Rgqksilbdk/edit
I think the upstream passes would need to fail to apply when there is a modulus or wrap around, and i suppose we can write our HEIR internal canonicalization with mod_arith in this repo
After careful inspection, I now think the canonicalization rule SubAsAdd is OK with modulus ; I was thinking that coefficients should in range [0, +mod) (which means -1 needs to be mod-1 instead of 2^32-1), but now I understand it is in (-mod, +mod), so directly polynomial.mul_scalar -1
is OK and won't overflow; given that the semantic of polynomial.mul_scalar is correct.
Checking the lowering of polynomial.mul_scalar confused me; It may result in wrong result; it emits arith.muli
then arith.remsi
; arith.muli
may overflow if the scalar is large enough; remsi
fortunately avoids the issue in #928. I again think these lowering should be handled in mod_arith.
Only apply if there is no coefficient modulus, otherwise we need arith_ext
However, for the case of NTT/iNTT, there must be coefficientModulus otherwise primitiveRoot wont exist, so the fix proposed in https://github.com/llvm/llvm-project/pull/110318 is removing them for now in the upstream.
Only apply if there is no coefficient modulus, otherwise we need arith_ext