Open MeronZerihun opened 2 weeks ago
Great overview! I might add, because it seems implicit here, that the memory access path (non-SIMD) reduces to the loop+if conversions, so I would expect you can lower as affine.for, apply relevant affine loop fusion/fission passes, then apply your affine.for/scf.if conversions. This might require splitting the work into multiple passes, or at least multiple applications of applyPatternsAndFoldGreedily within the same pass.
On Tue, Jul 9, 2024, 10:20 AM MeronZerihun @.***> wrote:
Data-oblivious Programming
A data-oblivious program is one that decouples data input from program execution. Such programs exhibit control-flow and memory access patterns that are independent of their input(s). This programming model is necessary for expressing FHE programs. There are 3 major transformations applied to convert a conventional program into a data-oblivious program: (1) If-Transformation
If-operations conditioned on inputs create data-dependent control-flow in programs. scf.if operations should atleast define a 'then' region (true path) and are always terminated scf.yield even when scf.if doesn't produce a result. To convert data-dependent scf.if operation to an equivalent set of data-oblivious operations in MLIR, we hoist all safely speculatable operations in the scf.if operation and convert the scf.yield operation to an arith.select operation. The following code snippet demonstrates an application of this transformation.
// Before applying If-transformation func.func @my_function(%input : i1) -> () { ... // Violation: %input is used as a condition causing a data-dependent branch %result = scf.if %input -> (i16) { %a = arith.muli %b, %c : i16 scf.yield %a : i16 } else { scf.yield %b : i16 } ... }
// After applying If-transformation func.func @my_function(%input : i16) -> (){ ... %a = arith.muli %b, %c : i16 %result = arith.select %input, %a, %b : i16 ... }
We have implemented a ConvertIfToSelect pass that transforms scf.if operations with secret-input conditions and with only Pure operations (i.e., operations that have no memory side effect and are speculatable) in their body. This transformation cannot be applied to scf.if operations when side effects are present in only one of the two regions. Although possible, we currently do not support transformations for scf.if operations where both regions have operations with matching side effects. (2) Loop-Transformation
Loop statements with input-dependent conditions (bounds) and number of iterations introduce data-dependent branches that violate data-obliviousness. We can convert such loops to become data-oblivious by replacing input-dependent conditionals (bounds) with static input-independent parameters (e.g. defining a constant upper bound), replacing early loop exits with update operations where the value returned from the loop is selectively updated using conditional predication. In MLIR, loops are expressed using either affine.for, scf.for or scf.while operations.
affine.for: This operation lends itself well to expressing data oblivious programs because it requires constant loop bounds, eliminating input-dependent limits. Implementing early-exits becomes a challenge since the affine.for construct requires the body to contain a block that terminates with affine.yield https://mlir.llvm.org/docs/Dialects/Affine/#affinefor-affineaffineforop, meaning that the loop body must end with an affine.yield operation.
%sum_0 = arith.constant 0.0 : f32 // The for-loop's bound is a fixed constant %sum = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_0) -> (f32) { %t = affine.load %buffer[%i] : memref<1024xf32> %sum_next = arith.addf %sum_iter, %input : f32 affine.yield %sum_next : f32 } ...
scf.for: Similar to affine.for operation, scf.for operation requires a body region that exactly contains one block that terminates with scf.yield https://mlir.llvm.org/docs/Dialects/SCFDialect/#scffor-scfforop, which makes implementing early loop exits within the same level of abstraction a challenge. However, scf.for does allow input-dependent conditionals which does not adhere to data-obliviousness constraints. A solution to this could be to either have the programmer or the compiler specify an input-independent upper bound so we can transform the loop to use this upper bound and also carefully update values returned from the for-loop using conditional predication. Our current solution to this is simply hardcoding the upper bounds in our programs. We are unsure how to express this bounds in the IR and we would love to get suggestions that promote reusing existing MLIR frontends.
func.func @my_function(%value: i32, %inputIndex: index) -> i32 { ... // Violation: for-loop uses %inputIndex as upper bound which causes a input-dependent control-flow %result = scf.for %iv = %begin to %inputIndex step %step_value iter_args(%arg1 = %value) -> i32 { %output = arith.muli %arg1, %arg1 : i32 scf.yield %output : i32 } ... }
// After applying Loop-Transformation func.func @my_function(%value: i32, %inputIndex: index) -> i32 { ... // %MAX_INDEX is a constant that defines the maximum possible index value %result = scf.for %iv = %begin to %MAX_INDEX step %step_value iter_args(%arg1 = %value) -> i32 { %output = arith.muli %arg1, %agr1 : i32 %cond = arith.cmpi eq, %iv, %inputIndex : index %newOutput = arith.select %cond, %output, %arg1 scf.yield %newOutput : i32 } ... }
- scf.while: This operation represents a generic while/do-while loop that keeps iterating as long as a condition is met. An input-dependent while condition introduces a data-dependent control flow that violates data-oblivious constraints. It's challenging to implement early-exits since the scf.while is also expected to terminate with scf.yield https://mlir.llvm.org/docs/Dialects/SCFDialect/#scfwhile-scfwhileop.
// Before applying Loop-Transformation func.func @my_function(%input: i16){ %zero = arith.constant 0 : i16 %result = scf.while (%arg1 = %input) : (i16) -> i16 { %cond = arith.cmpi slt, %arg1, %zero : i16 // Violation: scf.while uses %cond whose value depends on %input scf.condition(%cond) %arg1 : i16 } do { ^bb0(%arg2: i16): %mul = arith.muli %arg2, %arg2: i16 scf.yield %mul } ... return }
// After applying Loop-Transformation func.func @my_function(%input: i16){ %zero = arith.constant 0 : i16 %begin = arith.constant 1 : index ... // Replace while-loop with a for-loop with a constant bound %MAX_ITER %result = scf.for %iv = %begin to %MAX_ITER step %step_value iter_args(%iter_arg = %input) -> i16 { %cond = arith.cmpi slt, %iter_arg, %zero : i16 %mul = arith.muli %iter_arg, %iter_arg : i16 %output = arith.select %cond, %mul, %iter_arg scf.yield %output } ... return }
(3) Access-Transformation
Input-dependent memory access cause data-dependent memory footprints. A naive data-oblivious solution to this maybe doing read-write operations over the entire data structure while only performing the desired save/update operation for the index of interest. For simplicity, we only look at load/store operations for tensors as they are well supported structures in high-level MLIR likely emitted by most frontends. We draft non-SIMD and SIMD-friendly approaches of our proposed data-oblivious sol The following code snippet demonstrates this transformation:
// Before applying Access Transformation func.func @my_function(%input: tensor<16xi32>, %inputIndex: index){ ... %c_10 = arith.constant 10 : i32 // Violation: tensor.extract loads value at %inputIndex %extractedValue = tensor.extract %input[%inputIndex] : tensor<16xi32> %newValue = arith.addi %extractedValue, %c_10 : i32 // Violation: tensor.insert stores value at %inputIndex %inserted = tensor.insert %newValue into %input[%inputIndex] : tensor<16xi32> ... }
// After applying Access Transformation
//(1) Non-SIMD solution func.func @my_function(%input: tensor<16xi32>, %inputIndex: index){ ... %c_10 = arith.constant 10 : i32 %i_0 = arith.constant 0 : index %dummyValue = arith.constant 0 : i32 %tempTensor = tensor.empty() : tensor<1xi32> %dummyTensor = tensor.insert %dummyValue into %tempTensor[%i_0] : tensor<1xi32>
%valueStore = affine.for %i=0 to 16 iter_args(%tensorArg = %dummyTensor) -> tensor<1xi32> { // 1. Extract value at %i // 2. Get value already saved in tensor // 3. Check if %i matches the %inputIndex // 4. If %i matches %inputIndex, select value extracted in (1) // 5. Insert value and yield new tensor %inputValue = tensor.extract %input[%i] : tensor<16xi32> %oldValue = tensor.extract %tensorArg[%i_0] : tensor<1xi32> %cond = arith.cmpi eq, %i, %inputIndex : index %value = arith.select %cond, %inputValue, %oldValue : i32 %newTensor = tensor.insert %value into %tensorArg[%i_0] : tensor<1xi32> affine.yield %newTensor : tensor<1xi32> } %extractedValue = tensor.extract %valueStore[%i_0] : tensor<1xi32>
%newValue = arith.addi %extractedValue, %c_10 : i32
%inserted = affine.for %i=0 to 16 iter_args(%inputArg = %input) -> tensor<16xi32> { // 1. Extract value at %i // 2. Check if %i matches the %inputIndex // 3. If %i matches %inputIndex, select %newValue // 4. Insert value and yield new tensor %oldValue = tensor.extract %inputArg[%i] : tensor<16xi32> %cond = arith.cmpi eq, %i, %inputIndex : index %value = arith.select %cond, %newValue, %oldValue : i32 %newTensor = tensor.insert %value into %inputArg[%i] : tensor<16xi32> affine.yield %newTensor : tensor<16xi32> } ... }
// (2) SIMD-style solution func.func @my_function(%input: tensor<16xi32>, %inputIndex: index){ ... %c_10 = arith.constant 10 : i32
// 1. Create a one-hot encoded tensor of 1 at %input_index and 0s for all the others // 2. Sign extend to 32-bits // 3. Multiply the one-hot encoded tensor with our input tensor // 4. Rotate and sum // 5. Get extracted value %mask = tensor_ext.index_to_mask %inputIndex {size = 16} : index -> tensor<16xi1> %mask_ext = tensor_ext.extui %mask to i32 : tensor<16xi32> %masked = arith.muli %input, %mask : tensor<16xi32> %maskedSum = tensor_ext.rotate_and_sum %masked : tensor<16xi32> %i_0 = arith.constant 0 : index %extractedValue = tensor.extract %maskedSum[%i_0] : tensor<16xi32>
%newValue = arith.addi %extractedValue, %c_10 : i32
// 1. Clean out index(slot) where we'll insert value // 2. Sign extend to 32-bits // 3. Multiply existing tensor with negated mask // 4. Replicate element to be inserted across all slots // 5. Multiply replicated with mask // 6. Add maskedInsert with maskedInputTensor %negatedMask = tensor_ext.negate_mask %mask : tensor<16xi1> %negatedMask_ext = tensor_ext.extui %negatedMask to i32 : tensor<16xi32> %maskedInputTensor = tensor_ext.muli %negatedMask_ext, %input %replicated = tensor_ext.replicate %extractedValue {size = 16} : i32 -> tensor<16xi32> %maskedInsert = tensor_ext.mul %replicated, %mask_ext : tensor<16xi32> %inserted = tensor_ext.addi %maskedInsert, %maskedInputTensor : tensor<16xi32> ... }
To-dos and Open Questions
- If Transformation: Handle cases where both the 'then' and 'else' regions contain operations with matching side effects
- Loop transformation: Implement transformation for input-dependent scf.for and scf.while loops
- Loop transformation: Find ways to express constant upper bounds in the IR
- Access transformation: Implement non-SIMD and SIMD solutions
— Reply to this email directly, view it on GitHub https://github.com/google/heir/issues/777, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAS2PKVRUPC46DYA45KSAFLZLQLV5AVCNFSM6AAAAABKTJMQVCVHI2DSMVQWIX3LMV43ASLTON2WKOZSGM4TQNZWGA3TGOI . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Yes, that's a great point Jeremy. Thank you for the comment!
Data-oblivious Programming
A data-oblivious program is one that decouples data input from program execution. Such programs exhibit control-flow and memory access patterns that are independent of their input(s). This programming model is necessary for expressing FHE programs. There are 3 major transformations applied to convert a conventional program into a data-oblivious program:
(1) If-Transformation
If-operations conditioned on inputs create data-dependent control-flow in programs.
scf.if
operations should atleast define a 'then' region (true path) and are always terminatedscf.yield
even whenscf.if
doesn't produce a result. To convert data-dependentscf.if
operation to an equivalent set of data-oblivious operations in MLIR, we hoist all safely speculatable operations in thescf.if
operation and convert thescf.yield
operation to anarith.select
operation. The following code snippet demonstrates an application of this transformation.We have implemented a
ConvertIfToSelect
pass that transforms operations with secret-input conditions and with only Pure operations (i.e., operations that have no memory side effect and are speculatable) in their body. This transformation cannot be applied to operations when side effects are present in only one of the two regions. Although possible, we currently do not support transformations for operations where both regions have operations with matching side effects.(2) Loop-Transformation
Loop statements with input-dependent conditions (bounds) and number of iterations introduce data-dependent branches that violate data-obliviousness. We can convert such loops to become data-oblivious by replacing input-dependent conditionals (bounds) with static input-independent parameters (e.g. defining a constant upper bound), replacing early loop exits with update operations where the value returned from the loop is selectively updated using conditional predication. In MLIR, loops are expressed using either
affine.for
,scf.for
orscf.while
operations.affine.for
: This operation lends itself well to expressing data oblivious programs because it requires constant loop bounds, eliminating input-dependent limits. Implementing early-exits becomes a challenge since theaffine.for
construct requires the body to contain a block that terminates withaffine.yield
, meaning that the loop body must end with an affine.yield operation.scf.for
: Similar to affine.for operation, scf.for operation requires a body region that exactly contains one block that terminates withscf.yield
, which makes implementing early loop exits within the same level of abstraction a challenge. However, scf.for does allow input-dependent conditionals which does not adhere to data-obliviousness constraints. A solution to this could be to either have the programmer or the compiler specify an input-independent upper bound so we can transform the loop to use this upper bound and also carefully update values returned from the for-loop using conditional predication. Our current solution to this is simply hardcoding the upper bounds in our programs. We are unsure how to express this bounds in the IR and we would love to get suggestions that promote reusing existing MLIR frontends.// After applying Loop-Transformation func.func @my_function(%value: i32, %inputIndex: index) -> i32 { ... // %MAX_INDEX is a constant that defines the maximum possible index value %result = scf.for %iv = %begin to %MAX_INDEX step %step_value iter_args(%arg1 = %value) -> i32 { %output = arith.muli %arg1, %agr1 : i32 %cond = arith.cmpi eq, %iv, %inputIndex : index %newOutput = arith.select %cond, %output, %arg1 scf.yield %newOutput : i32 } ... }
(3) Access-Transformation
Input-dependent memory access cause data-dependent memory footprints. A naive data-oblivious solution to this maybe doing read-write operations over the entire data structure while only performing the desired save/update operation for the index of interest. For simplicity, we only look at load/store operations for tensors as they are well supported structures in high-level MLIR likely emitted by most frontends. We drafted the following non-SIMD and SIMD-friendly approaches for this transformation:
More notes on these transformations
scf.for
andscf.if
operations) → Loop-Transformation (change data-dependent loops to use constant bounds and condition the loop's yield results withscf.if
operation) → If-Transformation (substitute data-dependent conditionals witharith.select
operation).tensor.extract
operations over the same tensor, we can also apply upstream affine transformations on the resulting multiple affine loops produced by the Access-Transformation to fuse these loops.PRs related to this issue
737
778
To-dos and Open Questions
scf.for
andscf.while
loops