Open silvasean opened 2 years ago
The use of higher-D tensors (or reasoning about operations at a higher dimensional iteration space) really helped for fusion purposes, especially fusing with reshapes. I see this case as being slightly different. This is a named op, where every dimension carries some significance. So having multiple batch dimensions muddles that a bit. I cant think of a reasonable code-gen strategy that will not just fold all the batch dimensions to expose the most amount of parallelism.
My preference would be to leave the batch_matmul as a 4D iterations space (batch, M, N and K). THe issue then is having linalg.tensor_expand_shape work for dynamic dimensions. Its a known issue, with a known solution. We need extra operand to the op to specify the shape of the result. This operand would mostly be generated using linalg.init_tensor operation. This is similar to what is done for linalg.generic operations where sometimes the outs is specified only for its shape.
IIRC, we want to move more towards the direction of higher-D tensors natively, so maybe the linalg.tensor_expand_shape thing isn't so pressing if we can make sure we handle higher-D inputs better.
assigned to @nicolasvasilache
We are planning to go significantly beyond the current capabilities, see: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707
In the meantime we could make the semantics more powerful for tensor.expand_shape
and memref.expand_shape
if this is still useful?
Yes this is still useful for us.
Extended Description
This came up in torch-mlir: we are lowering an (B1, B2, M, K) x (B1, B2, K, N) matmul to linalg.batch_matmul. This requires collapsing B1,B2 -> B and doing the matmul and then expanding back B -> B1,B2.
linalg.tensor_expand_shape cannot do such a reshaping. The workaround is to use a linalg.generic.
Maybe linalg.batch_matmul could allow variadic leading dimensions?
Still, it seems useful to support arbitrary expansions in linalg.tensor_expand_shape