tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
424 stars 54 forks source link

Sharding specs of line_all_gather for Llama3-TG #11172

Open kpaigwar opened 1 month ago

kpaigwar commented 1 month ago
########################################################################################
# Spec 1
########################################################################################
fused_query_key_value = {'shape' : [1, 1, 32, 1280], 
                        'shard_shape' : (32, 32)}
all_gather_output = {'shape' : [4, 1, 32, 1280], 
                    'shard_shape' : (32*4, 32)}
output_mem_config = ttnn.create_sharded_memory_config(
                            shape=(32*4, 32),
                            core_grid=ttnn.CoreGrid(y=5, x=8),
                            strategy=ttnn.ShardStrategy.WIDTH,
                            orientation=ttnn.ShardOrientation.ROW_MAJOR,
                            use_height_and_width_as_shard_shape=True,
                        )
gathered_tensor = ttnn.line_all_gather(fused_query_key_value, dim=0, num_links=2, 
                                       cluster_axis=1, device_mesh=self.device_mesh, 
                                       memory_config=output_mem_config)
########################################################################################
# Spec 2
########################################################################################
attn_output_tensor = {'shape' : [1, 1, 32, 2048], 
                        'shard_shape' : [1, 1, 32, 64]}
all_gather_output = {'shape' : [8, 1, 32, 2048], 
                    'shard_shape' : (32*8, 64)}
output_mem_config = ttnn.create_sharded_memory_config(
                            shape=(32*8, 64),
                            core_grid=ttnn.CoreGrid(y=4, x=8),
                            strategy=ttnn.ShardStrategy.WIDTH,
                            orientation=ttnn.ShardOrientation.ROW_MAJOR,
                            use_height_and_width_as_shard_shape=True,
                        )
gathered_tensor = ttnn.line_all_gather(attn_output_tensor, dim=0, num_links=2, 
                                       cluster_axis=0, device_mesh=self.device_mesh, 
                                       memory_config=output_mem_config)
kpaigwar commented 1 month ago

fyi @SeanNijjar @cglagovich