google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
242 stars 34 forks source link

Data-oblivious Programming and its Transformations #777

Open MeronZerihun opened 2 weeks ago

MeronZerihun commented 2 weeks ago

Data-oblivious Programming

A data-oblivious program is one that decouples data input from program execution. Such programs exhibit control-flow and memory access patterns that are independent of their input(s). This programming model is necessary for expressing FHE programs. There are 3 major transformations applied to convert a conventional program into a data-oblivious program:

(1) If-Transformation

If-operations conditioned on inputs create data-dependent control-flow in programs. scf.if operations should atleast define a 'then' region (true path) and are always terminated scf.yield even when scf.if doesn't produce a result. To convert data-dependent scf.if operation to an equivalent set of data-oblivious operations in MLIR, we hoist all safely speculatable operations in the scf.if operation and convert the scf.yield operation to an arith.select operation. The following code snippet demonstrates an application of this transformation.

// Before applying If-transformation
func.func @my_function(%input : i1) -> () {
  ...
  // Violation: %input is used as a condition causing a data-dependent branch
  %result =`%input -> (i16) {
        %a = arith.muli %b, %c : i16
        scf.yield %a : i16
      } else {
        scf.yield %b : i16
      }
  ...
}

// After applying If-transformation
func.func @my_function(%input : i16) -> (){
  ...
  %a = arith.muli %b, %c : i16
  %result = arith.select %input, %a, %b : i16
  ...
}

We have implemented a ConvertIfToSelect pass that transforms operations with secret-input conditions and with only Pure operations (i.e., operations that have no memory side effect and are speculatable) in their body. This transformation cannot be applied to operations when side effects are present in only one of the two regions. Although possible, we currently do not support transformations for operations where both regions have operations with matching side effects.

(2) Loop-Transformation

Loop statements with input-dependent conditions (bounds) and number of iterations introduce data-dependent branches that violate data-obliviousness. We can convert such loops to become data-oblivious by replacing input-dependent conditionals (bounds) with static input-independent parameters (e.g. defining a constant upper bound), replacing early loop exits with update operations where the value returned from the loop is selectively updated using conditional predication. In MLIR, loops are expressed using either affine.for, scf.for or scf.while operations.

// After applying Loop-Transformation func.func @my_function(%value: i32, %inputIndex: index) -> i32 { ... // %MAX_INDEX is a constant that defines the maximum possible index value %result = scf.for %iv = %begin to %MAX_INDEX step %step_value iter_args(%arg1 = %value) -> i32 { %output = arith.muli %arg1, %agr1 : i32 %cond = arith.cmpi eq, %iv, %inputIndex : index %newOutput = arith.select %cond, %output, %arg1 scf.yield %newOutput : i32 } ... }

- `scf.while`: This operation represents a generic while/do-while loop that keeps iterating as long as a condition is met. An input-dependent while condition introduces a data-dependent control flow that violates data-oblivious constraints. It's challenging to implement early-exits since the scf.while is also expected to [terminate with `scf.yield`](https://mlir.llvm.org/docs/Dialects/SCFDialect/#scfwhile-scfwhileop).
```llvm
// Before applying Loop-Transformation
func.func @my_function(%input: i16){
  %zero = arith.constant 0 : i16
  %result = scf.while (%arg1 = %input) : (i16) -> i16 {
    %cond = arith.cmpi slt, %arg1, %zero : i16
    // Violation: scf.while uses %cond whose value depends on %input
    scf.condition(%cond) %arg1 : i16
  } do {
  ^bb0(%arg2: i16):
    %mul = arith.muli %arg2, %arg2: i16
    scf.yield %mul
  }
  ...
  return
}

// After applying Loop-Transformation
func.func @my_function(%input: i16){
  %zero = arith.constant 0 : i16
  %begin = arith.constant 1 : index
  ...
  // Replace while-loop with a for-loop with a constant bound %MAX_ITER
  %result = scf.for %iv = %begin to %MAX_ITER step %step_value iter_args(%iter_arg = %input) -> i16 {
    %cond = arith.cmpi slt, %iter_arg, %zero : i16
    %mul = arith.muli %iter_arg, %iter_arg : i16
    %output = arith.select %cond, %mul, %iter_arg
    scf.yield %output
  }
  ...
  return
}

(3) Access-Transformation

Input-dependent memory access cause data-dependent memory footprints. A naive data-oblivious solution to this maybe doing read-write operations over the entire data structure while only performing the desired save/update operation for the index of interest. For simplicity, we only look at load/store operations for tensors as they are well supported structures in high-level MLIR likely emitted by most frontends. We drafted the following non-SIMD and SIMD-friendly approaches for this transformation:

// Before applying Access Transformation
func.func @my_function(%input: tensor<16xi32>, %inputIndex: index) {
  ...
  %c_10 = arith.constant 10 : i32
  // Violation: tensor.extract loads value at %inputIndex 
  %extractedValue = tensor.extract %input[%inputIndex] : tensor<16xi32>
  %newValue = arith.addi %extractedValue, %c_10 : i32
  // Violation: tensor.insert stores value at %inputIndex 
  %inserted = tensor.insert %newValue into %input[%inputIndex] : tensor<16xi32>
  ...
}

// After applying Access Transformation

// (1) Non-SIMD solution
func.func @my_function(%input: tensor<16xi32>, %inputIndex: index) {
  ...
  %c_10 = arith.constant 10 : i32
  %i_0 = arith.constant 0 : index
  %dummyValue = arith.constant 0 : i32

  %extractedValue = affine.for %i=0 to 16 iter_args(%arg= %dummyValue) -> (i32) {
    // 1. Extract value at %i
    // 2. Check if %i matches %inputIndex
    // 3. If %i matches %inputIndex, select %value extracted in (1) else select %dummyValue
    // 4. Yield selected value
    %value = tensor.extract %input[%i] : tensor<16xi32>
    %cond = arith.cmpi eq, %i, %inputIndex : index
    %selected = arith.select %cond, %value, %dummyValue : i32
    affine.yield %selected : i32
  }

  %newValue = arith.addi %extractedValue, %c_10 : i32

  %inserted = affine.for %i=0 to 16 iter_args(%inputArg = %input) -> tensor<16xi32> {
    // 1. Check if %i matches the %inputIndex
    // 2. Insert %newValue and produce %newTensor
    // 3. If %i matches %inputIndex, select %newTensor else select input tensor
    // 4. Yield final tensor
    %cond = arith.cmpi eq, %i, %inputIndex : index
    %newTensor = tensor.insert %value into %inputArg[%i] : tensor<16xi32>
    %finalTensor= arith.select %cond, %newTensor, %inputArg : tensor<16xi32>
    affine.yield %finalTensor : tensor<16xi32>
  }
  ...
}

// (2) SIMD-style solution
func.func @my_function(%input: tensor<16xi32>, %inputIndex: index){
  ...
  %c_10 = arith.constant 10 : i32

  // 1. Create a one-hot encoded tensor of 1 at %input_index and 0s for all the others
  // 2. Sign extend to 32-bits
  // 3. Multiply the one-hot encoded tensor with our input tensor
  // 4. Rotate and sum
  // 5. Get extracted value
  %mask = tensor_ext.index_to_mask %inputIndex {size = 16}  : index -> tensor<16xi1>
  %mask_ext = tensor_ext.extui %mask to i32 : tensor<16xi32> 
  %masked = arith.muli %input, %mask : tensor<16xi32>
  %maskedSum = tensor_ext.rotate_and_sum %masked : tensor<16xi32>
  %i_0 = arith.constant 0 : index
  %extractedValue = tensor.extract %maskedSum[%i_0] : tensor<16xi32> 

  %newValue = arith.addi %extractedValue, %c_10 : i32

  // 1. Clean out index(slot) where we'll insert value
  // 2. Sign extend to 32-bits
  // 3. Multiply existing tensor with negated mask
  // 4. Replicate element to be inserted across all slots 
  // 5. Multiply replicated with mask
  // 6. Add maskedInsert with maskedInputTensor
  %negatedMask = tensor_ext.negate_mask %mask : tensor<16xi1>
  %negatedMask_ext = tensor_ext.extui %negatedMask to i32 : tensor<16xi32> 
  %maskedInputTensor = tensor_ext.muli %negatedMask_ext, %input
  %replicated = tensor_ext.replicate %extractedValue {size = 16} : i32 -> tensor<16xi32>
  %maskedInsert = tensor_ext.mul %replicated, %mask_ext : tensor<16xi32>
  %inserted = tensor_ext.addi %maskedInsert, %maskedInputTensor : tensor<16xi32>
  ...
}

More notes on these transformations

PRs related to this issue

To-dos and Open Questions

j2kun commented 2 weeks ago

Great overview! I might add, because it seems implicit here, that the memory access path (non-SIMD) reduces to the loop+if conversions, so I would expect you can lower as affine.for, apply relevant affine loop fusion/fission passes, then apply your affine.for/scf.if conversions. This might require splitting the work into multiple passes, or at least multiple applications of applyPatternsAndFoldGreedily within the same pass.

On Tue, Jul 9, 2024, 10:20 AM MeronZerihun @.***> wrote:

Data-oblivious Programming

A data-oblivious program is one that decouples data input from program execution. Such programs exhibit control-flow and memory access patterns that are independent of their input(s). This programming model is necessary for expressing FHE programs. There are 3 major transformations applied to convert a conventional program into a data-oblivious program: (1) If-Transformation

If-operations conditioned on inputs create data-dependent control-flow in programs. scf.if operations should atleast define a 'then' region (true path) and are always terminated scf.yield even when scf.if doesn't produce a result. To convert data-dependent scf.if operation to an equivalent set of data-oblivious operations in MLIR, we hoist all safely speculatable operations in the scf.if operation and convert the scf.yield operation to an arith.select operation. The following code snippet demonstrates an application of this transformation.

// Before applying If-transformation func.func @my_function(%input : i1) -> () { ... // Violation: %input is used as a condition causing a data-dependent branch %result = scf.if %input -> (i16) { %a = arith.muli %b, %c : i16 scf.yield %a : i16 } else { scf.yield %b : i16 } ... }

// After applying If-transformation func.func @my_function(%input : i16) -> (){ ... %a = arith.muli %b, %c : i16 %result = arith.select %input, %a, %b : i16 ... }

We have implemented a ConvertIfToSelect pass that transforms scf.if operations with secret-input conditions and with only Pure operations (i.e., operations that have no memory side effect and are speculatable) in their body. This transformation cannot be applied to scf.if operations when side effects are present in only one of the two regions. Although possible, we currently do not support transformations for scf.if operations where both regions have operations with matching side effects. (2) Loop-Transformation

Loop statements with input-dependent conditions (bounds) and number of iterations introduce data-dependent branches that violate data-obliviousness. We can convert such loops to become data-oblivious by replacing input-dependent conditionals (bounds) with static input-independent parameters (e.g. defining a constant upper bound), replacing early loop exits with update operations where the value returned from the loop is selectively updated using conditional predication. In MLIR, loops are expressed using either affine.for, scf.for or scf.while operations.

  • affine.for: This operation lends itself well to expressing data oblivious programs because it requires constant loop bounds, eliminating input-dependent limits. Implementing early-exits becomes a challenge since the affine.for construct requires the body to contain a block that terminates with affine.yield https://mlir.llvm.org/docs/Dialects/Affine/#affinefor-affineaffineforop, meaning that the loop body must end with an affine.yield operation.

    %sum_0 = arith.constant 0.0 : f32 // The for-loop's bound is a fixed constant %sum = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_0) -> (f32) { %t = affine.load %buffer[%i] : memref<1024xf32> %sum_next = arith.addf %sum_iter, %input : f32 affine.yield %sum_next : f32 } ...

  • scf.for: Similar to affine.for operation, scf.for operation requires a body region that exactly contains one block that terminates with scf.yield https://mlir.llvm.org/docs/Dialects/SCFDialect/#scffor-scfforop, which makes implementing early loop exits within the same level of abstraction a challenge. However, scf.for does allow input-dependent conditionals which does not adhere to data-obliviousness constraints. A solution to this could be to either have the programmer or the compiler specify an input-independent upper bound so we can transform the loop to use this upper bound and also carefully update values returned from the for-loop using conditional predication. Our current solution to this is simply hardcoding the upper bounds in our programs. We are unsure how to express this bounds in the IR and we would love to get suggestions that promote reusing existing MLIR frontends.

func.func @my_function(%value: i32, %inputIndex: index) -> i32 { ... // Violation: for-loop uses %inputIndex as upper bound which causes a input-dependent control-flow %result = scf.for %iv = %begin to %inputIndex step %step_value iter_args(%arg1 = %value) -> i32 { %output = arith.muli %arg1, %arg1 : i32 scf.yield %output : i32 } ... }

// After applying Loop-Transformation func.func @my_function(%value: i32, %inputIndex: index) -> i32 { ... // %MAX_INDEX is a constant that defines the maximum possible index value %result = scf.for %iv = %begin to %MAX_INDEX step %step_value iter_args(%arg1 = %value) -> i32 { %output = arith.muli %arg1, %agr1 : i32 %cond = arith.cmpi eq, %iv, %inputIndex : index %newOutput = arith.select %cond, %output, %arg1 scf.yield %newOutput : i32 } ... }

  • scf.while: This operation represents a generic while/do-while loop that keeps iterating as long as a condition is met. An input-dependent while condition introduces a data-dependent control flow that violates data-oblivious constraints. It's challenging to implement early-exits since the scf.while is also expected to terminate with scf.yield https://mlir.llvm.org/docs/Dialects/SCFDialect/#scfwhile-scfwhileop.

// Before applying Loop-Transformation func.func @my_function(%input: i16){ %zero = arith.constant 0 : i16 %result = scf.while (%arg1 = %input) : (i16) -> i16 { %cond = arith.cmpi slt, %arg1, %zero : i16 // Violation: scf.while uses %cond whose value depends on %input scf.condition(%cond) %arg1 : i16 } do { ^bb0(%arg2: i16): %mul = arith.muli %arg2, %arg2: i16 scf.yield %mul } ... return }

// After applying Loop-Transformation func.func @my_function(%input: i16){ %zero = arith.constant 0 : i16 %begin = arith.constant 1 : index ... // Replace while-loop with a for-loop with a constant bound %MAX_ITER %result = scf.for %iv = %begin to %MAX_ITER step %step_value iter_args(%iter_arg = %input) -> i16 { %cond = arith.cmpi slt, %iter_arg, %zero : i16 %mul = arith.muli %iter_arg, %iter_arg : i16 %output = arith.select %cond, %mul, %iter_arg scf.yield %output } ... return }

(3) Access-Transformation

Input-dependent memory access cause data-dependent memory footprints. A naive data-oblivious solution to this maybe doing read-write operations over the entire data structure while only performing the desired save/update operation for the index of interest. For simplicity, we only look at load/store operations for tensors as they are well supported structures in high-level MLIR likely emitted by most frontends. We draft non-SIMD and SIMD-friendly approaches of our proposed data-oblivious sol The following code snippet demonstrates this transformation:

// Before applying Access Transformation func.func @my_function(%input: tensor<16xi32>, %inputIndex: index){ ... %c_10 = arith.constant 10 : i32 // Violation: tensor.extract loads value at %inputIndex %extractedValue = tensor.extract %input[%inputIndex] : tensor<16xi32> %newValue = arith.addi %extractedValue, %c_10 : i32 // Violation: tensor.insert stores value at %inputIndex %inserted = tensor.insert %newValue into %input[%inputIndex] : tensor<16xi32> ... }

// After applying Access Transformation

//(1) Non-SIMD solution func.func @my_function(%input: tensor<16xi32>, %inputIndex: index){ ... %c_10 = arith.constant 10 : i32 %i_0 = arith.constant 0 : index %dummyValue = arith.constant 0 : i32 %tempTensor = tensor.empty() : tensor<1xi32> %dummyTensor = tensor.insert %dummyValue into %tempTensor[%i_0] : tensor<1xi32>

%valueStore = affine.for %i=0 to 16 iter_args(%tensorArg = %dummyTensor) -> tensor<1xi32> { // 1. Extract value at %i // 2. Get value already saved in tensor // 3. Check if %i matches the %inputIndex // 4. If %i matches %inputIndex, select value extracted in (1) // 5. Insert value and yield new tensor %inputValue = tensor.extract %input[%i] : tensor<16xi32> %oldValue = tensor.extract %tensorArg[%i_0] : tensor<1xi32> %cond = arith.cmpi eq, %i, %inputIndex : index %value = arith.select %cond, %inputValue, %oldValue : i32 %newTensor = tensor.insert %value into %tensorArg[%i_0] : tensor<1xi32> affine.yield %newTensor : tensor<1xi32> } %extractedValue = tensor.extract %valueStore[%i_0] : tensor<1xi32>

%newValue = arith.addi %extractedValue, %c_10 : i32

%inserted = affine.for %i=0 to 16 iter_args(%inputArg = %input) -> tensor<16xi32> { // 1. Extract value at %i // 2. Check if %i matches the %inputIndex // 3. If %i matches %inputIndex, select %newValue // 4. Insert value and yield new tensor %oldValue = tensor.extract %inputArg[%i] : tensor<16xi32> %cond = arith.cmpi eq, %i, %inputIndex : index %value = arith.select %cond, %newValue, %oldValue : i32 %newTensor = tensor.insert %value into %inputArg[%i] : tensor<16xi32> affine.yield %newTensor : tensor<16xi32> } ... }

// (2) SIMD-style solution func.func @my_function(%input: tensor<16xi32>, %inputIndex: index){ ... %c_10 = arith.constant 10 : i32

// 1. Create a one-hot encoded tensor of 1 at %input_index and 0s for all the others // 2. Sign extend to 32-bits // 3. Multiply the one-hot encoded tensor with our input tensor // 4. Rotate and sum // 5. Get extracted value %mask = tensor_ext.index_to_mask %inputIndex {size = 16} : index -> tensor<16xi1> %mask_ext = tensor_ext.extui %mask to i32 : tensor<16xi32> %masked = arith.muli %input, %mask : tensor<16xi32> %maskedSum = tensor_ext.rotate_and_sum %masked : tensor<16xi32> %i_0 = arith.constant 0 : index %extractedValue = tensor.extract %maskedSum[%i_0] : tensor<16xi32>

%newValue = arith.addi %extractedValue, %c_10 : i32

// 1. Clean out index(slot) where we'll insert value // 2. Sign extend to 32-bits // 3. Multiply existing tensor with negated mask // 4. Replicate element to be inserted across all slots // 5. Multiply replicated with mask // 6. Add maskedInsert with maskedInputTensor %negatedMask = tensor_ext.negate_mask %mask : tensor<16xi1> %negatedMask_ext = tensor_ext.extui %negatedMask to i32 : tensor<16xi32> %maskedInputTensor = tensor_ext.muli %negatedMask_ext, %input %replicated = tensor_ext.replicate %extractedValue {size = 16} : i32 -> tensor<16xi32> %maskedInsert = tensor_ext.mul %replicated, %mask_ext : tensor<16xi32> %inserted = tensor_ext.addi %maskedInsert, %maskedInputTensor : tensor<16xi32> ... }

To-dos and Open Questions

  • If Transformation: Handle cases where both the 'then' and 'else' regions contain operations with matching side effects
  • Loop transformation: Implement transformation for input-dependent scf.for and scf.while loops
  • Loop transformation: Find ways to express constant upper bounds in the IR
  • Access transformation: Implement non-SIMD and SIMD solutions

— Reply to this email directly, view it on GitHub https://github.com/google/heir/issues/777, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAS2PKVRUPC46DYA45KSAFLZLQLV5AVCNFSM6AAAAABKTJMQVCVHI2DSMVQWIX3LMV43ASLTON2WKOZSGM4TQNZWGA3TGOI . You are receiving this because you are subscribed to this thread.Message ID: @.***>

MeronZerihun commented 2 weeks ago

Yes, that's a great point Jeremy. Thank you for the comment!