llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
27.97k stars 11.54k forks source link

How are these matmuls in the MLIR toy example valid? #86378

Open nyck33 opened 5 months ago

nyck33 commented 5 months ago
# User defined generic function that operates on unknown shaped arguments.
def multiply_transpose(a, b) {
  return transpose(a) * transpose(b);
}

def main() {
  # Define a variable `a` with shape <2, 3>, initialized with the literal value.
  # The shape is inferred from the supplied literal.
  var a = [[1, 2, 3], [4, 5, 6]];
  # b is identical to a, the literal array is implicitly reshaped: defining new
  # variables is the way to reshape arrays (element count in literal must match
  # the size of specified shape).
  var b<2, 3> = [1, 2, 3, 4, 5, 6];

  # This call will specialize `multiply_transpose` with <2, 3> for both
  # arguments and deduce a return type of <3, 2> in initialization of `c`.
  var c = multiply_transpose(a, b);
  # A second call to `multiply_transpose` with <2, 3> for both arguments will
  # reuse the previously specialized and inferred version and return `<3, 2>`
  var d = multiply_transpose(b, a);
  # A new call with `<3, 2>` for both dimension will trigger another
  # specialization of `multiply_transpose`.
  var e = multiply_transpose(c, d);
  # Finally, calling into `multiply_transpose` with incompatible shapes
  # (<2, 3> and <3, 2>) will trigger a shape inference error.
  var f = multiply_transpose(a, c);
}

If mat A is 23 as is B then both transposed are 32 so the shapes seem incompatible for the rule that num cols in mat A must match num cols in mat B.

How are these valid except for var f?

nyck33 commented 5 months ago

My best guess is that it is not like GEMM but is element-wise multiplication.

Based on the code snippet, the input and output shapes for variables `c`, `d`, `e`, and `f` in the `multiply_transpose` function are as follows:

1. **Variable `c`**:
   - Input: `a` and `b` are both <2, 3>.
   - Operation: transpose(a) results in <3, 2>, transpose(b) results in <3, 2>.
   - Output: Assuming element-wise multiplication, `c` is <3, 2>.

2. **Variable `d`**:
   - Input: `b` and `a` are both <2, 3>.
   - Operation: Similar to `c`, both get transposed to <3, 2>.
   - Output: `d` will also be <3, 2>, following the same logic as `c`.

3. **Variable `e`**:
   - Input: `c` and `d` are both <3, 2>.
   - Operation: transpose(c) and transpose(d) both result in <2, 3>.
   - Output: Assuming element-wise multiplication, `e` is <2, 3>.

4. **Variable `f`**:
   - Input: `a` is <2, 3>, and `c` is <3, 2>.
   - Operation: transpose(a) is <3, 2>, transpose(c) is <2, 3>.
   - Output: The code comment indicates a shape inference error, which aligns with traditional matrix multiplication rules, but it might be contextually valid within MLIR's shape inference capabilities.

It's important to note that the output shapes depend on the interpretation of the multiplication operation (element-wise or matrix multiplication) and the capabilities of the MLIR framework in handling dynamic shapes and inferring types.
DarshanRamakant commented 1 month ago

It is an element-wise multiplication. You can see the documentation in the ops.td file : https://github.com/llvm/llvm-project/blob/f070f61fc0b4c731a031fbe9f8e7360c337791c4/mlir/examples/toy/Ch3/include/toy/Ops.td#L207