google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
231 stars 33 forks source link

Port optimizations from HEaaN.mlir paper #635

Open j2kun opened 2 months ago

j2kun commented 2 months ago

https://dl.acm.org/doi/pdf/10.1145/3591228

The lowest hanging fruit for us seems to be loop fusion passes, which we could apply to the polynomial dialect after ntt is lowered to affine in the mlir-polynomial-to-llvm pipeline

inbelic commented 1 month ago

In the paper they also describe a potential optimization of modulo arithmetic by removing the use of a modulo operator (Rem[S|U]Op) for multiplication/addition of polynomials in the NTT domain. They introduce an operation to represent the first Barrett reduction step, say arith_ext.barrett_step, and, arith_ext.subifge x y that denotes z = (x >= y) ? x - y : y. For instance:

mul = arith.muli x y
res = arith.remui mul cmod

becomes

mul = arith.muli x y
barret = arith_ext.barrett_reduce mul cmod
res = arith_ext.subifge barret cmod

We are able to avoid the use of potential division used in the remainder operation and only have runtime multiplication and bitshift as the Barrett ratio is able to be statically computed. We would be able to directly use this optimization in the current NTT lowering.

Further, they describe the use of a data-flow analysis to be able to reduce the number of arith_ext.subifge when there are subsequent uses from muls/adds/subs. These optimizations require the assumption that the input polynomial coefficients are in the range [0, cmod), however this restriction is not currently required in the polynomial (can be negative). This could be addressed with an operation poly.normalise that will ensure all the polynomial coefficients are in the range [0, cmod).

So I would propose the following steps to implement the papers optimizations:

Looking for feedback in all aspects of the proposed solution, especially operation names.

asraa commented 1 month ago

Nice! I'm excited to see the difference when applied to the NTT lowering :)

as the Barrett ratio is able to be statically computed

Just to check me: for computing the Barret ratio, during poly-to-standard lowering the pass would statically compute the Barret ratio and insert the computation for computing the Barrett reduction, correct? If so, makes sense and I think I like the arith_ext style dialect and name. Since the modulus for arith_ext.barret_reduce must be constant and statically known, I wonder if it should be an attribute?

%barret = arith_ext.barrett_reduce {modulus = cmod} %mul

This could be addressed with an operation poly.normalise that will ensure all the polynomial coefficients are in the range [0, cmod).

Hmm yes that's a good point. I'm a little curious how just an operation will play out. Do we need an attribute on the polynomial type itself to mark that it is normalized? Without it I would think that a polynomial.mul lowering to standard would need to do some analysis to determine whether it's inputs were the result of a poly.normalise.

j2kun commented 1 month ago

require the assumption that the input polynomial coefficients are in the range [0, cmod)

Another possibility I was considering while writing https://github.com/google/heir/pull/675 is that we should ensure this invariant holds always. I didn't ultimately do it in that PR because I found some confusing behavior around remsi/remui (either BOTH operands are signed or BOTH are unsigned, which is wrong both ways if you have (-1 : i32) % cmod).

But we could consider that. I think @AlexanderViand-Intel should chime in since this would have to be compatible with polynomial ISA considerations.

Otherwise I think this is a great plan.

inbelic commented 1 month ago

Yes exactly, we can compute the ratio from half the operand bit-width and cmod. I agree with making it an attribute since it is static.

Hm good point. Thinking out loud: the optimizations would happen after polynomial-to-standard, where we would then apply the first set of passes to introduce the arith_ext operations. Followed by the data-flow analysis. We would then need to have a way to denote that the values of the tensor are normalized after lowering from the poly level.

We could use an encoding in the tensor to denote it is normalised wrt. some cmod. Then when we lower to_tensor we can mark those tensors, either all of them if the invariant holds or just the polynomials that have the attribute. Although I think we would still need some sort of analysis to propagate the ranges when we operate at the arith/arith_ext level.

Another option would be to introduce the poly.normalise op which would lower to an arith_ext.normalise { modulus = cmod } to allow for the range to be propagated a the arith/arith_ext level.

I can start now with adding the arith_ext dialect and their operations/passes. Then we have some time to think about how to go about the rest.