onnx / onnx-mlir

Representation and Reference Lowering of ONNX Models in MLIR Compiler Infrastructure
Apache License 2.0
775 stars 322 forks source link

Added support to generate OpenMP parallel construct clauses, at this time for num_threads and proc_bind #2944

Closed AlexandreEichenberger closed 2 months ago

AlexandreEichenberger commented 2 months ago

Added support to generate OpenMP parallel construct with num_threads and proc_bind clause.

First I added two optional parameters to the krnl.parallel operation:

      %loop_block, %loop_local = krnl.block %0 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
      krnl.parallel(%loop_block), num_threads(%c8_i32) {proc_bind = "spread"} : !krnl.loop
      krnl.iterate(%loop_block) with (%0 -> %arg1 = 0 to 16384){
        %1 = krnl.get_induction_var_value(%loop_block) : (!krnl.loop) -> index
        %2 = vector.load %reshape[%1] : memref<16384xf32>, vector<32xf32>
        %3 = vector.load %reshape_2[%1] : memref<16384xf32>, vector<32xf32>
        %4 = arith.addf %2, %3 : vector<32xf32>
        vector.store %4, %reshape_4[%1] : memref<16384xf32>, vector<32xf32>
      }

which allows the user to associate parallel loops with an optional num_threads or proc_bind to the create.krnl.parallel builder.

When lowering to affine (or if generating affine or scf parallel operation), we then insert inside the loop a KrnlParallelClauseOp, which takes one mandatory value (the loop index), to identify the parallel loop targeted by the clause, and the optional num_threads (a value) and the proc_bind (a string).

  affine.parallel (%arg1) = (0) to (16384) step (32) {
    %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32>
    %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32>
    %2 = arith.addf %0, %1 : vector<32xf32>
    vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32>
    affine.for %arg2 = 0 to 1 {
    }
    krnl.parallel_clause(%arg1), num_threads(%c8_i32) {proc_bind = "spread"} : index
  }

After the parallel constructs are lowered to OpenMP construct, a simple pass (createProcessKrnlParallelClausePass) identify the KrnlParallelClauseOp, locate its enclosing omp.parallel construct, and migrate the clauses to the OpenMP constructs.

  omp.parallel num_threads(%c8_i32 : i32) proc_bind(spread) {
    omp.wsloop {
      omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) {
        memref.alloca_scope  {
          %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32>
          %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32>
          %2 = arith.addf %0, %1 : vector<32xf32>
          vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32>
        }
        omp.yield
      }
      omp.terminator
    }
    omp.terminator
  }

Added 2 mlir lit test files

jenkins-droid commented 2 months ago

Jenkins Linux s390x Build #15665 [push] Added support to generat... started at 09:54

jenkins-droid commented 2 months ago

Jenkins Linux amd64 Build #15662 [push] Added support to generat... started at 08:54

jenkins-droid commented 2 months ago

Jenkins Linux ppc64le Build #14692 [push] Added support to generat... started at 10:05

jenkins-droid commented 2 months ago

Jenkins Linux amd64 Build #15662 [push] Added support to generat... passed after 1 hr 6 min

jenkins-droid commented 2 months ago

Jenkins Linux s390x Build #15665 [push] Added support to generat... passed after 1 hr 39 min

jenkins-droid commented 2 months ago

Jenkins Linux ppc64le Build #14692 [push] Added support to generat... passed after 2 hr 3 min