Open nyck33 opened 5 months ago
My best guess is that it is not like GEMM but is element-wise multiplication.
Based on the code snippet, the input and output shapes for variables `c`, `d`, `e`, and `f` in the `multiply_transpose` function are as follows:
1. **Variable `c`**:
- Input: `a` and `b` are both <2, 3>.
- Operation: transpose(a) results in <3, 2>, transpose(b) results in <3, 2>.
- Output: Assuming element-wise multiplication, `c` is <3, 2>.
2. **Variable `d`**:
- Input: `b` and `a` are both <2, 3>.
- Operation: Similar to `c`, both get transposed to <3, 2>.
- Output: `d` will also be <3, 2>, following the same logic as `c`.
3. **Variable `e`**:
- Input: `c` and `d` are both <3, 2>.
- Operation: transpose(c) and transpose(d) both result in <2, 3>.
- Output: Assuming element-wise multiplication, `e` is <2, 3>.
4. **Variable `f`**:
- Input: `a` is <2, 3>, and `c` is <3, 2>.
- Operation: transpose(a) is <3, 2>, transpose(c) is <2, 3>.
- Output: The code comment indicates a shape inference error, which aligns with traditional matrix multiplication rules, but it might be contextually valid within MLIR's shape inference capabilities.
It's important to note that the output shapes depend on the interpretation of the multiplication operation (element-wise or matrix multiplication) and the capabilities of the MLIR framework in handling dynamic shapes and inferring types.
It is an element-wise multiplication. You can see the documentation in the ops.td file : https://github.com/llvm/llvm-project/blob/f070f61fc0b4c731a031fbe9f8e7360c337791c4/mlir/examples/toy/Ch3/include/toy/Ops.td#L207
If mat A is 23 as is B then both transposed are 32 so the shapes seem incompatible for the rule that num cols in mat A must match num cols in mat B.
How are these valid except for
var f
?