elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

feat: experimental sharding backend #1544

Open polvalente opened 1 month ago

polvalente commented 1 month ago

Adds a proof-of-concept implementation for sharding as an Nx meta-compiler.

In the current proposal, we shard inputs according to an arbitrary slicing configuration, and the compiler then does its best to propagate those slices through to the output. The compiler can then build a separate {args, function, reducer} tuple for each of the output shards, where:

The following example showcases how we can shard the example function into 4 separate shards. This happens because There are 2 shards for arg1, and only the first axis of arg0 is able to be sharded, due to the other 2 axes being connected to contracting axes in the dot product.

arg0_sharding = %{} # inputs are taken to be fully sharded if no specification is given
arg1_sharding = %{4 => [0..0, 1..1]}

Nx.default_backend(Nx.BinaryBackend)

fun = fn l, r ->
  x = Nx.add(l, Nx.tensor([[1]]))
  x = Nx.transpose(x, axes: [0, 2, 1])
  y = Nx.subtract(r, 1)
  y = Nx.squeeze(y, axes: [0, 1])
  Nx.dot(x, [2, 1], y, [1, 0])
end

# fun = &Nx.dot(&1, [1, 2], &2, [1, 0])
# fun = &Nx.add(&1, &2)

inputs = [
  Nx.iota({2, 2, 3}, type: :f32),
  Nx.add(Nx.iota({1, 1, 3, 2, 2}), 10)
]

{output_holder, shards} =
  Nx.Defn.jit_apply(
    fun,
    inputs,
    compiler: Nx.Defn.ShardingCompiler,
    sharding_config: [arg0_sharding, arg1_sharding],
    sharding_compiler: Nx.Defn.Evaluator,
    sharding_compiler_options: []
  )

sharded_result =
  shards
  |> Task.async_stream(fn {arg, fun, caster} ->
    dbg(self())
    {fun.(arg), caster}
  end)
  |> Enum.reduce(output_holder, fn {:ok, {result, caster}}, acc ->
    caster.(result, acc)
  end)
  |> IO.inspect()

# Ensure that the sharded result is the same as the result for the function applied to the unsharded inputs
IO.inspect(Nx.equal(sharded_result, apply(fun, inputs)) |> Nx.all() |> Nx.to_number() |> Kernel.==(1))