This pass needs to contain the core logic of where/how to automatically insert CCL ops, I think a bit of prototyping should be done because it feels like there could be a generic OpInterface that TTNN ops implement to denote which dimensions they require to be gathered if sharded across multiple devices.
At first this pass can just special case the ops as there are really only a few categories of ops we need to worry about
Matmul: If inner dim of either input is split between device an all_gather must be inserted to gather respective inner dimension.
Reduce: If reduce dim spans multi-device either a reduce_scatter must replace this operation
TMs: Likely need an all_gather, need to handle on a case-by-case basis / more exploration needs to be done here.
Depends on:
This pass needs to contain the core logic of where/how to automatically insert CCL ops, I think a bit of prototyping should be done because it feels like there could be a generic OpInterface that TTNN ops implement to denote which dimensions they require to be gathered if sharded across multiple devices.
At first this pass can just special case the ops as there are really only a few categories of ops we need to worry about