google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
237 stars 34 forks source link

Determine how to generically identify and handle plaintext-ciphertext ops in `insert-rotate` #586

Open j2kun opened 3 months ago

j2kun commented 3 months ago

The gx-kernel HECO example ported from https://github.com/google/heir/pull/587 shows a shortcoming of the existing insert-rotate pass.

Specifically, it anchors on binary operations that consume the result of extracting values from (1D) tensors, but in the gx-kernel example one of those tensors is a 2D tensor containing the weights of a convolution kernel. We got around that by canonicalizing the constant weight matrix into individual arith.constant ops and then having special patterns that identify constants and splat them into tensors.

However, there are two analogous situations:

  1. A not-quite-constant but still plaintext 1D tensor is combined with bona-fide secret tensors. By 'not quite constant' I mean that canonicalize won't split it up into individual arith.constant ops, but it is still semantically a cleartext. E.g., it could be a function input that is not annotated with {secret.secret}, which might be further obscured by having some arith ops locally process it before the extract op.
  2. Non-tensor, but non-constant scalar values that need to be splatted to be combined via arith.muli or similar to a normal "ciphertext" tensor value. Again, a scalar non-secret function input that is processed inside the function by arith ops seems like a sufficiently general case of this.

Both of the examples above suggest to me that we should have some pre-insert-rotate analysis pass that annotates which SSA values are restricted by the "tensor_ext.rotate" paradigm and which are not. Call such values "restricted" and "unrestricted" for brevity. For unrestricted SSA values, we can freely use tensor.extract and other arbitrary tensor ops, and splat scalars whenever they need to be combined with a restricted value. The insert-rotate pass needn't concern itself with how these identifications were chosen, but can match on the annotations in its patterns.

github-actions[bot] commented 3 months ago

This issue has 1 outstanding TODOs:

This comment was autogenerated by todo-backlinks

AlexanderViand-Intel commented 3 months ago

👍 Isn't the classification just "is the element type of the tensor secret.secret"? If that's the case, it might not be worth adding a whole pass/analysis and instead bneing able to provide an "isOperandRestricted" callback might be sufficient?

j2kun commented 3 months ago

I think that's right, with the only caveat being that we have tests that use tensor_ext without any secret IR (and tensor_ext has no internal concept of secret). Hence my suggestions for a "restricted" annotation.