google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
349 stars 49 forks source link

Adding early-exits to MLIR #922

Open MeronZerihun opened 3 months ago

MeronZerihun commented 3 months ago

Currently, we don't have constructs to express early-exits in MLIR. As early-exits are common in many applications, we find it important to add support in loops.

State-of-the-art

In MLIR, a terminator is allowed to transfer the control flow either to another block in the same region or to the parent operation of the region. The parent operation may transfer the control to some other block in the same region as it is located itself, but it must be a terminator for this. Our current loop constructions are strictly defined to contain limited number of regions and only one terminator operation. In the scf and affine dialects, the for loop is defined to contain one region that contains its body. Defining multiple basic block in the for-loop body seems unachieveable since the loop only allows a single terminator within the region. The while and do-while loop defined by the scf dialect actually allow two regions: before and after regions but this regions are defined to fulfill the semantic meaning of how these loops operate. We also have a limit on the number and the type of terminators we can use in these regions.

Proposed design

Ideally, we want to build new loops that either has (1) multiple regions or (2) multiple basic blocks as follows. Please note that the multiple region proposal is similar to this RFC proposal.

Example of a C-like program with break:

int foo(int input, int max, int threshold){
    int sum = 0
    for(int i = 0; i < max; i += 1){
        sum = sum + input
        if( i >= threshold)
            break
    }
    return sum
}

Using multiple basic blocks within the same region:

func.func @foo(%input : i32, %max : index, %threshold : index) -> i32{
    %zero = arith.constant 0 : index
    %one = arith.constant 1 : index
    %result = scf.breakable_for %i = %zero to %max step %one iter_args(%arg = %input) -> i32 {
        %sum = arith.addi %input, %input : i32
        scf.cond_br %cond, ^bb1(%sum : i32), ^bb2(%sum : i32)

        ^bb1(%sum : i32):
            // scf.break terminator returns %sum and passes control back to parent op, i.e. the scf.breakable_for op
            scf.break %sum : i32
        ^bb2(%sum : i32):
            // scf.yield terminator yields %sum for the next iteration
            scf.yield %sum : i32
    }
    return %result : i32
}

Using multiple regions:

func.func @foo(%input : i32, %max : index, %threshold : index) -> i32{
    %zero = arith.constant 0 : index
    %one = arith.constant 1 : index
    %result = scf.breakable_for %i = %zero to %max step %one iter_args(%arg = %input) -> i32 {
        %sum = arith.addi %input, %input 
        // Create breakable_if region, return %sum to parent operation
        scf.breakable_if %cond {
            // Terminator for breakable_if
            scf.break %sum : i32
        }
        // return %sum for the next iteration
        scf.yield %sum : i32
    } 
    return %result : i32 
} 

Design considerations

References

github-actions[bot] commented 1 month ago

This issue has 1 outstanding TODOs:

This comment was autogenerated by todo-backlinks