tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
387 stars 47 forks source link

Support for distributed layernorm #8796

Open johanna-rock-tt opened 3 months ago

johanna-rock-tt commented 3 months ago

Currently there is no support for layernorm to be distributed over multiple devices.

Implications

Proposed Solution: Implementation of a distributed layernorm

cglagovichTT commented 3 months ago

Restrictions:

Layernorm part 1: per-device

output of layernorm part 1: each device produces an output of B x 1 x S x 2 tiles where the first column of tiles contains variance per device, and the second column of tiles contains mean per device. Each tile column only has valid data in the first column of the tile.

AllGather on width produces B x 1 x S x 2 * 32 * num_devices.

layernorm part 2:

layernorm part 2 will reduce the second input to get global mean and variance, and then apply these to compute the rest of layernorm.

Note that RMSNorm only has to exchange the sum(X**2) values. LNP1 in RMS mode will produce 1 column of tiles, and LNP2 will consume 1 column of input tiles.

cglagovichTT commented 3 months ago
image
TT-BrianLiu commented 3 months ago

So far, we don't have any ops that operate across device other than CCL ops which @SeanNijjar is working on. This op seems like something that should be generalized as we start supporting multi-chip ops, so I think it's important we get things right the first time. (Potentially need a wider design review for how we want to handle multi-chip ops.) For example:

cglagovichTT commented 3 months ago

We are thinking of implementing distributed layernorm as 3 separate programs so we can use Sean's AllGather as it exists today. We would invoke it like this:

stats = tt_lib.tensor.layernorm_tophalf(x)
stats = ttnn.all_gather(stats)
out = tt_lib.tensor.layernorm_bottomhalf(x, epsilon, starts, weight, bias)

From talking with Sean, infra to support CCL within ops isn't mature so this is probably the fastest we can bring up this functionality.

cglagovichTT commented 3 months ago

We are planning on restricting it to interleaved for two reasons

  1. We want to AllGather to produce a tensor of gathered statistics which each core (parallelizing on sequence length, so we have no inter-core comms) can pick up and reduce. Sharded AllGather currently only supports width-sharded, single-tile high tensors, so there's no way to AllGather sharded and assign output tiles to the correct worker core
  2. The purpose of distributed layernorm is to push the AllGather down to the Matmul in the MLP, so we can use AllGather/Matmul overlapping. This optimization helps the most for prefill mode, where we expect inputs to be interleaved.

Thinking about it, we could support sharded input/output but I think the stats would have to be interleaved DRAM. But I think models would have to explicitly shard and then interleave input/output to use this, so we would pay non-overlapped reads/writes anyways.

TT-BrianLiu commented 3 months ago

We are thinking of implementing distributed layernorm as 3 separate programs so we can use Sean's AllGather as it exists today. We would invoke it like this:

stats = tt_lib.tensor.layernorm_tophalf(x)
stats = ttnn.all_gather(stats)
out = tt_lib.tensor.layernorm_bottomhalf(x, epsilon, starts, weight, bias)

From talking with Sean, infra to support CCL within ops isn't mature so this is probably the fastest we can bring up this functionality.

I see. This makes sense. Two things:

cglagovichTT commented 3 months ago
  1. tophalf and bottomhalf are relatively different. tophalf only computes local sum(x) and sum(x**2). bottom half takes the gathered stats and computes the rest of the layernorm
SeanNijjar commented 3 months ago

This all gather + reduce in layernorm_bottomhalf sounds like a reduce-scatter CCL op. @SeanNijjar Is this something we will be implementing later? @TT-BrianLiu - yes this is the plan eventually, but we shouldn't require it to bringup this functionality. The plan is, similar to our op list, to provide all the baseline CCL "primitives": all-gather, all-reduce, reduce-scatter, all-scatter, bcast, send, etc.

Beyond that, I'm hoping to provide a bit more of a programming model (sort of) that should a) make building these ops simpler, b) empower developers like Colman and Johanna to experiment with custom schemes like this without having to depend on me or have a potentially compromised implementation that relies on these bigger ops

cglagovichTT commented 1 month ago

@kevinmiTT11 found what might be a bug in distributed layernorm - since pre-allgather rmsnorm produces and output which is 1 x 1 x R x 1[32] for each device, the allgather afterward produces an output that is 1 x 1 x R x 4[32] for 4 devices - this messes up the computation in post-allgather rmsnorm, since it expects 4 devices to produce 4 tile columns of stats. @kevinmiTT11 has a workaround which reshapes the output of pre-allgather rmsnorm.

johanna-rock-tt commented 1 month ago

@SeanNijjar is this related to the regression you've been looking at last week?

SeanNijjar commented 1 month ago

@johanna-rock-tt no this isn't the same issue. The issue I was talking to you about was a separate issue on branch only (was since resolved and merged).

However, @cglagovichTT's description is different from how I understood your offline description. My understanding was that all-gather should be getting as input something like

[1,1,R,1[32]] x ring_size

with output (per chip)

[1, 1, R, num_chips x 1[32]] - in other words, num_chips worth of tiles on width, each having a column of valid data.

cglagovichTT commented 1 month ago

@SeanNijjar the way you described it is how allgather behaves and that is what causes the issue in post-allgather rmsnorm. From AllGather, we get 1, 1, R, num_chips[32] but we want 1, 1, R, num_chips * 32.

I'm not sure if AllGather recently changed how it behaves with tile-padded tensors or if we just didn't catch this in our distributed rmsnorm test cases.

SeanNijjar commented 1 month ago

@cglagovichTT I did recently add the new, second, bidirectional mode for all-gather. Fundamentally there should have been no changes. The all gather and distributed layernorm tests are all passing. It should be impossible for the allgather to emit 1, 1, R, num_chips[32] and that definitely would have been caught by the tests 🤔

cglagovichTT commented 1 month ago

Let me dig deeper into my model code... I must be missing something then