Adds a proof-of-concept implementation for sharding as an Nx meta-compiler.
In the current proposal, we shard inputs according to an arbitrary slicing configuration, and the compiler then does its best to propagate those slices through to the output. The compiler can then build a separate {args, function, reducer} tuple for each of the output shards, where:
args: is a version of the input arguments that is sliced according to which data section is required for that specific output data section
function: is a new compilation of the input function based on the new sliced arguments
reducer: is a function that is responsible for inserting the result shard into the correct place in an output accumulator tensor.
The following example showcases how we can shard the example function into 4 separate shards. This happens because There are 2 shards for arg1, and only the first axis of arg0 is able to be sharded, due to the other 2 axes being connected to contracting axes in the dot product.
arg0_sharding = %{} # inputs are taken to be fully sharded if no specification is given
arg1_sharding = %{4 => [0..0, 1..1]}
Nx.default_backend(Nx.BinaryBackend)
fun = fn l, r ->
x = Nx.add(l, Nx.tensor([[1]]))
x = Nx.transpose(x, axes: [0, 2, 1])
y = Nx.subtract(r, 1)
y = Nx.squeeze(y, axes: [0, 1])
Nx.dot(x, [2, 1], y, [1, 0])
end
# fun = &Nx.dot(&1, [1, 2], &2, [1, 0])
# fun = &Nx.add(&1, &2)
inputs = [
Nx.iota({2, 2, 3}, type: :f32),
Nx.add(Nx.iota({1, 1, 3, 2, 2}), 10)
]
{output_holder, shards} =
Nx.Defn.jit_apply(
fun,
inputs,
compiler: Nx.Defn.ShardingCompiler,
sharding_config: [arg0_sharding, arg1_sharding],
sharding_compiler: Nx.Defn.Evaluator,
sharding_compiler_options: []
)
sharded_result =
shards
|> Task.async_stream(fn {arg, fun, caster} ->
dbg(self())
{fun.(arg), caster}
end)
|> Enum.reduce(output_holder, fn {:ok, {result, caster}}, acc ->
caster.(result, acc)
end)
|> IO.inspect()
# Ensure that the sharded result is the same as the result for the function applied to the unsharded inputs
IO.inspect(Nx.equal(sharded_result, apply(fun, inputs)) |> Nx.all() |> Nx.to_number() |> Kernel.==(1))
Adds a proof-of-concept implementation for sharding as an Nx meta-compiler.
In the current proposal, we shard inputs according to an arbitrary slicing configuration, and the compiler then does its best to propagate those slices through to the output. The compiler can then build a separate {args, function, reducer} tuple for each of the output shards, where:
The following example showcases how we can shard the example function into 4 separate shards. This happens because There are 2 shards for arg1, and only the first axis of arg0 is able to be sharded, due to the other 2 axes being connected to contracting axes in the dot product.