Closed KavithaTipturMadhu closed 2 years ago
The above example expects three changes to the mlir code: new datatype vnni_bf16
and two new arith operations:
arith.vnni_mulf
which involves a vnni_bf16
type operand and a bf16
operand and produces vnni_bf16
type result (tuple).arith.vnni_addf
which operate on two vnni_bf16
operands and results in a vnni_bf16
type.4x8xbf16
to 2x8xvnni_bf16
format on daThanks @KavithaTipturMadhu, looking good!
A few questions:
matmultpp
accepts tensor<8x4xbf16>
for %B
but passes as tensor<2x8xvnni_bf16>
to linalg.generic
without a cast, that's probably going to break basic validation. I think the function should only accept vnni_bf16
typed tensors.bf16
and vnni_bf16
is a relayout. We already have an extension op linalgx.relayout
, could that be used instead?tuple<bf16, bf16>
as a way to tell other developers (that don't know what VNNI is) what it actually means, but I'm still not sure we ought to be using tuple
at all. Let's keep these types for now and ask the community what they think. packed_bf16
(or pbf_16
) and make that a standard type, then arith
will work out of the box. This may be worth an RFC upstream.arith
with the VNNI ops will work for now, but it doesn't scale. We don't want to update all arith
upstream ops with some custom type. If we can't get this type in a generic dialect, we may have to only operate in it using our dialects, but that's for later. For now, we do an extension like we did before.I agree with @rengolin that the datatype for now is the best way forward. We are right now working on TPPs with support for ARM's MMLA instructions ( https://github.com/libxsmm/libxsmm/tree/sve256_support_bfmla ) and this will even trigger more of these weird formats. We plan to figure out a scalable-way of expressing them in Oct'22 inside the TPP spec. So it's good that we have this, but ultimately we want and must work with the community how we express the reorders in a platform agnostic way.
BTW... some side note: we play it "dirty" when writing C TPP code: when doing stuff which is in arith
, means eltwise operations, etc., we just don't expose the VNNI Layout to them, as they don't care as the operation is eltwise.... so perhaps a similar cast trick can be used.
one comment on the IR above:
// Call kernel.
%C = arith.constant dense<0.0> : tensor<2x4xvnni_bf16>
%0 = call @matmultpp(%da_cast, %db, %C)
: (tensor<2x8xvnni_bf16>, tensor<8x4xbf16>, tensor<2x4xvnni_bf16>) -> tensor<2x4xvnni_bf16>
In col-major: The A matrix is in VNNI format, B is in regular col-major and C would be in regular col-major. It seems here the output of the matmultpp
is in VNNI format?
BTW... some side note: we play it "dirty" when writing C TPP code: when doing stuff which is in
arith
, means eltwise operations, etc., we just don't expose the VNNI Layout to them, as they don't care as the operation is eltwise.... so perhaps a similar cast trick can be used.
We discussed this, but the problem here is the relayout.
One way to trick MLIR into doing arith.mulf
on a vnni_bf16
while still thinking is a bf16
is to relayout and cast, but we need to get the shapes right, or the maps on linalg.generic
won't work.
Either
linalgx.relayout ins(%0: tensor<4x8xbf16>) outs(%1: tensor<2x8xvnni_bf16>)
// Wrong shape?
tensor.cast ins(%1: tensor<2x8xvnni_bf16>) outs(%2: tensor<4x8xbf16>)
or
linalgx.relayout ins(%0: tensor<4x8xbf16>) outs(%1: tensor<2x8xvnni_bf16>)
// Wrong rank?
tensor.cast ins(%1: tensor<2x8xvnni_bf16>) outs(%2: tensor<2x8x2xbf16>)
seems wrong.
In C code you can be opaque, if you take in pointers and "know" the layout. But in MLIR, the module needs to validate the shapes of all tensors everywhere and they must match some basic standards.
one comment on the IR above:
// Call kernel. %C = arith.constant dense<0.0> : tensor<2x4xvnni_bf16> %0 = call @matmultpp(%da_cast, %db, %C) : (tensor<2x8xvnni_bf16>, tensor<8x4xbf16>, tensor<2x4xvnni_bf16>) -> tensor<2x4xvnni_bf16>
In col-major: The A matrix is in VNNI format, B is in regular col-major and C would be in regular col-major. It seems here the output of the
matmultpp
is in VNNI format?
I have updated the example such that the result matrix is in flat format. This brings us to two important considerations:
arith.vnni_addf now has addition semantics that operates on a vnni_bf16
token and a bf16
token and results in a bf16
token. This cannot work because the expectation is eachvnni_bf16
token is made of 2 bf16
tokens, and broadcast semantics can be followed to return a vnni_bf16
result. Unless we provide extract operations on vnni_bf16
, we cannot operate on bf16
scalars. This will require additional checks in lowering linalg.generic to matmul.
We will need to circumvent shape checks at linalg level for result or add support for additional checks when interoperating between the two formats vnni_bf16
and bf16
.
What are your thoughts about these @alheinecke, @rengolin , @chelini ?
One thing we want to achieve is layout transformation, for example from <4x8xbf16>
to tensor<2x8xpbf16>
. For this, the intuitive way is to use linalgx.relayout to generate <2x8x2xbf16>
from <4x8xbf16>
and then casting the tensor <2x8x2xbf16>
to <2x8xpbf16>
. The cast operation, however, expects the same element type in both tensors and compatible shape for casting to go through. Making this work will require changes in TensorOps as well as memrefs, as ideally this cast is valid between memrefs as well. Alternatively, the relayout operation could capture the casting op as well in itself since it is a custom op anyway. Upstreaming the relayout op becomes tricky then. Thoughts? @alheinecke @rengolin @chelini
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0 * 2 + d2, d1)>
#map4 = affine_map<(d0, d1 ,d2) -> (d0, d1, d2)>
module {
func.func @matmultpp(%A: tensor<2x8xpbf16>,
%B: tensor<8x4xbf16>, %C: tensor<2x4xpbf16>) -> tensor<2x4xpbf16> attributes {llvm.emit_c_interface} {
%D = linalg.generic {indexing_maps = [#map0, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%A, %B: tensor<2x8xpbf16>, tensor<8x4xbf16>) outs(%C: tensor<2x4xpbf16>) {
^bb0(%a: pbf16, %b: bf16, %c: pbf16):
%0 = arith.mulf %a, %b : pbf16
%1 = arith.addf %c, %0 : pbf16
linalg.yield %1 : pbf16
} -> tensor<2x4xpbf16>
return %D : tensor<2x4xpbf16>
}
func.func @entry() {
%c0 = arith.constant 0 : index
%da = arith.constant dense<[
[ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1 ],
[ 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2 ],
[ 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3 ],
[ 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4 ]
]> : tensor<4x8xbf16>
%0 = linalg.init_tensor [2,8,2] : tensor<2x8x2xbf16>
%da_relayout = linalgx.relayout ins(%da:tensor <4x8xbf16>, #map3) outs(%0:tensor <2x8x2xbf16>, #map4)->tensor<2x8x2xbf16>
%da_cast = tensor.cast %da_relayout:tensor <2x8x2xbf16> to tensor <2x8xpbf16>
%d1 = arith.pbf16_constant -1.0:bf16 , -2.0:bf16 : pbf16
%da_reversecast = tensor.cast %da_cast : tensor<2x8xpbf16> to tensor<2x8x2xbf16>
%v0 = vector.transfer_read %da_reversecast[%c0, %c0, %c0], %d1: tensor<2x8xpbf16>, vector<2x8x2xbf16>
%f1 = arith.extf %v0:vector<2x8x2xbf16> to vector<2x8x2xf32>
vector.print %f1:vector<2x8x2xf32>
%db = arith.constant dense<[
[ 10.1, 11.1, 12.1, 13.1 ],
[ 10.2, 11.2, 12.2, 13.2 ],
[ 10.3, 11.3, 12.3, 13.3 ],
[ 10.4, 11.4, 12.4, 13.4 ],
[ 10.5, 11.5, 12.5, 13.5 ],
[ 10.6, 11.6, 12.6, 13.6 ],
[ 10.7, 11.7, 12.7, 13.7 ],
[ 10.8, 11.8, 12.8, 13.8 ]
]> : tensor<8x4xbf16>
// Call kernel.
%C = arith.constant dense<0.0> : tensor<4x4xbf16>
%0 = call @matmultpp(%da, %db, %C)
: (tensor<4x8xbf16>, tensor<8x4xbf16>, tensor<4x4xbf16>) -> tensor<4x4xbf16>
//
// CHECK:( ( 388, 426, 462, 500 ), ( 396, 434, 472, 510 ), ( 406, 444, 484, 520 ), ( 414, 454, 492, 532 ) )
//
%d1 = arith.constant -1.0 : bf16
%v0 = vector.transfer_read %0[%c0, %c0], %d1 : tensor<4x4xbf16>, vector<4x4xbf16>
%f1 = arith.extf %v0: vector<4x4xbf16> to vector<4x4xf32>
vector.print %f1 : vector<4x4xf32>
return
}
}
@rengolin Here's the example we discussed today. I would like your input on the use of pbf16 type in linalg.generic.
Next, we need to make sure we can get bf16 through an MLP layer with Relu, calling TPP.