ShardingPropagation -> Part of what's present in ShardingCompiler currently -- basically, this is the purpose for __compile__. This is what propagates the sharding throughout the function.
GraphSplitter -> This PR (see below)
ShardingSeparation -> This is also present in ShardingCompiler, and is mostly what __jit__ currently executes.
This PR adds the GraphSplitter pass that we can use to split computations based on
certain conditions (such as operations that require that a given dimension is not sharded).
It also builds out a list of computation specs that we can use to build an execution graph afterwards.
Each of these stages might have further fan-out depending on the subsequent ShardingSeparation pass.
For compilation passes we'll have:
ShardingPropagation -> Part of what's present in ShardingCompiler currently -- basically, this is the purpose for
__compile__
. This is what propagates the sharding throughout the function. GraphSplitter -> This PR (see below) ShardingSeparation -> This is also present in ShardingCompiler, and is mostly what__jit__
currently executes.This PR adds the GraphSplitter pass that we can use to split computations based on certain conditions (such as operations that require that a given dimension is not sharded).
It also builds out a list of computation specs that we can use to build an execution graph afterwards. Each of these stages might have further fan-out depending on the subsequent ShardingSeparation pass.