google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
333 stars 48 forks source link

polynomial-to-standard: error on integer width for poly.add when ring coefficientModulus width is less than coefficient width #990

Open ZenithalHourlyRate opened 1 month ago

ZenithalHourlyRate commented 1 month ago

When playing around tests/bgv/to_polynomial.mlir with additional -convert-elementwise-to-affine -polynomial-to-standard, error happened

error: 'arith.extsi' op operand type 'tensor<1024xi32>' and result type 'tensor<1024xi26>' are cast incompatible
    %mul = bgv.mul %x, %y : (!ct1, !ct1) -> !ct2
           ^
note: see current operation: %19 = "arith.extsi"(%13) : (tensor<1024xi32>) -> tensor<1024xi26>

Reduced to a minimal working example with argument -polynomial-to-standard, where changing the line to !p = !p2 resolves the problem:

#my_poly = #polynomial.int_polynomial<1 + x**1024>
#ring1 = #polynomial.ring<coefficientType = i32, coefficientModulus = 33538049 : i32, polynomialModulus=#my_poly>
#ring2 = #polynomial.ring<coefficientType = i25, coefficientModulus = 33538049 : i25, polynomialModulus=#my_poly>

!p1 = !polynomial.polynomial<ring = #ring1>
!p2 = !polynomial.polynomial<ring = #ring2>

!p = !p1

module {
  func.func @polymul(%x : !p, %y : !p) -> (!p) {
    %add = polynomial.add %x, %y : !p
    return %add : !p
  }
}

The related code is the following

https://github.com/google/heir/blob/b79bccf6bf3863083e339bd582b4e06fdc3dd561/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp#L522-L543

It assumes either coefficientModulus is a power-of-two or its width is the same as coefficientType.

Maybe we should check that coefficientType has the same width as coefficientModulus if it is a prime.

Also, should we consider lowering it to mod_arith so that polynomial-to-standard can be simplified? Currently extsi/extui is everywhere.

asraa commented 1 month ago

cc @AlexanderViand @inbelic on this issue

Also, should we consider lowering it to mod_arith so that polynomial-to-standard can be simplified?

I think so - but I will need to catch up on this particular pass pipeline.

AlexanderViand-Intel commented 1 month ago

Thanks for bringing this up! I’ll add it to the list of “issues that show why we really need formal verification around mod_arith stuff”!

I agree @ switching this pipeline over to mod_arith would be the best thing to make sure at least these kinds of issues only occur in one place xD

ZenithalHourlyRate commented 1 month ago

polynomial.mul_scalar also has an erroneous lowering (the test only covers power-of-two branch):

Similar code as above, and would report error no matter !p1 or !p2:

#my_poly = #polynomial.int_polynomial<1 + x**1024>
#ring1 = #polynomial.ring<coefficientType = i32, coefficientModulus = 33538049 : i32, polynomialModulus=#my_poly>
#ring2 = #polynomial.ring<coefficientType = i25, coefficientModulus = 33538049 : i25, polynomialModulus=#my_poly>

!p1 = !polynomial.polynomial<ring = #ring1>
!p2 = !polynomial.polynomial<ring = #ring2>

!p = !p1

module {
  func.func @polymul(%x : !p, %y : !p) -> (!p) {
    %add = polynomial.sub %x, %y : !p
    return %add : !p
  }
}

Note that polynomial.sub is canonicalized to polynomial.mul_scalar -1 and polynomial.add, and further lowering would result in

error: 'arith.remsi' op requires the same type for all operands and results
    %add = polynomial.sub %x, %y : !p
           ^
note: see current operation: %4 = "arith.remsi"(%arg1, %3) : (tensor<1024xi32>, i32) -> tensor<1024xi32>

IR emitted here (one RemSi from add lowering, another from mul_scalar lowering)

  %4 = "arith.constant"() <{value = -1 : i32}> : () -> i32
  %5 = "tensor.splat"(%4) : (i32) -> tensor<1024xi32>
  %6 = "arith.muli"(%1, %5) <{overflowFlags = #arith.overflow<none>}> : (tensor<1024xi32>, tensor<1024xi32>) -> tensor<1024xi32>
  %7 = "arith.constant"() <{value = 33538049 : i32}> : () -> i32
  %8 = "arith.remsi"(%1, %7) : (tensor<1024xi32>, i32) -> tensor<1024xi32>
  %10 = "arith.constant"() <{value = dense<33538049> : tensor<1024xi26>}> : () -> tensor<1024xi26>
  %11 = "arith.extsi"(%3) : (tensor<1024xi32>) -> tensor<1024xi26>
  %12 = "arith.extsi"(%8) : (tensor<1024xi32>) -> tensor<1024xi26>
  %13 = "arith.addi"(%11, %12) <{overflowFlags = #arith.overflow<none>}> : (tensor<1024xi26>, tensor<1024xi26>) -> tensor<1024xi26>
  %14 = "arith.remsi"(%13, %10) : (tensor<1024xi26>, tensor<1024xi26>) -> tensor<1024xi26>
  %15 = "arith.trunci"(%14) : (tensor<1024xi26>) -> tensor<1024xi32>