google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
237 stars 34 forks source link

PolyToStandard: handling tensors of poly? #143

Closed asraa closed 1 week ago

asraa commented 10 months ago

@mr0re1 brought up a great review comment about lowering tensors of poly's to standard. https://github.com/google/heir/pull/134/files#r1312057016

Tensors of polys arise from BGV lowerings (and likely all other scheme lowerings to poly). When lowering to standard, we convert poly's to a tensor of ints.

Tensors of tensors are disallowed in MLIR - they are invalid tensor element types.

We have a few options here:

(1) Flatten the tensor: e.g. if we have a tensor polynomials, each of which are lowered to tensor<1024xi64>, then a size 2 tensor would become tensor<2048xi64>. However, this would appear to severely complicate the lowering math.

(2) Detensorize/lower to linalg and then lower to standard? E.g. inputs to PolyToStandard would not include tensors.

(3) It feels like I'm missing something - surely type conversion with tensors has been done before?

asraa commented 10 months ago

@j2kun very kindly reminded me we can just multi-dim the tensors

mr0re1 commented 10 months ago

(1) Flatten the tensor: e.g. if we have a tensor polynomials, each of which are lowered to tensor<1024xi64>, then a size 2 tensor would become tensor<2048xi64>. However, this would appear to severely complicate the lowering math.

The tensor<2x1024xi64> seems to be more intuitive than tensor<2048xi64>.

asraa commented 10 months ago

Yep :) multi-dim it is. Haha it totally slipped my mind. I'll update with some code, I believe I need to add a type converter for ranked tensors, and then handle some tensor ops.

What's really strange is there's no interface pattern for this :/ which makes it really weird. Even if the type converter exists, I need to add support for converting ops like tensor.from_elements - I believe some conversions write a basic catch-all op converter that just changes types for these ops.

j2kun commented 10 months ago

Let's chat a bit about the type conversion when you get back to this. I don't think the type conversion should be particularly hard...

j2kun commented 4 months ago

From our HEIR meeting today I suggested we should be able to find a way to use upstream passes to map elementwise operations to loops, perhaps via linalg.

@AlexanderViand-Intel could you post your sample IR?

AlexanderViand-Intel commented 4 months ago

Sure, here's a minimal example, something like this also appears when lowering tests/bgv/to_polynomial.mlir:

func.func @test_bin_ops(%arg0: tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>, %arg1: tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>) ->  tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>> {
  %0 = polynomial.add(%arg0, %arg1) : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
  return %0 :  tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
}

Trying to lower this with -polynomial-to-standard with the pass "as-is" (as of 04f6106b42e99a04f9382c71f855929f0f3a280f) simply crashes on this cast in ConvertAdd because the pass assumes the operands are polynomials but here they of course aren't.

Since the whole idea of traits like ElementwiseMappable is that we shouldn't have to re-invent the wheel, I tried applying the existing -convert-elementwise-to-linalg pass to the example, which results in:

#map = affine_map<(d0) -> (d0)>
func.func @test_bin_ops(%arg0: tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>, %arg1: tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>) -> tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>> {
  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>, tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>) outs(%arg0 : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>) {
  ^bb0(%in: !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>, %in_0: !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>, %out: !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>):
    %1 = polynomial.add(%in, %in_0) : !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>
    linalg.yield %1 : !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>
  } -> tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
  return %0 : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
}

However, from here I'm a bit stuck, none of the passes I tried would further simplify this in a meaningful way (e.g., convert it to loop).

Note that applying -polynomial-to-standard on the linalg.generic version as-is crashes because of unresolved materializations around yield, but adding builtin.unrealized_conversion_cast as materializers lets you run the pass and see what's going on a bit better:

#map = affine_map<(d0) -> (d0)>
func.func @test_bin_ops(%arg0: tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>, %arg1: tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>) -> tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>> {
  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>, tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>) outs(%arg0 : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>) {
  ^bb0(%in: !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>, %in_0: !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>, %out: !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>):
    %1 = builtin.unrealized_conversion_cast %in : !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>> to tensor<1024xi25>
    %2 = builtin.unrealized_conversion_cast %in_0 : !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>> to tensor<1024xi25>
    %cst = arith.constant dense<33538049> : tensor<1024xi26>
    %3 = arith.extsi %1 : tensor<1024xi25> to tensor<1024xi26>
    %4 = arith.extsi %2 : tensor<1024xi25> to tensor<1024xi26>
    %5 = arith.addi %3, %4 : tensor<1024xi26>
    %6 = arith.remsi %5, %cst : tensor<1024xi26>
    %7 = arith.trunci %6 : tensor<1024xi26> to tensor<1024xi25>
    %8 = builtin.unrealized_conversion_cast %7 : tensor<1024xi25> to !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>
    linalg.yield %8 : !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>
  } -> tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
  return %0 : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
}

From here, we could think about how these builtin.unrealized_conversion_casts should be resolved and might arrive at something like this (manually adapted from the above) with multi-dimensional tensors:

