tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
466 stars 73 forks source link

Sharded inputs support in line all-gather for TG llama #11069

Closed mikevin920 closed 2 weeks ago

mikevin920 commented 3 months ago

Currently line all-gather only supports interleaved inputs which adds s2i and i2s ops to TG llama and slows down e2e and device perf.

Refer to device perf sheet for shapes required for sharded inputs with line all-gather.

SeanNijjar commented 3 months ago

Good news! I've got a branch (snijjar/issue-10786) with fully functional tiled, width sharding support on line allgather .

Additionally, this (rewrite essentially), will provide a lot of other benefits:

These are also coming down the pipe (I'm adding support as I type this - it should be quick - maybe a couple more hours. I know it's not an explicit ask but the info is in my brain cache so I'm going for it)

Also this new infra will be portable between CCL ops, so we'll get this for future CCL ops "for free" now too :D

The above is for tiled tensors. Row major hasn't been tested but shouldn't require additional work beyond bug fixes.

SeanNijjar commented 3 months ago

PR for these changea; https://github.com/tenstorrent/tt-metal/pull/11074

kpaigwar commented 3 months ago

That's great! Thanks!

cglagovichTT commented 3 months ago

Reopening until we have sharded line allgather functional in Llama TG tests

cglagovichTT commented 1 month ago

@kpaigwar is this closed or shall we keep it open until the PCC issue for one config is fixed? Do you have an issue for the failing shape?

kpaigwar commented 1 month ago

@cglagovichTT, I don't have a separate issue for failing cases. Let's keep this open and I will attach the failing configuration details in the same.

kpaigwar commented 1 month ago

@SeanNijjar This is the specification of sharded line_all_gather after fused_qkv matmul which results into bad pcc.

fused_query_key_value = {'shape' : [1, 1, 32, 1280], 
                        'shard_shape' : (32, 32)}
all_gather_output = {'shape' : [4, 1, 32, 1280], 
                    'shard_shape' : (32*4, 32)}
output_mem_config = ttnn.create_sharded_memory_config(
                            shape=(32*4, 32),
                            core_grid=ttnn.CoreGrid(y=5, x=8),
                            strategy=ttnn.ShardStrategy.WIDTH,
                            orientation=ttnn.ShardOrientation.ROW_MAJOR,
                            use_height_and_width_as_shard_shape=True,
                        )
gathered_tensor = ttnn.line_all_gather(fused_query_key_value, dim=0, num_links=2, 
                                       cluster_axis=1, mesh_device=self.mesh_device, 
                                       memory_config=output_mem_config)