google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
317 stars 47 forks source link

convert-polynomial-mul-to-ntt: missing primitive root attribute when lowering further #993

Open ZenithalHourlyRate opened 3 weeks ago

ZenithalHourlyRate commented 3 weeks ago

Since https://github.com/llvm/llvm-project/pull/93227, primitiveRoot is no longer a param for #polynomial.ring, so currently convert-polynomial-mul-to-ntt just rewrites mul to ntt without #root specified (nullptr, check here)

We should then pass the primitive root in the argument of convert-polynomial-mul-to-ntt otherwise the converted result could not be further lowered (polynomial-to-standard)

#cycl = #polynomial.int_polynomial<1 + x**4>
#root = #polynomial.primitive_root<value=1925:i32, degree=8:i32>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 7681 : i32, polynomialModulus=#cycl>
!p = !polynomial.polynomial<ring=#ring>

module {
  func.func @polymul(%x : !p, %y : !p) -> (!p) {
    %add = polynomial.mul %x, %y : !p
    return %add : !p
  }
}

For example, --convert-polynomial-mul-to-ntt --polynomial-to-standard will result in

error: missing root attribute                                                                                                     
    %add = polynomial.mul %x, %y : !p                                                                                                                                                         ^                                                                                                                                                                       
note: see current operation: %4 = "polynomial.ntt"(<<UNKNOWN SSA VALUE>>) : (!polynomial.polynomial<ring = <coefficientType = i32, coefficientModulus = 7681 : i32, polynomialModulus = <1 + x**4>>>) -> tensor<4xi13, #polynomial.ring<coefficientType = i32, coefficientModulus = 7681 : i32, polynomialModulus = <1 + x**4>>>

Note that tests/polynomial/ntt_rewrites.mlir also can not be further lowered to standard

ZenithalHourlyRate commented 1 week ago

pass the primitive root in the argument of convert-polynomial-mul-to-ntt

This is impossible, so many combination of cmod and degree.

Related to #644 and https://github.com/google/heir/issues/543#issuecomment-2031067952

Why shouldn't we just let the user specify one in ringAttr. This is especially painful for cases like rns where many primitiveRoots should be specified in some way. Let alone cases where root should be some fixed value according to standard.

Better design to put semantics on the ops, and in this case the ntt/intt op can lower in multiple ways depending on the polynomial ring modulus (it can need an nth root of unity for cyclic polymul -> ntt, or a 2nth root for negacyclic polymul -> ntt)

Then if we want interplay between two poly with same degree/cmod but different root should be that we add an op that adds/discards the root attribute.

A note though, currently ntt impl in polynomial-to-standard assumes negacyclic polymul.

j2kun commented 1 week ago

I'm not sure I follow the logic here. A primitive root is not part of the semantic specification of a ring. It's extra data needed for certain ops. Moreover, it's semantic information specifically about how the op is implemented, so it makes a lot more sense for it to be on the op rather than the type (my thinking on this has changed since the linked comment, which pre-dates my implementation of the polynomial types upstream).

Putting lots of data on the type adds many extra headaches in the rest of the MLIR landscape, around type conversion and type checking, as well as performance barriers since the type is used everywhere, whereas the op is only processed when it is actually used.

This is especially painful for cases like rns where many primitiveRoots should be specified in some way. Let alone cases where root should be some fixed value according to standard.

I don't see what this has to do with the data being specified on the type/op.

j2kun commented 1 week ago

As a potential alternative: we could have a pass that is responsible for populating the primitive roots, and allow ntt/intt to exist root-less until a certain lowering requires it.

ZenithalHourlyRate commented 1 week ago

we could have a pass that is responsible for populating the primitive roots

Then that would be a table like --polynomial-populate-primitive-roots=mod1,degree1,root1,mod2,degree2,root2,mod3,degree3,root3....

That is quite verbose as it can be written in the input code (embedded in type) then why in commandline.

ZenithalHourlyRate commented 1 week ago

End user using --mlir-to-bgv='entry-function=dot_product ciphertext-degree=8' --bgv-to-lwe --lwe-to-polynomial --heir-polynomial-to-llvm is also likely to specify a lot of parameters, as pointed out in https://github.com/google/heir/issues/536#issuecomment-2009753131

If we go in this direction, we then have to use a config file.

j2kun commented 1 week ago

we could have a pass that is responsible for populating the primitive roots

Then that would be a table like --polynomial-populate-primitive-roots=mod1,degree1,root1,mod2,degree2,root2,mod3,degree3,root3....

That is quite verbose as it can be written in the input code (embedded in type) then why in commandline.

Why can't the pass determine the primitive roots? It is an implementation detail, not something a user would specify directly.

ZenithalHourlyRate commented 1 week ago

we could have a pass that is responsible for populating the primitive roots

Then that would be a table like --polynomial-populate-primitive-roots=mod1,degree1,root1,mod2,degree2,root2,mod3,degree3,root3.... That is quite verbose as it can be written in the input code (embedded in type) then why in commandline.

Why can't the pass determine the primitive roots? It is an implementation detail, not something a user would specify directly.

I was thing about applications like #232 where the choice of primitve root is pre-determined instead of internal compiler stuff. The user can specify the root everytime they calls ntt/intt, but the user can not use polynomial.mul and has to write mod_arith.mul with modulus themself. It is OK like for the case of Dilithium where user does not need polynomial.mul.

I agree for mul-to-ntt itself it is OK for the compiler to compute the primitive root, as ntt/intt comes in pair and root is guaranteed to be the same. Then the question is should we use the old StaticRoot.h or integrate third-party math library. I think the easier way is to reuse the old StaticRoot.h and lookup it for the mul-to-ntt pass.

Another point that comes to me is that, for individual ntt/intt the root should be embedded in the tensor type

%ntt0 = polynomial.ntt %poly0: !poly -> tensor<nxi32, #ring> // lowered using #root0
%ntt1 = polynomial.ntt %poly1 {root=#root1} : !poly -> tensor<nxi32, #ring> // should be tensor<nxi32, #ring, #root1>
mod_arith.mul %ntt0, %ntt1 // should be forbidden
%ntt0 = polynomial.ntt %poly0: !poly -> tensor<nxi32, #ring> // lowered using #root0
%ntt1 = polynomial.intt %ntt0 {root=#root1} : tensor<nxi32, #ring> -> !poly // should be forbidden
j2kun commented 6 days ago

I think it's safe to have a static roots file and use it as a lookup, and raise an error if the static file does not contain some needed values. If/when we end up implementing something that needs specific NTT roots, providing a config file and passing the name of the config file as a command-line flag seems OK to me.

for individual ntt/intt the root should be embedded in the tensor type

We can't put more attributes on tensor type because it is defined upstream, and I'm not even particularly happy we decided to put the ring attribute on the tensor. It seems like it may cause us problems later since the tensor type is beyond our control; upstream can remove it from the type and then we will be forced to change. But for the root, I think we don't need to know the root when converting a tensor type, we only need to know it when converting the op. So it shouldn't need to go on the type.