Closed mikevin920 closed 2 weeks ago
Good news! I've got a branch (snijjar/issue-10786
) with fully functional tiled, width sharding support on line allgather .
Additionally, this (rewrite essentially), will provide a lot of other benefits:
These are also coming down the pipe (I'm adding support as I type this - it should be quick - maybe a couple more hours. I know it's not an explicit ask but the info is in my brain cache so I'm going for it)
Also this new infra will be portable between CCL ops, so we'll get this for future CCL ops "for free" now too :D
The above is for tiled tensors. Row major hasn't been tested but shouldn't require additional work beyond bug fixes.
PR for these changea; https://github.com/tenstorrent/tt-metal/pull/11074
That's great! Thanks!
Reopening until we have sharded line allgather functional in Llama TG tests
@kpaigwar is this closed or shall we keep it open until the PCC issue for one config is fixed? Do you have an issue for the failing shape?
@cglagovichTT, I don't have a separate issue for failing cases. Let's keep this open and I will attach the failing configuration details in the same.
@SeanNijjar This is the specification of sharded line_all_gather after fused_qkv matmul which results into bad pcc.
fused_query_key_value = {'shape' : [1, 1, 32, 1280],
'shard_shape' : (32, 32)}
all_gather_output = {'shape' : [4, 1, 32, 1280],
'shard_shape' : (32*4, 32)}
output_mem_config = ttnn.create_sharded_memory_config(
shape=(32*4, 32),
core_grid=ttnn.CoreGrid(y=5, x=8),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
gathered_tensor = ttnn.line_all_gather(fused_query_key_value, dim=0, num_links=2,
cluster_axis=1, mesh_device=self.mesh_device,
memory_config=output_mem_config)
Currently line all-gather only supports interleaved inputs which adds s2i and i2s ops to TG llama and slows down e2e and device perf.
Refer to device perf sheet for shapes required for sharded inputs with line all-gather.