elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

feat: add graph splitter for all-gather/all-reduce operations #1545

Closed polvalente closed 1 month ago

polvalente commented 1 month ago

For compilation passes we'll have:

ShardingPropagation -> Part of what's present in ShardingCompiler currently -- basically, this is the purpose for __compile__. This is what propagates the sharding throughout the function. GraphSplitter -> This PR (see below) ShardingSeparation -> This is also present in ShardingCompiler, and is mostly what __jit__ currently executes.

This PR adds the GraphSplitter pass that we can use to split computations based on certain conditions (such as operations that require that a given dimension is not sharded).

It also builds out a list of computation specs that we can use to build an execution graph afterwards. Each of these stages might have further fan-out depending on the subsequent ShardingSeparation pass.