google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
318 stars 47 forks source link

Accelerator-Friendly Lowering of polynomial.mul to NTT #724

Closed AlexanderViand-Intel closed 4 months ago

AlexanderViand-Intel commented 4 months ago

copied from my discord post earlier today

As part of the heir->DPRIVE-style accelerator flow, I've been thinking about the NTT rewrite of polynomial.mul (i.e., replacing a polynomial mul with ntt + elementwise mul + intt). Currently, the "elementwise mul" in that is represented like this:

%2 = arith.extui %0 : tensor<4xi5, #ring> to tensor<4xi10, #ring>
%3 = arith.extui %1 : tensor<4xi5, #ring> to tensor<4xi10, #ring>
%4 = arith.muli %2, %3 : tensor<4xi10, #ring>
%5 = arith.remui %4, %cst : tensor<4xi10, #ring>
%6 = arith.trunci %5 : tensor<4xi10, #ring>

which makes perfect sense, but is somewhat annoying to deal with for HW where "multiplication modulo cmod" is a native operation.

Right now, I'm sidestepping the issue by not using --convert-polynomial-mul-to-ntt, but that's clearly not ideal: one of the basic, but very important, optimizations HEIR should do at the polynomial level is to get rid of redundant ntt/intt operations and of course those need to happen after --convert-polynomial-mul-to-ntt.

My first idea to solve this would be to introduce a poly_ext (or tensor_ext ?) operation that takes two tensors and a ring attribute and represents the elementwise modular multiplication. Then the current --convert-polynomial-mul-to-ntt would be split into two, the first part only going to that elementwise_mul (or similar) operation, and a second pass that lowers that to the arith implementation above.

Thoughts? Am I missing a better solution here? Pinging @inbelic :wink:

j2kun commented 4 months ago

Is there a reason not to match on mul + remui and convert that to the HW-specific op, provided the values are legal for the HW?

j2kun commented 4 months ago

Ah, either way we still need a dedicated op. I could see that "op+mod" ops being added to arith_ext.

inbelic commented 4 months ago

The proposed method looks good, but I think the new operation fits well into the arith_ext. I have another pull request up that is adding the lowerings for arith_ext ops that you can add this lowering to as well.

A reason to not pattern match after, is that once the HEaan paper is fully implemented it may be an arith.remui, arith_ext.barrett_reduce or arith_ext.subifge op.

Edit: pull request is #715

inbelic commented 4 months ago

I also agree that we should add the other types of modular arithmetic to the arith_ext dialect.

AlexanderViand-Intel commented 4 months ago

👍 So how about arith_ext.add %lhs, %rhs { modulus = ..... } : i32 and similar for sub/mul? This is using the same syntax for the modulus attribute that the existing arith_ext ops use (I guess anything prettier might require a custom parser/printer and not just tablegen?)

j2kun commented 4 months ago

That syntax seems fine, as long as the modulus is always statically known.

AlexanderViand-Intel commented 4 months ago

I added the ops (see PR here: #725) and will proceed to change --convert-polynomial-mul-to-ntt to only lower to arith_ext.mul (+add support for arith_ext.add/sub/mul to -arith-ext-to-arith from #715) unless someone objects :)

Thinking more long-term, I wonder if there's enough demand/need from the accelerator side to justify turning this into a full "modular arithmetic" abstraction (iirc, that's what HEaaN.MLIR had), including modular-aware versions of the expected canonicalizaton/folding, including constant-folding. Of course, anything that lowers to "standard" arith can wait until that stage for these optimizations, but for the RLWE accelerators, that point will never be reached.

AlexanderViand-Intel commented 4 months ago

Btw, @inbelic: I noticed that the modulus in arith_ext ops is hardcoded to be I64Attr (and that's what I followed for the new ops, too). Is this a requirement of the passes/lowerings? If not, it'd be great if we could switch to APIntAttr (i.e., any integer attribute) as there are HW designs with native word sizes ranging all the way from 32 to 128 bits.

inbelic commented 4 months ago

We have a hard-coded variable APINT_BIT_WIDTH = 64 in Polynomial.h that is used when we parse the coefficient modulus of a poly.ring. This is why I had not directly used an APInt right away.

I agree that we should support higher bit-widths, but we should make it consistent everywhere. I said I would make an issue to address this but forget to actually create it. Will do so now.

j2kun commented 4 months ago

FWIW that particular constant should no longer be used for the coefficient modulus in the upstream polynomial dialect, as now the coefficient type and modulus are integer attrs which specify their own bit width in the IR

On Thu, Jun 6, 2024, 5:01 PM Finn Plummer @.***> wrote:

We have a hard-coded variable APINT_BIT_WIDTH = 64 in Polynomial.h that is used when we parse the coefficient modulus of a poly.ring. This is why I had not directly used an APInt right away.

I agree that we should support higher bit-widths, but we should make it consistent everywhere. I said I would make an issue to address this but forget to actually create it. Will do so now.

— Reply to this email directly, view it on GitHub https://github.com/google/heir/issues/724#issuecomment-2153599140, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAS2PKRVFXYBK4XJ4YCDBDDZGDZ4PAVCNFSM6AAAAABI46QJS6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNJTGU4TSMJUGA . You are receiving this because you commented.Message ID: @.***>

inbelic commented 4 months ago

Nice! There are no dependencies in any current lowering/pass, so we can use arbitrary integer as you suggest

AlexanderViand-Intel commented 4 months ago

I noticed a few "fun" things while putting this together, that I'm putting here as a note (for lack of a better place):

AlexanderViand-Intel commented 4 months ago

Done from my side, but there's a bit of a backlog of arith_ext-related PRs now. We already had

715 from @inbelic (adds the --arith-ext-to-arith pass)

725 (adds arith_ext.add/sub/mul)

and I've now added

731 (removes i64 constraint and adds support for arith_ext.add/sub/mul in the lowering pass)

732 (switches the ntt pass over to use the new stuff)