plaidml / tpp-mlir

TPP experimentation on MLIR for linear algebra
https://arxiv.org/abs/2404.15204
Other
113 stars 31 forks source link

VNNI BF16 Layout matmul example with a new datatype (vnni_bf16) corresponding to tuple (<bf16,bf16>) #17

Closed KavithaTipturMadhu closed 2 years ago

KavithaTipturMadhu commented 2 years ago

#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
 func.func @matmultpp(%A: tensor<2x8xvnni_bf16>,
          %B: tensor<8x4xbf16>, %C: tensor<2x4xvnni_bf16>) -> tensor<2x4xvnni_bf16> attributes {llvm.emit_c_interface} {
    %D = linalg.generic {indexing_maps = [#map0, #map1, #map2],
                         iterator_types = ["parallel", "parallel", "reduction"]}
    ins(%A, %B: tensor<2x8xvnni_bf16>, tensor<8x4xbf16>) outs(%C: tensor<4x4xbf16>) {
      ^bb0(%a: vnni_bf16, %b: bf16, %c: vnni_bf16):
        %0 = arith.vnni_mulf %a, %b : vnni_bf16
        %1 = arith.vnni_addf %c, %0 : bf16
        linalg.yield %1 : bf16
    } -> tensor<4x4xbf16>
    return %D : tensor<4x4xbf16>
  }

  func.func @entry() {
    %c0 = arith.constant 0 : index

    // Initialize various matrices.
    %da = arith.constant dense<[
        [ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1 ],
        [ 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2 ],
        [ 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3 ],
        [ 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4 ]
    ]> : tensor<4x8xbf16>
    %da_cast = linalgx.reorder %da: tensor<4x8xbf16> to tensor<2x8xvnni_bf16>
    %db = arith.constant dense<[
        [ 10.1, 11.1, 12.1, 13.1 ],
        [ 10.2, 11.2, 12.2, 13.2 ],
        [ 10.3, 11.3, 12.3, 13.3 ],
        [ 10.4, 11.4, 12.4, 13.4 ],
        [ 10.5, 11.5, 12.5, 13.5 ],
        [ 10.6, 11.6, 12.6, 13.6 ],
        [ 10.7, 11.7, 12.7, 13.7 ],
        [ 10.8, 11.8, 12.8, 13.8 ]
    ]> : tensor<8x4xbf16>

    // Call kernel.
    %C = arith.constant dense<0.0> : tensor<2x4xvnni_bf16>
    %0 = call @matmultpp(%da_cast, %db, %C)
       : (tensor<2x8xvnni_bf16>, tensor<8x4xbf16>, tensor<4x4xbf16>) -> tensor<4x4xbf16>

    return
  }
}
`
KavithaTipturMadhu commented 2 years ago

The above example expects three changes to the mlir code: new datatype vnni_bf16 and two new arith operations:

rengolin commented 2 years ago

Thanks @KavithaTipturMadhu, looking good!

A few questions:

  1. Your matmultpp accepts tensor<8x4xbf16> for %B but passes as tensor<2x8xvnni_bf16> to linalg.generic without a cast, that's probably going to break basic validation. I think the function should only accept vnni_bf16 typed tensors.
  2. "cast" is usually a name given to operations where there are no memory changes, only interpretation, while the "cast" between bf16 and vnni_bf16 is a relayout. We already have an extension op linalgx.relayout, could that be used instead?
  3. As we discussed, adding the VNNI types is probably the fastest way to show this works. I proposed we look at tuple<bf16, bf16> as a way to tell other developers (that don't know what VNNI is) what it actually means, but I'm still not sure we ought to be using tuple at all. Let's keep these types for now and ask the community what they think.
  4. Perhaps call it packed_bf16 (or pbf_16) and make that a standard type, then arith will work out of the box. This may be worth an RFC upstream.
  5. In the same way, extending arith with the VNNI ops will work for now, but it doesn't scale. We don't want to update all arith upstream ops with some custom type. If we can't get this type in a generic dialect, we may have to only operate in it using our dialects, but that's for later. For now, we do an extension like we did before.
alheinecke commented 2 years ago

I agree with @rengolin that the datatype for now is the best way forward. We are right now working on TPPs with support for ARM's MMLA instructions ( https://github.com/libxsmm/libxsmm/tree/sve256_support_bfmla ) and this will even trigger more of these weird formats. We plan to figure out a scalable-way of expressing them in Oct'22 inside the TPP spec. So it's good that we have this, but ultimately we want and must work with the community how we express the reorders in a platform agnostic way.

alheinecke commented 2 years ago

BTW... some side note: we play it "dirty" when writing C TPP code: when doing stuff which is in arith, means eltwise operations, etc., we just don't expose the VNNI Layout to them, as they don't care as the operation is eltwise.... so perhaps a similar cast trick can be used.

alheinecke commented 2 years ago

one comment on the IR above:

    // Call kernel.
    %C = arith.constant dense<0.0> : tensor<2x4xvnni_bf16>
    %0 = call @matmultpp(%da_cast, %db, %C)
       : (tensor<2x8xvnni_bf16>, tensor<8x4xbf16>, tensor<2x4xvnni_bf16>) -> tensor<2x4xvnni_bf16>

In col-major: The A matrix is in VNNI format, B is in regular col-major and C would be in regular col-major. It seems here the output of the matmultppis in VNNI format?

rengolin commented 2 years ago

BTW... some side note: we play it "dirty" when writing C TPP code: when doing stuff which is in arith, means eltwise operations, etc., we just don't expose the VNNI Layout to them, as they don't care as the operation is eltwise.... so perhaps a similar cast trick can be used.

We discussed this, but the problem here is the relayout.

One way to trick MLIR into doing arith.mulf on a vnni_bf16 while still thinking is a bf16 is to relayout and cast, but we need to get the shapes right, or the maps on linalg.generic won't work.

Either

    linalgx.relayout ins(%0: tensor<4x8xbf16>) outs(%1: tensor<2x8xvnni_bf16>)
    // Wrong shape?
    tensor.cast ins(%1: tensor<2x8xvnni_bf16>) outs(%2: tensor<4x8xbf16>)

or

    linalgx.relayout ins(%0: tensor<4x8xbf16>) outs(%1: tensor<2x8xvnni_bf16>)
    // Wrong rank?
    tensor.cast ins(%1: tensor<2x8xvnni_bf16>) outs(%2: tensor<2x8x2xbf16>)

seems wrong.

In C code you can be opaque, if you take in pointers and "know" the layout. But in MLIR, the module needs to validate the shapes of all tensors everywhere and they must match some basic standards.

KavithaTipturMadhu commented 2 years ago

one comment on the IR above:

    // Call kernel.
    %C = arith.constant dense<0.0> : tensor<2x4xvnni_bf16>
    %0 = call @matmultpp(%da_cast, %db, %C)
       : (tensor<2x8xvnni_bf16>, tensor<8x4xbf16>, tensor<2x4xvnni_bf16>) -> tensor<2x4xvnni_bf16>

In col-major: The A matrix is in VNNI format, B is in regular col-major and C would be in regular col-major. It seems here the output of the matmultppis in VNNI format?

I have updated the example such that the result matrix is in flat format. This brings us to two important considerations:

  1. arith.vnni_addf now has addition semantics that operates on a vnni_bf16 token and a bf16 token and results in a bf16 token. This cannot work because the expectation is eachvnni_bf16 token is made of 2 bf16 tokens, and broadcast semantics can be followed to return a vnni_bf16 result. Unless we provide extract operations on vnni_bf16, we cannot operate on bf16 scalars. This will require additional checks in lowering linalg.generic to matmul.

  2. We will need to circumvent shape checks at linalg level for result or add support for additional checks when interoperating between the two formats vnni_bf16 and bf16.

What are your thoughts about these @alheinecke, @rengolin , @chelini ?

KavithaTipturMadhu commented 2 years ago

One thing we want to achieve is layout transformation, for example from <4x8xbf16> to tensor<2x8xpbf16>. For this, the intuitive way is to use linalgx.relayout to generate <2x8x2xbf16> from <4x8xbf16> and then casting the tensor <2x8x2xbf16> to <2x8xpbf16>. The cast operation, however, expects the same element type in both tensors and compatible shape for casting to go through. Making this work will require changes in TensorOps as well as memrefs, as ideally this cast is valid between memrefs as well. Alternatively, the relayout operation could capture the casting op as well in itself since it is a custom op anyway. Upstreaming the relayout op becomes tricky then. Thoughts? @alheinecke @rengolin @chelini

KavithaTipturMadhu commented 2 years ago
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0 * 2 + d2, d1)>
#map4 = affine_map<(d0, d1 ,d2) -> (d0, d1, d2)>
module {
 func.func @matmultpp(%A: tensor<2x8xpbf16>,
          %B: tensor<8x4xbf16>, %C: tensor<2x4xpbf16>) -> tensor<2x4xpbf16> attributes {llvm.emit_c_interface} {
    %D = linalg.generic {indexing_maps = [#map0, #map1, #map2],
                         iterator_types = ["parallel", "parallel", "reduction"]}
    ins(%A, %B: tensor<2x8xpbf16>, tensor<8x4xbf16>) outs(%C: tensor<2x4xpbf16>) {
      ^bb0(%a: pbf16, %b: bf16, %c: pbf16):
        %0 = arith.mulf %a, %b : pbf16
        %1 = arith.addf %c, %0 : pbf16
        linalg.yield %1 : pbf16
    } -> tensor<2x4xpbf16>
    return %D : tensor<2x4xpbf16>
  }

  func.func @entry() {
    %c0 = arith.constant 0 : index

    %da = arith.constant dense<[
        [ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1 ],
        [ 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2 ],
        [ 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3 ],
        [ 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4 ]
    ]> : tensor<4x8xbf16>
    %0 = linalg.init_tensor [2,8,2] : tensor<2x8x2xbf16>
    %da_relayout = linalgx.relayout ins(%da:tensor <4x8xbf16>, #map3) outs(%0:tensor <2x8x2xbf16>, #map4)->tensor<2x8x2xbf16>
    %da_cast = tensor.cast %da_relayout:tensor <2x8x2xbf16> to tensor <2x8xpbf16>
    %d1 = arith.pbf16_constant -1.0:bf16 , -2.0:bf16 : pbf16
    %da_reversecast = tensor.cast %da_cast : tensor<2x8xpbf16> to tensor<2x8x2xbf16>
    %v0 = vector.transfer_read %da_reversecast[%c0, %c0, %c0], %d1: tensor<2x8xpbf16>, vector<2x8x2xbf16>
    %f1 = arith.extf %v0:vector<2x8x2xbf16> to vector<2x8x2xf32>
    vector.print %f1:vector<2x8x2xf32>

    %db = arith.constant dense<[
        [ 10.1, 11.1, 12.1, 13.1 ],
        [ 10.2, 11.2, 12.2, 13.2 ],
        [ 10.3, 11.3, 12.3, 13.3 ],
        [ 10.4, 11.4, 12.4, 13.4 ],
        [ 10.5, 11.5, 12.5, 13.5 ],
        [ 10.6, 11.6, 12.6, 13.6 ],
        [ 10.7, 11.7, 12.7, 13.7 ],
        [ 10.8, 11.8, 12.8, 13.8 ]
    ]> : tensor<8x4xbf16>

    // Call kernel.
    %C = arith.constant dense<0.0> : tensor<4x4xbf16>
    %0 = call @matmultpp(%da, %db, %C)
       : (tensor<4x8xbf16>, tensor<8x4xbf16>, tensor<4x4xbf16>) -> tensor<4x4xbf16>

    //
    // CHECK:( ( 388, 426, 462, 500 ), ( 396, 434, 472, 510 ), ( 406, 444, 484, 520 ), ( 414, 454, 492, 532 ) )
    //
    %d1 = arith.constant -1.0 : bf16
    %v0 = vector.transfer_read %0[%c0, %c0], %d1 : tensor<4x4xbf16>, vector<4x4xbf16>
    %f1 = arith.extf %v0: vector<4x4xbf16> to vector<4x4xf32>
     vector.print %f1 : vector<4x4xf32>

    return
  }
}

@rengolin Here's the example we discussed today. I would like your input on the use of pbf16 type in linalg.generic.

rengolin commented 2 years ago

24 has been merged with a working version of this. Thanks Kavitha for driving this discussion!

Next, we need to make sure we can get bf16 through an MLP layer with Relu, calling TPP.