Open j2kun opened 2 months ago
In the paper they also describe a potential optimization of modulo arithmetic by removing the use of a modulo operator (Rem[S|U]Op
) for multiplication/addition of polynomials in the NTT domain. They introduce an operation to represent the first Barrett reduction step, say arith_ext.barrett_step
, and, arith_ext.subifge x y
that denotes z = (x >= y) ? x - y : y
.
For instance:
mul = arith.muli x y
res = arith.remui mul cmod
becomes
mul = arith.muli x y
barret = arith_ext.barrett_reduce mul cmod
res = arith_ext.subifge barret cmod
We are able to avoid the use of potential division used in the remainder operation and only have runtime multiplication and bitshift as the Barrett ratio is able to be statically computed. We would be able to directly use this optimization in the current NTT lowering.
Further, they describe the use of a data-flow analysis to be able to reduce the number of arith_ext.subifge
when there are subsequent uses from muls/adds/subs. These optimizations require the assumption that the input polynomial coefficients are in the range [0, cmod)
, however this restriction is not currently required in the polynomial (can be negative). This could be addressed with an operation poly.normalise
that will ensure all the polynomial coefficients are in the range [0, cmod)
.
So I would propose the following steps to implement the papers optimizations:
arith_ext
which has the operations barret_reduce
and subifge
.arith.muli + arith.remui
into arith.muli + artih_ext.barret_reduce + arith_ext.subifge
when the operands are in the range [0, cmod)
.arith.addi + arith.remui
into arith.addi + arith_ext.subifge
when the operands are in the range [0, cmod)
.arith.subi + arith.remui
into arith.subi + arith.addi cmod + arith.subifge
when the operands are in the range [0, cmod)
.poly.normalise
operation to poly
to provide a fixed range [0, cod)
for use in the above passes and modelling the ranges in the data-flow analysis.arith.subifge
operations.Looking for feedback in all aspects of the proposed solution, especially operation names.
Nice! I'm excited to see the difference when applied to the NTT lowering :)
as the Barrett ratio is able to be statically computed
Just to check me: for computing the Barret ratio, during poly-to-standard
lowering the pass would statically compute the Barret ratio and insert the computation for computing the Barrett reduction, correct? If so, makes sense and I think I like the arith_ext
style dialect and name. Since the modulus for arith_ext.barret_reduce
must be constant and statically known, I wonder if it should be an attribute?
%barret = arith_ext.barrett_reduce {modulus = cmod} %mul
This could be addressed with an operation poly.normalise that will ensure all the polynomial coefficients are in the range [0, cmod).
Hmm yes that's a good point. I'm a little curious how just an operation will play out. Do we need an attribute on the polynomial type itself to mark that it is normalized? Without it I would think that a polynomial.mul
lowering to standard would need to do some analysis to determine whether it's inputs were the result of a poly.normalise.
require the assumption that the input polynomial coefficients are in the range [0, cmod)
Another possibility I was considering while writing https://github.com/google/heir/pull/675 is that we should ensure this invariant holds always. I didn't ultimately do it in that PR because I found some confusing behavior around remsi/remui (either BOTH operands are signed or BOTH are unsigned, which is wrong both ways if you have (-1 : i32) % cmod
).
But we could consider that. I think @AlexanderViand-Intel should chime in since this would have to be compatible with polynomial ISA considerations.
Otherwise I think this is a great plan.
Yes exactly, we can compute the ratio from half the operand bit-width and cmod
. I agree with making it an attribute since it is static.
Hm good point. Thinking out loud: the optimizations would happen after polynomial-to-standard
, where we would then apply the first set of passes to introduce the arith_ext
operations. Followed by the data-flow analysis. We would then need to have a way to denote that the values of the tensor are normalized after lowering from the poly
level.
We could use an encoding in the tensor to denote it is normalised wrt. some cmod
. Then when we lower to_tensor
we can mark those tensors, either all of them if the invariant holds or just the polynomials that have the attribute. Although I think we would still need some sort of analysis to propagate the ranges when we operate at the arith/arith_ext
level.
Another option would be to introduce the poly.normalise
op which would lower to an arith_ext.normalise { modulus = cmod }
to allow for the range to be propagated a the arith/arith_ext
level.
I can start now with adding the arith_ext
dialect and their operations/passes. Then we have some time to think about how to go about the rest.
https://dl.acm.org/doi/pdf/10.1145/3591228
The lowest hanging fruit for us seems to be loop fusion passes, which we could apply to the polynomial dialect after ntt is lowered to affine in the
mlir-polynomial-to-llvm
pipeline