#map = affine_map<(d0) -> (d0)>
func.func @test_bin_ops(%arg0: tensor<2x1024xi25>, %arg1: tensor<2x1024xi25>) -> tensor<2x1024xi25> {
  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<2x1024xi25>, tensor<2x1024xi25>) outs(%arg0 : tensor<2x1024xi25>) {
  ^bb0(%in: tensor<1024xi25>, %in_0: tensor<1024xi25>, %out: tensor<1024xi25>):
    %cst = arith.constant dense<33538049> : tensor<1024xi26>
    %3 = arith.extsi %in : tensor<1024xi25> to tensor<1024xi26>
    %4 = arith.extsi %in_0 : tensor<1024xi25> to tensor<1024xi26>
    %5 = arith.addi %3, %4 : tensor<1024xi26>
    %6 = arith.remsi %5, %cst : tensor<1024xi26>
    %7 = arith.trunci %6 : tensor<1024xi26> to tensor<1024xi25>
    linalg.yield %7 : tensor<1024xi25>
  } -> tensor<2x1024xi25>
  return %0 : tensor<2x1024xi25>
}

Unfortunately, this doesn't work because linalg.generic expects to be "looping" (conceptually) over all the dimensions: error: 'linalg.yield' op type of yield operand 1 ('tensor<1024xi25>') doesn't match the element type of the enclosing linalg.generic op ('i25'). Of course, we could flatten tensor<2x1024xi25> into tensor<2048xi25> but that seems like a lot of work/boilerplate for something pretty messy.

AlexanderViand-Intel commented 4 months ago

I think if we had a pass that converted elementwise mappable things to affine.for instead of linalg.generic, which might look a bit like this:

#map = affine_map<(d0) -> (d0)>
func.func @test_bin_ops(%arg0: tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>, %arg1: tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>) -> tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>> {
  %0 = affine.for %i = 0 to 2 iter_args(%a = %arg0) -> tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>> {
    %in = tensor.extract %arg0[%i] :  tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
    %in_0 = tensor.extract %arg1[%i] :  tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
    %1 = polynomial.add(%in, %in_0) : !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>
    %t = tensor.insert %1 into %a[%i] :  tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
    affine.yield %t :  tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
  }
  return %0 : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
}

We could then handle the type conversion gracefully with multi-dimensional tensors. I haven't yet found a pass that would do this (either going through linalg.generic or directly), but I've also not spent as much time with the "standard" part of MLIR as @j2kun and @asraa so maybe you've got an idea here.

j2kun commented 4 months ago

@AlexanderViand-Intel https://github.com/google/heir/pull/504 should show you how to convert to affine loops, though it still runs into the problem of a failed type conversion in --polynomial-to-standard, which I suspect is an unrelated bug.

AlexanderViand-Intel commented 4 months ago

@AlexanderViand-Intel #504 should show you how to convert to affine loops, though it still runs into the problem of a failed type conversion in --polynomial-to-standard, which I suspect is an unrelated bug.

Even my "manually converted" affine loop thing above runs into failed type conversion issues, which I think is because the type converter doesn't realize that it can convert from tensor<2x!poly.poly<...>> to tensor<2x1024xi25> (or whatever the poly in the tensor happens to convert to, so #505 might not be entirely unrelated?

j2kun commented 4 months ago

I see what you're saying, but I would expect it to just type-convert the content of the memref opaquely. Maybe the problem is that you can't have a memref of tensors? Anyhow, if you have time to root-cause that type conversion issue, that'd be helpful. Otherwise as least we have a clear path to convert elementwise to loops now :)

AlexanderViand-Intel commented 4 months ago

I don't think it's memref related, because lowering my example above (no memrefs, just affine.for) with the -polynomial-to-lower gives similar errors, and using the "augmented" version (with unrealized conversion cast) results in this:

func.func @test_bin_ops(%arg0: tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>, %arg1: tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>) -> tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>> {
%0 = affine.for %arg2 = 0 to 2 iter_args(%arg3 = %arg0) -> (tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>) {
    %extracted = tensor.extract %arg0[%arg2] : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
    %1 = builtin.unrealized_conversion_cast %extracted : !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>> to tensor<1024xi25>
    %extracted_0 = tensor.extract %arg1[%arg2] : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
    %2 = builtin.unrealized_conversion_cast %extracted_0 : !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>> to tensor<1024xi25>
    %cst = arith.constant dense<33538049> : tensor<1024xi26>
    %3 = arith.extsi %1 : tensor<1024xi25> to tensor<1024xi26>
    %4 = arith.extsi %2 : tensor<1024xi25> to tensor<1024xi26>
    %5 = arith.addi %3, %4 : tensor<1024xi26>
    %6 = arith.remsi %5, %cst : tensor<1024xi26>
    %7 = arith.trunci %6 : tensor<1024xi26> to tensor<1024xi25>
    %8 = builtin.unrealized_conversion_cast %7 : tensor<1024xi25> to !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>
    %inserted = tensor.insert %8 into %arg3[%arg2] : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
    affine.yield %inserted : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
}
return %0 : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
}

