iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.84k stars 612 forks source link

[Epic] Add BF16 Support #13370

Open silvasean opened 1 year ago

silvasean commented 1 year ago

Task list

What happened?

Repro IR reduced out of a large language model: gist

Reproduce with:

iree-compile ir21.mlir --iree-hal-target-backends=cuda --iree-input-type=mhlo --iree-hal-cuda-llvm-target-arch=sm_80 -o ir21.vmfb

(this seems like the type of issue that will be obvious what a smaller test case is as debugging proceeds -- let me know if not the case and I can bisect the test case itself)

What component(s) does this issue relate to?

Compiler

Version information

iree.git @ ab37989652aed11f7f46498c09b9ac515c83eaa3

manishucsd commented 1 year ago

How to obtain the ir that is failing ir21.mlir from the gist?

is gist == ir21.mlir

silvasean commented 1 year ago

yes it is

manishucsd commented 1 year ago

Also tagging @KoolJBlack on this bug (see comment here)

KoolJBlack commented 1 year ago

https://github.com/openxla/iree/pull/13371 should fix the issue

manishucsd commented 1 year ago

13371 should fix the issue

So you are able to compile the ir in this bug gist (= ir21.mlir) after #13371 fix?

manishucsd commented 1 year ago

I tried the following command line on iree tot and it still fails.

./tools/iree-compile ir21.mlir --iree-hal-target-backends=cuda --iree-input-type=mhlo --iree-hal-cuda-llvm-target-arch=sm_80 -o ir21.vmfb 2> temp.txt

temp.txt

KoolJBlack commented 1 year ago

Sorry, disregard comment in https://github.com/openxla/iree/issues/13370#issuecomment-1532301511. That PR was unrelated to this issue.

manishucsd commented 1 year ago

@KoolJBlack and @MaheshRavishankar is there anything else in the iree llvm-integrate commit (iree.git @ https://github.com/openxla/iree/commit/ab37989652aed11f7f46498c09b9ac515c83eaa3) tagged in this bug could be causing this issue?

MaheshRavishankar commented 1 year ago

Could you do dump after all and narrow down the failure a bit?

manishucsd commented 1 year ago

Could you do dump after all and narrow down the failure a bit?

I ran the following command to debug this:

 ./tools/iree-compile ir21.mlir --iree-hal-target-backends=cuda --iree-input-type=mhlo --iree-hal-cuda-llvm-target-arch=sm_80 -o ir21.vmfb  --mlir-print-ir-before-all --mlir-disable-threading --mlir-elide-elementsattrs-if-larger=4 2> dump_ir_before_after_all.mlir

dump_ir_before_after_all.mlir

The last pass that runs is iree-stream-schedule-execution

@MaheshRavishankar , Can you please help me take look at the dump_ir_before_after_all.mlir file above? Also, would you run a different command to dump the IR to narrow this model-level failure to the operation-level?

MaheshRavishankar commented 1 year ago

Ok, this is not the same error as what Sean said above. @silvasean this is a pretty big test case... is there a chance to get a smaller repro.

silvasean commented 1 year ago
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
module {
  func.func @f(%arg0: tensor<3x2048x32x64xbf16>, %arg1: tensor<2048x2048xbf16>) -> tensor<3x32x64x2048xbf16> {
    %cst = arith.constant 0.000000e+00 : bf16
    %0 = tensor.empty() : tensor<3x32x64x2048xbf16>
    %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<3x32x64x2048xbf16>) -> tensor<3x32x64x2048xbf16>
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<3x2048x32x64xbf16>, tensor<2048x2048xbf16>) outs(%1 : tensor<3x32x64x2048xbf16>) {
    ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
      %3 = arith.mulf %in, %in_0 : bf16
      %4 = arith.addf %out, %3 : bf16
      linalg.yield %4 : bf16
    } -> tensor<3x32x64x2048xbf16>
    return %2 : tensor<3x32x64x2048xbf16>
  }
}
silvasean commented 1 year ago

For the curious, this is written as an einsum in the user source code with the following equation: ABD,KDNH->KABNH

https://github.com/google/praxis/blob/55856e84395b098422c5d504f8744567b13038f4/praxis/layers/attentions.py#L842

MaheshRavishankar commented 1 year ago

Nice! Thanks!

Its failing in the LLVMVectorToGPU pass

