plaidml / tpp-mlir

TPP experimentation on MLIR for linear algebra
https://arxiv.org/abs/2404.15204
Other
111 stars 31 forks source link

Extend mlir-gen tool for linalg named ops #933 #943

Closed shahidact closed 4 months ago

shahidact commented 4 months ago

Mlir-Gen tool currently generates linalg generic ops for basic MLP operators. We want to extend it to also generate linalg named ops whenever required based on command line option.

This is the 1st PR to a larger goal to ensure we have enough test coverage which passes through TPP-OPT pipeline.

shahidact commented 4 months ago

Could you also extend mlir-gen-(matmul|fc).mlir tests? One test case with named ops for each should be enough.

Earlier I had tried to cover this case too but landed in cases as shown below which had irregular tile sizes as "tiles=64,48,64" and it seems it is not a priority for now. Moreover, it requires to do as DAG matching using "reduce_batch_matmul -> 64x48xbf16 -> broadcast -> 2x16x64x48xbf16".

func.func @entry(%arg0: tensor<2x36x64x64xbf16>, %arg1: tensor<16x36x64x48xbf16>, %arg2: tensor<2x16x64x48xbf16>) -> tensor<2x16x64x48xbf16> { %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x36x64x64xbf16>, tensor<16x36x64x48xbf16>) outs(%arg2 : tensor<2x16x64x48xbf16>) { ^bb0(%in: bf16, %in_0: bf16, %out: bf16): %1 = arith.mulf %in, %in_0 : bf16 %2 = arith.addf %out, %1 : bf16 linalg.yield %2 : bf16 } -> tensor<2x16x64x48xbf16> return %0 : tensor<2x16x64x48xbf16> }

adam-smnk commented 4 months ago

Earlier I had tried to cover this case too but landed in cases as shown below which had irregular tile sizes as "tiles=64,48,64" and it seems it is not a priority for now. Moreover, it requires to do as DAG matching using "reduce_batch_matmul -> 64x48xbf16 -> broadcast -> 2x16x64x48xbf16".

These two tests mlir-gen-(matmul|fc).mlir are not integration in the sense that we run them through tpp-run. They only CHECK generated IR to ensure it is as we expect it to be i.e., ensure that the IR is valid, to catch possible mlir-gen regressions etc.