tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
76 stars 13 forks source link

Enable conversion of all_reduce and GSPMD custom_op into TTIR dialect #1351

Open wooseokTT opened 2 days ago

wooseokTT commented 2 days ago
  1. TT_Reduce_Type is created to share compution type with TTNN dialect
  2. AllReduceOp in TTIR is introdcued to accomodate stableHLO all_reduce op
  3. MeshShardOp in TTIR is introduced to capture GSPMD custom sharding
  4. Realistic test cases are added from JAX/PJRT output

Current verion of importing is targetting GSPMD input, but our future plans mainly focus on supporting Shardy-based JAX/PJRT output.