// -----// IR Dump Before LLVMGPUVectorToGPU (iree-llvmgpu-vector-to-gpu) //----- //                                                                                                                                                                                                                               
func.func @f_dispatch_0_generic_3x32x64x2048x2048_bf16() {
  %c2 = arith.constant 2 : index
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant dense<0.000000e+00> : vector<16x16xbf16>
  %c16 = arith.constant 16 : index
  %c2048 = arith.constant 2048 : index
  %cst_0 = arith.constant 0.000000e+00 : bf16
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<3x2048x32x64xbf16>
  memref.assume_alignment %0, 64 : memref<3x2048x32x64xbf16>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<2048x2048xbf16>
  memref.assume_alignment %1, 64 : memref<2048x2048xbf16>
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<3x32x64x2048xbf16>
  memref.assume_alignment %2, 64 : memref<3x32x64x2048xbf16>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %3 = gpu.thread_id  x
  %4 = gpu.thread_id  y
  %5:3 = scf.for %arg0 = %c0 to %c2048 step %c16 iter_args(%arg1 = %cst, %arg2 = %cst, %arg3 = %cst) -> (vector<16x16xbf16>, vector<16x16xbf16>, vector<16x16xbf16>) {
    %9 = affine.apply affine_map<()[s0] -> ((s0 floordiv 64) floordiv 2)>()[%workgroup_id_x]
    %10 = affine.apply affine_map<()[s0, s1] -> (s1 * 16 + (s0 floordiv 64) * 32 - ((s0 floordiv 64) floordiv 2) * 64)>()[%workgroup_id_x, %4]
    %11 = vector.transfer_read %0[%c0, %arg0, %9, %10], %cst_0 {in_bounds = [true, true, true]} : memref<3x2048x32x64xbf16>, vector<8x1x16xbf16>
    %12 = affine.apply affine_map<()[s0] -> (s0 + 8)>()[%arg0]
    %13 = vector.transfer_read %0[%c0, %12, %9, %10], %cst_0 {in_bounds = [true, true, true]} : memref<3x2048x32x64xbf16>, vector<8x1x16xbf16>
    %14 = vector.transfer_read %0[%c1, %arg0, %9, %10], %cst_0 {in_bounds = [true, true, true]} : memref<3x2048x32x64xbf16>, vector<8x1x16xbf16>
    %15 = vector.transfer_read %0[%c1, %12, %9, %10], %cst_0 {in_bounds = [true, true, true]} : memref<3x2048x32x64xbf16>, vector<8x1x16xbf16>
    %16 = vector.transfer_read %0[%c2, %arg0, %9, %10], %cst_0 {in_bounds = [true, true, true]} : memref<3x2048x32x64xbf16>, vector<8x1x16xbf16>
    %17 = vector.transfer_read %0[%c2, %12, %9, %10], %cst_0 {in_bounds = [true, true, true]} : memref<3x2048x32x64xbf16>, vector<8x1x16xbf16>
    %18 = affine.apply affine_map<()[s0, s1] -> (s0 * 32 - (s0 floordiv 64) * 2048 + (s1 floordiv 32) * 16)>()[%workgroup_id_x, %3]
    %19 = vector.transfer_read %1[%18, %arg0], %cst_0 {in_bounds = [true, true]} : memref<2048x2048xbf16>, vector<16x8xbf16>
    %20 = vector.transfer_read %1[%18, %12], %cst_0 {in_bounds = [true, true]} : memref<2048x2048xbf16>, vector<16x8xbf16>
    %21 = vector.transpose %11, [1, 0, 2] : vector<8x1x16xbf16> to vector<1x8x16xbf16>
    %22 = vector.extract %21[0] : vector<1x8x16xbf16>
    %23 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %22, %19, %arg1 : vector<8x16xbf16>, vector<16x8xbf16> into vecto\
r<16x16xbf16>
    %24 = vector.transpose %14, [1, 0, 2] : vector<8x1x16xbf16> to vector<1x8x16xbf16>
    %25 = vector.extract %24[0] : vector<1x8x16xbf16>
    %26 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %25, %19, %arg2 : vector<8x16xbf16>, vector<16x8xbf16> into vecto\
r<16x16xbf16>
    %27 = vector.transpose %16, [1, 0, 2] : vector<8x1x16xbf16> to vector<1x8x16xbf16>
    %28 = vector.extract %27[0] : vector<1x8x16xbf16>
    %29 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %28, %19, %arg3 : vector<8x16xbf16>, vector<16x8xbf16> into vecto\
r<16x16xbf16>
    %30 = vector.transpose %13, [1, 0, 2] : vector<8x1x16xbf16> to vector<1x8x16xbf16>
    %31 = vector.extract %30[0] : vector<1x8x16xbf16>
    %32 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %31, %20, %23 : vector<8x16xbf16>, vector<16x8xbf16> into vector<\
16x16xbf16>
    %33 = vector.transpose %15, [1, 0, 2] : vector<8x1x16xbf16> to vector<1x8x16xbf16>
    %34 = vector.extract %33[0] : vector<1x8x16xbf16>
    %35 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %34, %20, %26 : vector<8x16xbf16>, vector<16x8xbf16> into vector<\
16x16xbf16>
    %36 = vector.transpose %17, [1, 0, 2] : vector<8x1x16xbf16> to vector<1x8x16xbf16>
    %37 = vector.extract %36[0] : vector<1x8x16xbf16>
    %38 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %37, %20, %29 : vector<8x16xbf16>, vector<16x8xbf16> into vector<\
16x16xbf16>
    scf.yield %32, %35, %38 : vector<16x16xbf16>, vector<16x16xbf16>, vector<16x16xbf16>
  }
  %6 = affine.apply affine_map<()[s0] -> ((s0 floordiv 64) floordiv 2)>()[%workgroup_id_x]
  %7 = affine.apply affine_map<()[s0, s1] -> (s1 * 16 + (s0 floordiv 64) * 32 - ((s0 floordiv 64) floordiv 2) * 64)>()[%workgroup_id_x, %4]
  %8 = affine.apply affine_map<()[s0, s1] -> (s0 * 32 - (s0 floordiv 64) * 2048 + (s1 floordiv 32) * 16)>()[%workgroup_id_x, %3]
  vector.transfer_write %5#2, %2[%c2, %6, %7, %8] {in_bounds = [true, true]} : vector<16x16xbf16>, memref<3x32x64x2048xbf16>
  vector.transfer_write %5#1, %2[%c1, %6, %7, %8] {in_bounds = [true, true]} : vector<16x16xbf16>, memref<3x32x64x2048xbf16>
  vector.transfer_write %5#0, %2[%c0, %6, %7, %8] {in_bounds = [true, true]} : vector<16x16xbf16>, memref<3x32x64x2048xbf16>
  return
}