If I add an explicit typeconversion then we get a bit closer (note the func has been handled properly):

func.func @test_bin_ops(%arg0: tensor<1024x2xi25>, %arg1: tensor<1024x2xi25>) -> tensor<1024x2xi25> {
  %0 = builtin.unrealized_conversion_cast %arg1 : tensor<1024x2xi25> to tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
  %1 = builtin.unrealized_conversion_cast %arg0 : tensor<1024x2xi25> to tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
  %2 = affine.for %arg2 = 0 to 2 iter_args(%arg3 = %1) -> (tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>) {
    %extracted = tensor.extract %1[%arg2] : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
    %4 = builtin.unrealized_conversion_cast %extracted : !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>> to tensor<1024xi25>
    %extracted_0 = tensor.extract %0[%arg2] : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
    %5 = builtin.unrealized_conversion_cast %extracted_0 : !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>> to tensor<1024xi25>
    %cst = arith.constant dense<33538049> : tensor<1024xi26>
    %6 = arith.extsi %4 : tensor<1024xi25> to tensor<1024xi26>
    %7 = arith.extsi %5 : tensor<1024xi25> to tensor<1024xi26>
    %8 = arith.addi %6, %7 : tensor<1024xi26>
    %9 = arith.remsi %8, %cst : tensor<1024xi26>
    %10 = arith.trunci %9 : tensor<1024xi26> to tensor<1024xi25>
    %11 = builtin.unrealized_conversion_cast %10 : tensor<1024xi25> to !polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>
    %inserted = tensor.insert %11 into %arg3[%arg2] : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
    affine.yield %inserted : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>>
  }
  %3 = builtin.unrealized_conversion_cast %2 : tensor<2x!polynomial.polynomial<<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>>> to tensor<1024x2xi25>
  return %3 : tensor<1024x2xi25>
}

but it's still not quite there, because the affine.for isn't converted. I think this is because it's not part of the operations that addStructuralConversionPatterns handles. Do you know if there's an equivalent for affine, or do we need to manually mark it as a dynamically legal op?

AlexanderViand-Intel commented 4 months ago

I think I have an idea of how to handle this (and also #505), WIP here. It nearly works, but right now only with --verify-each=false because it produces invalid tensor.extract and affine.load operations. The reason is that when we go from extracting a poly from tensor<2x!polynomial.polynomial<..>> to triyng the same from a tensor<2x1024xi25>, we need to switch from tensor.extract to tensor.slice (and ditto for affine.load to affine.load_vector). I'll put together a PR later today.

PS: this seems like it would pop up nearly as frequently as the func.func/control-flow conversion needs, so I wouldn't be surprised if there's already something like addStructuralConversionPatterns but for containers - I just couldn't find one when I looked around the upstream examples. PPS: 🤦 addStructuralConversionPatterns is a HEIR thing, not an MLIR thing, so no wonder there's no equivalent (yet!)

AlexanderViand-Intel commented 4 months ago

The "tensor of tensor" PR (#508) lays the groundwork for supporting poly operations on tensors, but it's not a full end-to-end solution because it doesn't support the code generated by the --convert-elementwise-to-linalg chain (c.f. #504), so even with that we still can't quite lower a polynomial.add %t1, %t2 : tensor<2x!polynomial.polynomial> .

Instead of adding "memref of memref" helpers and then trying to figure out how to elegantly handle the raising from memref back to tensor (c.f. #505) I figured it might actually be easier to just add a --convert-elementwise-to-affine pass, which seems like a generally useful thing to have, so I started a draft PR for that: #524

j2kun commented 3 months ago

Reopening just to clean up the remaining comments

AlexanderViand-Intel commented 2 months ago

Digging this back up, as I noticed (as part of starting on #559) that poly.ntt and poly.intt aren't ElementwiseMappable and, in fact, can't be, because ElementwiseMappable requires that all non-scalar types match, and a "tensorized" version of an NTT might look something like this:

#ring = #polynomial.ring<cmod=33538049, ideal=#polynomial.polynomial<1 + x**1024>>
%0= polynomial.ntt %some_poly : !polynomial.polynomial<#ring> -> tensor<1024xi64, #ring> // normal NTT
%1= polynomial.ntt %some_poly_tensor : tensor<2x!polynomial.polynomial<#ring>> -> tensor<2x1024xi64, #ring> // tensor NTT

and so the ElementwiseMappable verifier will complain that "all non-scalar operands/results must have the same shape and base type". In the future, it might be interesting to start a conversation upstream about the possibility of expanding ElementwiseMappable, but for now I'll just avoid adding the trait to the op and instead add a todo for custom patterns for poly.ntt and poly.intt to our elementwise-to-affine pass.

AlexanderViand-Intel commented 2 weeks ago

I'm re-opening this issue, as the problem mentioned above appeared again as part of #763 and therefore needs attention.