This is the error


repro.mlir:6:12: error: 'gpu.subgroup_mma_constant_matrix' op operand #0 must be 8-bit signed integer or 8-bit unsigned integer or 32-bit signless integer or 16-bit float or 32-bit float, but got 'bf16'
    %cst = arith.constant 0.000000e+00 : bf16
           ^
repro.mlir:5:3: note: called from
  func.func @f(%arg0: tensor<3x2048x32x64xbf16>, %arg1: tensor<2048x2048xbf16>) -> tensor<3x32x64x2048xbf16> {
  ^
repro.mlir:6:12: note: see current operation: %4 = "gpu.subgroup_mma_constant_matrix"(%3) : (bf16) -> !gpu.mma_matrix<16x16xbf16, "COp">
    %cst = arith.constant 0.000000e+00 : bf16
           ^
MaheshRavishankar commented 1 year ago

@manishucsd you can probably take this from here.... not immediately clear to me based on the IR.

stellaraccident commented 1 year ago

So this is actually under-support for bf16?

Would have been good to have had a much clearer indication of that vs falling off a ledge.

stellaraccident commented 1 year ago

@rsuderman This is why I'd love it if we were running the regression tests on cuda.

silvasean commented 1 year ago

Would have been good to have had a much clearer indication of that vs falling off a ledge.

Curiously the original test case doesn't print anything besides the raw assertion failure in the title.

The reduced test case I posted above prints:

<unknown>:0: error: MMAMatrixType elements must be SI8, UI8, I32, F16, or F32
... assertions failure + stack trace ...

The full original repro file doesn't even give a stack trace, just the assertion failure. Curiously, if I pass --mlir-disable-threading then it does (so I guess the reduced test case just runs on one thread so gives better diagnostics?).

Mahesh's output is the output of the op verifier, not the type "verifier" (MLIR is notoriously bad about verifying types, resulting in bad diagnostics like this case). Mahesh, how did you even get to the point of the op verifier failing instead of the type verifier assertion halting compilation?

MaheshRavishankar commented 1 year ago

I think I was using a RelWithDebInfo build

rsuderman commented 1 year ago

@rsuderman This is why I'd love it if we were running the regression tests on cuda.

Have we managed to restore the nightly pjrt plugin build? I have most of the basic tooling ready such that we can run the jax tests via a github action. That was the blocker I had before I could setup our regression / conformance tests.

stellaraccident commented 1 year ago

Looks like it needs fixes around a PJRT API break.

(I think you can just bump iree manually if you need to get to head on that, but we should fix the PJRT API break to unblock it)

rsuderman commented 1 year ago

Okay, I'll see about fixing the API issues and seeing if we can get the nightly rolling again. Usually the changes are pretty straightforward. I found a way to make JAX's tests run hermetically which has been useful for triaging work, I should be able to get something similar working.

stellaraccident commented 1 year ago

This has the look of being related to the split initialization work that Skye was doing. Okwan also has some context on that. There's probably an easy way to get it going and a more complicated thing needed for true multi device.

manishucsd commented 1 year ago

We can let BF16 go through and not blocked by the MMAMatrixType::verify call to MMAMatrixType::isValidElementType , but following the discussion here, enabling BF16 will need more work than just letting it go through this check?

Also, there might also be minor bugs may hit enabling bf16 in the mma.sync path. Those should be just adding the support which pretty identical to F16 MMASYNC.

MaheshRavishankar commented 1 year ago

Is the title of the issue now "Add BF16 support for CUDA" or something similar...

allieculp commented 1 year ago

@rsuderman @manishucsd to update task list here and title as needed.

rsuderman commented 1 year ago

I added a basic task list with details with rough idea of responsibilities. Please update if they differ from my initial draft.

rsuderman commented 1 year ago

We appear to have a failure further into the pipeline for cuda. Specifically it is occurring in the stream dialect lowering. @benvanik could you investigate what is going on in iree-stream-schedule-execution?

stream_error.txt

manishucsd commented 1 year ago

Hi @stellaraccident , @MaheshRavishankar , and @pjannaty , I see this mlir::Type::isBF16().

Which makes me believe that BF16 is supported in LLVM backend. What else is missing that we will need inline ptx to support BF16 in IREE codegen backend?

benvanik commented 1 year ago

Not sure about the stream error - would need to look but would want to make sure this IR is not the result of other changes. That program is absolutely horrendous and I'd expect it to be several orders of magnitude slower than it could be if the input wasn't garbage. Assume each stream.async.load will take 1-10us and there are 476 of them. That is to say: the performance issues with this program are obvious and once it runs don't assume it's going to be representative.

pjannaty commented 1 year ago

Checked with @kushanam who will follow up.

rsuderman commented 1 year ago

stream.async.load will take 1-10us and there are 476 of them. That is to say: the performance issues with this program are obvious

Yeah, the input / output issue is the JAX "everything is an input" philosophy, so there is not much we can do to work around that. The crash is more the concern than performance right now.

MaheshRavishankar commented 1 year ago

Hi @stellaraccident , @MaheshRavishankar , and @pjannaty , I see this mlir::Type::isBF16().

Which makes me believe that BF16 is supported in LLVM backend. What else is missing that we will need inline ptx to support BF16 in IREE codegen backend?

That is in MLIR... MLIR has support for BF16... LLVM proper not so sure...

MaheshRavishankar commented 1 year ago

Hi @stellaraccident , @MaheshRavishankar , and @pjannaty , I see this mlir::Type::isBF16(). Which makes me believe that BF16 is supported in LLVM backend. What else is missing that we will need inline ptx to support BF16 in IREE codegen backend?

That is in MLIR... MLIR has support for BF16... LLVM proper not so sure...

https://github.com/llvm/llvm-project/blob/8052c1e6ebbd993439006bd996bd34b9e8d32f57/llvm/include/llvm/IR/Type.h#L57 I am assuming is BF16.... I dont know whether it is plumbed through all the way to NVPTX....

In any case, start by relaxing the verifier and see where it blows up further.

stellaraccident commented 1 year ago

https://github.com/llvm/llvm-project/blob/8052c1e6ebbd993439006bd996bd34b9e8d32f57/llvm/include/llvm/IR/Type.h#L57 I am assuming is BF16.... I dont know whether it is plumbed through all the way to NVPTX....

I believe it is but there were some instruction selection issues in LLVM that we were hitting when running through the test suite. I believe there is a patch that someone from NV was working on but I don't have any status handy.

pjannaty commented 1 year ago

This is the patch: https://reviews.llvm.org/D144911

kushanam commented 1 year ago

This is the patch: https://reviews.llvm.org/D144911

Just updated the patch with more of the review change requests.